Commit f431756f by Ting PAN

Update with the new frontend API

Summary:
The new frontend makes an union of two execution modes, while starts from
a single tensor class. Besides, it emits the operator execution through
a common path that works both for dragon and torch.
1 parent 6bfe3e73
Showing with 1664 additions and 1580 deletions
...@@ -16,10 +16,8 @@ from __future__ import print_function ...@@ -16,10 +16,8 @@ from __future__ import print_function
import numpy import numpy
from dragon.core.autograph.tensor import TensorRef from dragon.core.autograph import context
from dragon.core.eager import context as eager_context from dragon.core.framework.tensor import Tensor
from dragon.core.framework import context
from dragon.core.util import logging
from dragon.vm.caffe.core.proto import caffe_pb2 from dragon.vm.caffe.core.proto import caffe_pb2
...@@ -36,20 +34,10 @@ class Layer(object): ...@@ -36,20 +34,10 @@ class Layer(object):
""" """
self._proto = layer_param self._proto = layer_param
self._name = layer_param.name self._bottom_names = [name for name in layer_param.bottom]
self._arguments, self.arguments = {'name': 'output'}, {} self._top_names = [name for name in layer_param.top]
# Store the inputs, outputs and trainable parameters. self._blobs = []
self._bottom, self._top, self._blobs = [], [], [] self._call_layer = None
for blob in layer_param.bottom:
self._bottom.append(blob)
for blob in layer_param.top:
self._top.append(blob)
# Store the loss weight to apply gradients.
self._loss_weight = layer_param.loss_weight \
if len(layer_param.loss_weight) > 0 else None
# Optional mirror stage argument for memory optimization.
if layer_param.HasField('mirror_stage'):
self._arguments['mirror_stage'] = layer_param.mirror_stage
@property @property
def blobs(self): def blobs(self):
...@@ -59,36 +47,37 @@ class Layer(object): ...@@ -59,36 +47,37 @@ class Layer(object):
@property @property
def bottom(self): def bottom(self):
"""Return the bottom names.""" """Return the bottom names."""
return self._bottom return self._bottom_names
@property
def loss_weight(self):
"""Return the loss weight."""
return self._loss_weight
@property @property
def name(self): def name(self):
"""Return the layer name.""" """Return the layer name."""
return self._name return self._proto.name
@property @property
def top(self): def top(self):
"""Return the top names.""" """Return the top names."""
return self._top return self._top_names
def add_blob(self, value=None, filler=None, no_grad=False): def add_blob(self, shape, filler, requires_grad=True):
"""Add a blob into this layer.""" """Add a blob into this layer."""
# Set the name for reference explicitly. data = Tensor(shape, name='blob%d' % (len(self._blobs) + 1))
data_name = context.get_name_scope() + 'param:{}'.format(len(self._blobs)) if filler.type == 'constant':
data, diff = TensorRef(data_name), TensorRef(data_name + '_grad') data.fill(filler.value)
if filler is not None: elif filler.type == 'gaussian':
data._register_as(**filler) data.normal(filler.mean, filler.std)
elif filler.type == 'uniform':
data.uniform(filler.min, filler.max)
elif filler.type == 'xavier':
norm_modes = {0: 'fan_in', 1: 'fan_out', 2: 'fan_avg'}
data.glorot_uniform(norm_modes[filler.variance_norm])
elif filler.type == 'msra':
norm_modes = {0: 'fan_in', 1: 'fan_out', 2: 'fan_avg'}
data.glorot_normal(norm_modes[filler.variance_norm])
else: else:
# Register a constant filler by default. raise ValueError('Unknown filler type: ' + filler.type)
value = value if value else 0 data.requires_grad = requires_grad
data.constant(value=value) self._blobs.append({'data': data, 'diff': None})
# Append to the blobs.
self._blobs.append({'data': data, 'diff': None if no_grad else diff})
def from_proto(self, proto): def from_proto(self, proto):
"""Deserialize from the proto. """Deserialize from the proto.
...@@ -110,16 +99,14 @@ class Layer(object): ...@@ -110,16 +99,14 @@ class Layer(object):
raise ValueError('Neither <data> or <double_data> in blob proto.') raise ValueError('Neither <data> or <double_data> in blob proto.')
if len(blob_proto.shape.dim) > 0: if len(blob_proto.shape.dim) > 0:
value = value.reshape([dim for dim in blob_proto.shape.dim]) value = value.reshape([dim for dim in blob_proto.shape.dim])
self._blobs[i]['data'].set_value(value) self._blobs[i]['data']._impl.FromNumpy(value, False)
logging.info('Blob({}/param:{}) loaded, shape: {}, size: {}'
.format(self._name, i, value.shape, value.size))
def setup(self, bottom): def setup(self, bottom):
"""Setup the layer.""" """Setup the layer."""
self.arguments = dict(self.arguments, **self._arguments)
bottom = bottom[0] if len(bottom) == 1 else bottom bottom = bottom[0] if len(bottom) == 1 else bottom
with eager_context.graph_mode(): with context.graph_mode():
return self.__call__(bottom) call_layer = self._call_layer or self
return call_layer.__call__(bottom)
def to_proto(self): def to_proto(self):
"""Serialize to the proto. """Serialize to the proto.
...@@ -133,7 +120,7 @@ class Layer(object): ...@@ -133,7 +120,7 @@ class Layer(object):
proto = caffe_pb2.LayerParameter() proto = caffe_pb2.LayerParameter()
proto.CopyFrom(self._proto) proto.CopyFrom(self._proto)
for blob in self._blobs: for blob in self._blobs:
value = blob['data'].get_value() value = blob['data'].numpy()
if str(value.dtype) == 'float32': if str(value.dtype) == 'float32':
blob_proto = caffe_pb2.BlobProto( blob_proto = caffe_pb2.BlobProto(
data=value.flatten(), data=value.flatten(),
...@@ -147,21 +134,6 @@ class Layer(object): ...@@ -147,21 +134,6 @@ class Layer(object):
proto.blobs.extend([blob_proto]) proto.blobs.extend([blob_proto])
return proto return proto
@staticmethod
def get_filler(proto, filler_name):
"""Return the filler from proto."""
if proto.HasField(filler_name):
filler = getattr(proto, filler_name)
return {
'type': filler.type.lower(),
'value': filler.value,
'low': filler.min,
'high': filler.max,
'mean': filler.mean,
'std': filler.std,
}
return None
def __call__(self, bottom): def __call__(self, bottom):
"""Define the forward pipeline.""" """Define the forward pipeline."""
raise NotImplementedError raise NotImplementedError
...@@ -48,7 +48,5 @@ from dragon.vm.caffe.core.layers.vision import Convolution ...@@ -48,7 +48,5 @@ from dragon.vm.caffe.core.layers.vision import Convolution
from dragon.vm.caffe.core.layers.vision import Deconvolution from dragon.vm.caffe.core.layers.vision import Deconvolution
from dragon.vm.caffe.core.layers.vision import LRN from dragon.vm.caffe.core.layers.vision import LRN
from dragon.vm.caffe.core.layers.vision import Pooling from dragon.vm.caffe.core.layers.vision import Pooling
from dragon.vm.caffe.core.layers.vision import ROIAlign
from dragon.vm.caffe.core.layers.vision import ROIPooling
__all__ = [_s for _s in dir() if not _s.startswith('_')] __all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -23,18 +23,18 @@ from dragon.vm.caffe.core.layer import Layer ...@@ -23,18 +23,18 @@ from dragon.vm.caffe.core.layer import Layer
class _DataPlugin(object): class _DataPlugin(object):
"""Embedded plugin for **Data** layer.""" """Embedded plugin for data layer."""
def setup(self, inputs, outputs): def setup(self, inputs, outputs):
kwargs = eval(self.kwargs_str) kwargs = eval(self.kwargs_str)
self.iterator = vision.DataIterator( default_ws = workspace.get_workspace()
dataset=KPLRecordDataset, **kwargs) self.outputs = [default_ws.get_tensor(output) for output in outputs]
self.iterator = vision.DataIterator(dataset=KPLRecordDataset, **kwargs)
def forward(self, inputs, outputs): def forward(self, inputs, outputs):
blobs = self.iterator.next() blobs = self.iterator.next()
current_ws = workspace.get_workspace()
for i, blob in enumerate(blobs): for i, blob in enumerate(blobs):
current_ws.feed_tensor(outputs[i], blob) self.outputs[i].FromNumpy(blob)
class Data(Layer): class Data(Layer):
...@@ -118,8 +118,8 @@ class Data(Layer): ...@@ -118,8 +118,8 @@ class Data(Layer):
'num_outputs': 2, 'num_outputs': 2,
} }
data, label = framework_ops.python_plugin([], **args) data, label = framework_ops.python_plugin([], **args)
data.shape = (self.data_args['batch_size'], data._shape = (self.data_args['batch_size'],
None, None, len(self.norm_args['mean'])) None, None, len(self.norm_args['mean']))
label.shape = (self.data_args['batch_size'], None) label._shape = (self.data_args['batch_size'], None)
data = array_ops.channel_normalize(data, **self.norm_args) data = array_ops.channel_normalize(data, **self.norm_args)
return data, label return data, label
...@@ -51,16 +51,17 @@ class EuclideanLoss(Layer): ...@@ -51,16 +51,17 @@ class EuclideanLoss(Layer):
reduction = 'mean' reduction = 'mean'
else: else:
reduction = norm_dict[param.normalization] reduction = norm_dict[param.normalization]
self.arguments = {'reduction': reduction} self.call_args = {'reduction': reduction}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom): def __call__(self, bottom):
loss = loss_ops.l2_loss(bottom, **self.arguments) loss = loss_ops.l2_loss(bottom, **self.call_args)
loss_weight = 1. if self.loss_weight is None else self.loss_weight loss_weight = 1. if self.loss_weight is None else self.loss_weight
return loss * (loss_weight * 0.5) return loss * (loss_weight * 0.5)
class SigmoidCrossEntropyLoss(Layer): class SigmoidCrossEntropyLoss(Layer):
r"""Compute the sigmoid cross entropy with contiguous targets. """Compute the loss of sigmoid cross entropy.
Examples: Examples:
...@@ -88,11 +89,12 @@ class SigmoidCrossEntropyLoss(Layer): ...@@ -88,11 +89,12 @@ class SigmoidCrossEntropyLoss(Layer):
reduction = 'batch_mean' reduction = 'batch_mean'
else: else:
reduction = norm_dict[param.normalization] reduction = norm_dict[param.normalization]
self.arguments = {'reduction': reduction} self.call_args = {'reduction': reduction}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom): def __call__(self, bottom):
loss = loss_ops.sigmoid_cross_entropy(bottom, **self.arguments) loss = loss_ops.sigmoid_cross_entropy_loss(bottom, **self.call_args)
if self.loss_weight is not None: if self.loss_weight != 1:
loss *= self.loss_weight loss *= self.loss_weight
return loss return loss
...@@ -131,24 +133,18 @@ class SmoothL1Loss(Layer): ...@@ -131,24 +133,18 @@ class SmoothL1Loss(Layer):
else: else:
reduction = norm_dict[param.normalization] reduction = norm_dict[param.normalization]
sigma2 = smooth_l1_param.sigma * smooth_l1_param.sigma sigma2 = smooth_l1_param.sigma * smooth_l1_param.sigma
self.arguments = { self.call_args = {'beta': float(1. / sigma2), 'reduction': reduction}
'beta': float(1. / sigma2), self.loss_weight = (layer_param.loss_weight or [1])[0]
'reduction': reduction,
}
def __call__(self, bottom): def __call__(self, bottom):
loss = loss_ops.smooth_l1_loss(bottom, **self.arguments) loss = loss_ops.smooth_l1_loss(bottom, **self.call_args)
if self.loss_weight is not None: if self.loss_weight != 1:
loss *= self.loss_weight loss *= self.loss_weight
return loss return loss
class SoftmaxWithLoss(Layer): class SoftmaxWithLoss(Layer):
r"""Compute the softmax cross entropy with sparse labels. """Compute the loss of softmax cross entropy.
The **CrossEntropy** function is defined as:
.. math:: \text{CrossEntropy}(p_{t}) = -\log(p_{t})
Examples: Examples:
...@@ -181,15 +177,16 @@ class SoftmaxWithLoss(Layer): ...@@ -181,15 +177,16 @@ class SoftmaxWithLoss(Layer):
reduction = 'batch_mean' reduction = 'batch_mean'
else: else:
reduction = norm_dict[param.normalization] reduction = norm_dict[param.normalization]
self.arguments = { self.call_args = {
'axis': softmax_param.axis, 'axis': softmax_param.axis,
'reduction': reduction, 'reduction': reduction,
'ignore_index': param.ignore_label 'ignore_index': param.ignore_label
if param.HasField('ignore_label') else None, if param.HasField('ignore_label') else None,
} }
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom): def __call__(self, bottom):
loss = loss_ops.sparse_softmax_cross_entropy(bottom, **self.arguments) loss = loss_ops.softmax_cross_entropy_loss(bottom, **self.call_args)
if self.loss_weight is not None: if self.loss_weight != 1:
loss *= self.loss_weight loss *= self.loss_weight
return loss return loss
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
from dragon.core.ops import activation_ops from dragon.core.ops import activation_ops
from dragon.core.ops import math_ops from dragon.core.ops import math_ops
from dragon.vm.caffe.core.layer import Layer from dragon.vm.caffe.core.layer import Layer
from dragon.vm.caffe.core.proto import caffe_pb2
class Dropout(Layer): class Dropout(Layer):
...@@ -47,10 +48,10 @@ class Dropout(Layer): ...@@ -47,10 +48,10 @@ class Dropout(Layer):
param = layer_param.dropout_param param = layer_param.dropout_param
if not param.scale_train: if not param.scale_train:
raise ValueError('Unscaled dropout is not supported.') raise ValueError('Unscaled dropout is not supported.')
self.arguments = {'ratio': param.dropout_ratio} self.call_args = {'ratio': param.dropout_ratio}
def __call__(self, bottom): def __call__(self, bottom):
return activation_ops.dropout(bottom, **self.arguments) return activation_ops.dropout(bottom, **self.call_args)
class ELU(Layer): class ELU(Layer):
...@@ -83,10 +84,10 @@ class ELU(Layer): ...@@ -83,10 +84,10 @@ class ELU(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(ELU, self).__init__(layer_param) super(ELU, self).__init__(layer_param)
self.arguments = {'alpha': float(layer_param.elu_param.alpha)} self.call_args = {'alpha': float(layer_param.elu_param.alpha)}
def __call__(self, bottom): def __call__(self, bottom):
return activation_ops.elu(bottom, **self.arguments) return activation_ops.elu(bottom, **self.call_args)
class Power(Layer): class Power(Layer):
...@@ -123,7 +124,7 @@ class Power(Layer): ...@@ -123,7 +124,7 @@ class Power(Layer):
bottom = bottom * self.scale bottom = bottom * self.scale
if self.shift != 0: if self.shift != 0:
bottom = bottom + self.shift bottom = bottom + self.shift
return math_ops.pow([bottom, self.power], **self.arguments) return math_ops.pow([bottom, self.power])
class PReLU(Layer): class PReLU(Layer):
...@@ -163,15 +164,24 @@ class PReLU(Layer): ...@@ -163,15 +164,24 @@ class PReLU(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(PReLU, self).__init__(layer_param) super(PReLU, self).__init__(layer_param)
param = layer_param.prelu_param param = layer_param.prelu_param
self.arguments = { self.filler = caffe_pb2.FillerParameter(type='constant', value=0.25)
'channel_shared': param.channel_shared, self.filler = param.filler if param.HasField('filler') else self.filler
'data_format': 'NCHW', self.channel_shared = param.channel_shared
}
self.add_blob(filler=self.get_filler(param, 'filler'), value=0.25) def build(self, bottom):
if self.channel_shared:
weight_shape = [1]
elif len(bottom.shape) > 1:
weight_shape = [bottom.shape[1]]
else:
weight_shape = [bottom.shape[0]]
self.add_blob(weight_shape, self.filler)
def __call__(self, bottom): def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs] inputs = [bottom] + [blob['data'] for blob in self._blobs]
return activation_ops.prelu(inputs, **self.arguments) return activation_ops.prelu(inputs)
class ReLU(Layer): class ReLU(Layer):
...@@ -205,11 +215,12 @@ class ReLU(Layer): ...@@ -205,11 +215,12 @@ class ReLU(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(ReLU, self).__init__(layer_param) super(ReLU, self).__init__(layer_param)
param = layer_param.relu_param param = layer_param.relu_param
if param.HasField('negative_slope'): self.negative_slope = param.negative_slope
self.arguments = {'alpha': param.negative_slope}
def __call__(self, bottom): def __call__(self, bottom):
return activation_ops.relu(bottom, **self.arguments) if self.negative_slope > 0:
return activation_ops.leaky_relu(bottom, self.negative_slope)
return activation_ops.relu(bottom)
class Sigmoid(Layer): class Sigmoid(Layer):
...@@ -235,7 +246,7 @@ class Sigmoid(Layer): ...@@ -235,7 +246,7 @@ class Sigmoid(Layer):
super(Sigmoid, self).__init__(layer_param) super(Sigmoid, self).__init__(layer_param)
def __call__(self, bottom): def __call__(self, bottom):
return activation_ops.sigmoid(bottom, **self.arguments) return activation_ops.sigmoid(bottom)
class TanH(Layer): class TanH(Layer):
...@@ -261,4 +272,4 @@ class TanH(Layer): ...@@ -261,4 +272,4 @@ class TanH(Layer):
super(TanH, self).__init__(layer_param) super(TanH, self).__init__(layer_param)
def __call__(self, bottom): def __call__(self, bottom):
return activation_ops.tanh(bottom, **self.arguments) return activation_ops.tanh(bottom)
...@@ -53,33 +53,39 @@ class Convolution(Layer): ...@@ -53,33 +53,39 @@ class Convolution(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(Convolution, self).__init__(layer_param) super(Convolution, self).__init__(layer_param)
param = layer_param.convolution_param param = layer_param.convolution_param
self.arguments = { self.kernel_shape = param.kernel_size or [1]
'out_channels': param.num_output, self.strides = param.stride or [1]
'kernel_shape': [int(e) for e in param.kernel_size], self.pads = param.pad or [0]
'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1], self.dilations = param.dilation or [1]
'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0], self.out_channels = param.num_output
'dilations': [int(e) for e in param.dilation] if len(param.dilation) > 0 else [1], self.weight_filler = param.weight_filler
'group': int(param.group), self.bias_filler = param.bias_filler
'padding': 'VALID', self.bias_term = param.bias_term
'data_format': 'NCHW', self.call_args = {'group': param.group}
}
if param.HasField('kernel_h'): def build(self, bottom):
assert param.HasField('kernel_w') num_axes = len(bottom.shape) - 2
self.arguments['kernel_shape'] = [param.kernel_h, param.kernel_w] if num_axes < 1:
if param.HasField('stride_h'): raise ValueError(
assert param.HasField('stride_w') 'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
self.arguments['strides'] = [param.stride_h, param.stride_w] .format(self.name, len(bottom.shape)))
if param.HasField('pad_h'): in_channels = bottom.shape[1]
assert param.HasField('pad_w') for k in ('kernel_shape', 'strides', 'pads', 'dilations'):
self.arguments['pads'] = [param.pad_h, param.pad_w] self.call_args[k] = [int(d) for d in getattr(self, k)]
self.add_blob(filler=self.get_filler(param, 'weight_filler')) if len(self.call_args[k]) < num_axes:
if param.bias_term: reps = num_axes - len(self.call_args[k])
self.add_blob(filler=self.get_filler(param, 'bias_filler')) self.call_args[k] += [self.call_args[k][-1]] * reps
weight_shape = [self.out_channels, in_channels] + self.call_args['kernel_shape']
self.add_blob(weight_shape, self.weight_filler)
if self.bias_term:
self.add_blob([self.out_channels], self.bias_filler)
def __call__(self, bottom): def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs] inputs = [bottom] + [blob['data'] for blob in self._blobs]
conv_op = 'conv{}d'.format(len(bottom.shape) - 2) conv_op = 'conv{}d'.format(len(self.call_args['kernel_shape']))
return getattr(vision_ops, conv_op)(inputs, **self.arguments) return getattr(vision_ops, conv_op)(inputs, **self.call_args)
class Deconvolution(Convolution): class Deconvolution(Convolution):
...@@ -116,10 +122,29 @@ class Deconvolution(Convolution): ...@@ -116,10 +122,29 @@ class Deconvolution(Convolution):
def __init__(self, layer_param): def __init__(self, layer_param):
super(Deconvolution, self).__init__(layer_param) super(Deconvolution, self).__init__(layer_param)
def build(self, bottom):
num_axes = len(bottom.shape) - 2
if num_axes < 1:
raise ValueError(
'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
.format(self.name, len(bottom.shape)))
in_channels = bottom.shape[1]
for k in ('kernel_shape', 'strides', 'pads', 'dilations'):
self.call_args[k] = [int(d) for d in getattr(self, k)]
if len(self.call_args[k]) < num_axes:
reps = num_axes - len(self.call_args[k])
self.call_args[k] += [self.call_args[k][-1]] * reps
weight_shape = [in_channels, self.out_channels] + self.call_args['kernel_shape']
self.add_blob(weight_shape, self.weight_filler)
if self.bias_term:
self.add_blob([self.out_channels], self.bias_filler)
def __call__(self, bottom): def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs] inputs = [bottom] + [blob['data'] for blob in self._blobs]
conv_op = 'conv{}d_transpose'.format(len(bottom.shape) - 2) conv_op = 'conv{}d_transpose'.format(len(self.call_args['kernel_shape']))
return getattr(vision_ops, conv_op)(inputs, **self.arguments) return getattr(vision_ops, conv_op)(inputs, **self.call_args)
class LRN(Layer): class LRN(Layer):
...@@ -148,17 +173,14 @@ class LRN(Layer): ...@@ -148,17 +173,14 @@ class LRN(Layer):
super(LRN, self).__init__(layer_param) super(LRN, self).__init__(layer_param)
param = layer_param.lrn_param param = layer_param.lrn_param
if param.norm_region > 0: if param.norm_region > 0:
raise NotImplementedError('WITHIN_CHANNEL mode is not implemented.') raise NotImplementedError('<WITHIN_CHANNEL> mode is not implemented.')
self.arguments = { self.op_args = {'size': param.local_size,
'size': param.local_size, 'alpha': param.alpha,
'alpha': param.alpha, 'beta': param.beta,
'beta': param.beta, 'bias': param.k}
'bias': param.k,
'data_format': 'NCHW',
}
def __call__(self, bottom): def __call__(self, bottom):
return normalization_ops.local_response_norm(bottom, **self.arguments) return normalization_ops.local_response_norm(bottom, **self.op_args)
class Pooling(Layer): class Pooling(Layer):
...@@ -184,93 +206,26 @@ class Pooling(Layer): ...@@ -184,93 +206,26 @@ class Pooling(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(Pooling, self).__init__(layer_param) super(Pooling, self).__init__(layer_param)
param = layer_param.pooling_param param = layer_param.pooling_param
self.arguments = { self.kernel_shape = [param.kernel_size or 1]
self.strides = [param.stride or 1]
self.pads = [param.pad or 0]
self.call_args = {
'ceil_mode': True, 'ceil_mode': True,
'mode': {0: 'MAX', 1: 'AVG'}[param.pool], 'mode': {0: 'MAX', 1: 'AVG'}[param.pool],
'data_format': 'NCHW',
'global_pool': param.global_pooling, 'global_pool': param.global_pooling,
} }
if not param.HasField('kernel_h'):
self.arguments['kernel_shape'] = [param.kernel_size]
else:
self.arguments['kernel_shape'] = [param.kernel_h, param.kernel_w]
if not param.HasField('pad_h'):
self.arguments['pads'] = [param.pad]
else:
self.arguments['pads'] = [param.pad_h, param.pad_w]
if not param.HasField('stride_h'):
self.arguments['strides'] = [param.stride]
else:
self.arguments['strides'] = [param.stride_h, param.stride_w]
def __call__(self, bottom):
pool_op = 'pool{}d'.format(len(bottom.shape) - 2)
return getattr(vision_ops, pool_op)(bottom, **self.arguments)
class ROIAlign(Layer):
r"""Apply the average roi align.
`[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
Examples:
```python
layer {
type: "ROIAlign"
bottom: "conv5_3"
top: "roi_pool4"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625
}
}
```
"""
def __init__(self, layer_param):
super(ROIAlign, self).__init__(layer_param)
param = layer_param.roi_pooling_param
self.arguments = {
'pool_h': int(param.pooled_h),
'pool_w': int(param.pooled_w),
'spatial_scale': param.spatial_scale,
}
def __call__(self, bottom):
return vision_ops.roi_align(bottom, **self.arguments)
class ROIPooling(Layer):
r"""Apply the max roi pooling.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
Examples:
```python
layer {
type: "ROIPooling"
bottom: "conv5_3"
top: "roi_pool4"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625
}
}
```
"""
def __init__(self, layer_param):
super(ROIPooling, self).__init__(layer_param)
param = layer_param.roi_pooling_param
self.arguments = {
'pool_h': int(param.pooled_h),
'pool_w': int(param.pooled_w),
'spatial_scale': param.spatial_scale,
}
def __call__(self, bottom): def __call__(self, bottom):
return vision_ops.roi_pool(bottom, **self.arguments) num_axes = len(bottom.shape) - 2
if num_axes < 1:
raise ValueError(
'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
.format(self.name, len(bottom.shape)))
call_args = self.call_args.copy()
for k in ('kernel_shape', 'strides', 'pads'):
call_args[k] = [int(d) for d in getattr(self, k)]
if len(call_args[k]) < num_axes:
reps = num_axes - len(call_args[k])
call_args[k] += [call_args[k][-1]] * reps
pool_op = 'pool{}d'.format(num_axes)
return getattr(vision_ops, pool_op)(bottom, **call_args)
...@@ -409,11 +409,10 @@ message LayerParameter { ...@@ -409,11 +409,10 @@ message LayerParameter {
optional ThresholdParameter threshold_param = 128; optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138; optional TileParameter tile_param = 138;
optional WindowDataParameter window_data_param = 129; optional WindowDataParameter window_data_param = 129;
// Following parameters are extended on BVLC-Caffe. // Following parameters are extended on BVLC/caffe.
optional ROIPoolingParameter roi_pooling_param = 151; optional SmoothL1LossParameter smooth_l1_loss_param = 151;
optional SmoothL1LossParameter smooth_l1_loss_param = 152; optional PermuteParameter permute_param = 152;
optional PermuteParameter permute_param = 153; optional NormalizeParameter normalize_param = 153;
optional NormalizeParameter normalize_param = 154;
} }
// Message that stores parameters used to apply transformation // Message that stores parameters used to apply transformation
...@@ -931,17 +930,6 @@ message PoolingParameter { ...@@ -931,17 +930,6 @@ message PoolingParameter {
optional bool global_pooling = 12 [default = false]; optional bool global_pooling = 12 [default = false];
} }
// Message that stores parameters used by ROIPoolingLayer
message ROIPoolingParameter {
// Pad, kernel size, and stride are all given as a single value for equal
// dimensions in height and width or as Y, X pairs.
optional uint32 pooled_h = 1 [default = 0]; // The pooled output height
optional uint32 pooled_w = 2 [default = 0]; // The pooled output width
// Multiplicative spatial scale factor to translate ROI coords from their
// input scale to the scale used when pooling
optional float spatial_scale = 3 [default = 1];
}
message PowerParameter { message PowerParameter {
// PowerLayer computes outputs y = (shift + scale * x) ^ power. // PowerLayer computes outputs y = (shift + scale * x) ^ power.
optional float power = 1 [default = 1.0]; optional float power = 1 [default = 1.0];
......
...@@ -57,7 +57,7 @@ def get_device_type(mixed=False): ...@@ -57,7 +57,7 @@ def get_device_type(mixed=False):
Parameters Parameters
---------- ----------
mixed : bool, optional, default=False mixed : bool, optional, default=False
**True** to return ``mixed`` for gpu device. ``True`` to return ``mixed`` for gpu device.
Returns Returns
------- -------
......
...@@ -23,9 +23,9 @@ except ImportError: ...@@ -23,9 +23,9 @@ except ImportError:
TensorGPU = object TensorGPU = object
from dragon.core.device import cuda from dragon.core.device import cuda
from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import device_spec from dragon.core.framework import device_spec
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.framework.tensor import Tensor
from dragon.vm.dali.core.framework import types from dragon.vm.dali.core.framework import types
...@@ -152,7 +152,7 @@ class Iterator(object): ...@@ -152,7 +152,7 @@ class Iterator(object):
@staticmethod @staticmethod
def new_tensor(shape, dtype, device): def new_tensor(shape, dtype, device):
"""Return a new tensor abstraction.""" """Return a new tensor abstraction."""
return EagerTensor(shape=shape, dtype=dtype, device=device) return Tensor(shape=shape, dtype=dtype, device=device)
def __iter__(self): def __iter__(self):
"""Return the iterator self.""" """Return the iterator self."""
...@@ -206,12 +206,12 @@ class Iterator(object): ...@@ -206,12 +206,12 @@ class Iterator(object):
def _transfer_tensor(self, dali_tensor, target_tensor): def _transfer_tensor(self, dali_tensor, target_tensor):
"""Transfer the dali tensor to the target.""" """Transfer the dali tensor to the target."""
target_shape = dali_tensor.shape() target_shape = dali_tensor.shape()
device = target_tensor._device = \ device = self.new_device(
self.new_device( device_type='cuda' if isinstance(
device_type='cuda' if isinstance( dali_tensor, TensorGPU) else 'cpu',
dali_tensor, TensorGPU) else 'cpu', device_index=self._pipe.device_id)
device_index=self._pipe.device_id, if hasattr(target_tensor, '_device'):
) target_tensor._device = device
impl = target_tensor._impl impl = target_tensor._impl
if target_shape != list(target_tensor.shape): if target_shape != list(target_tensor.shape):
new_capacity = not impl.Reshape(target_shape) new_capacity = not impl.Reshape(target_shape)
......
...@@ -63,7 +63,7 @@ class ImageDecoder(object): ...@@ -63,7 +63,7 @@ class ImageDecoder(object):
""" """
if isinstance(output_type, six.string_types): if isinstance(output_type, six.string_types):
output_type = getattr(types, output_type) output_type = getattr(types, output_type)
return ops.ImageDecoder( return ops.decoders.Image(
output_type=output_type, output_type=output_type,
host_memory_padding=host_memory_padding, host_memory_padding=host_memory_padding,
device_memory_padding=device_memory_padding, device_memory_padding=device_memory_padding,
...@@ -124,7 +124,7 @@ class ImageDecoderRandomCrop(object): ...@@ -124,7 +124,7 @@ class ImageDecoderRandomCrop(object):
""" """
if isinstance(output_type, six.string_types): if isinstance(output_type, six.string_types):
output_type = getattr(types, output_type) output_type = getattr(types, output_type)
return ops.ImageDecoderRandomCrop( return ops.decoders.ImageRandomCrop(
output_type=output_type, output_type=output_type,
host_memory_padding=host_memory_padding, host_memory_padding=host_memory_padding,
device_memory_padding=device_memory_padding, device_memory_padding=device_memory_padding,
......
...@@ -316,7 +316,7 @@ class RandomBBoxCrop(object): ...@@ -316,7 +316,7 @@ class RandomBBoxCrop(object):
thresholds : Sequence[float], optional thresholds : Sequence[float], optional
The minimum IoU(s) to satisfy. The minimum IoU(s) to satisfy.
allow_no_crop : bool, optional, default=True allow_no_crop : bool, optional, default=True
**True** to include the no-cropping as an option. ``True`` to include the no-cropping as an option.
num_attempts : int, optional, default=10 num_attempts : int, optional, default=10
The max number of sampling trails. The max number of sampling trails.
bbox_layout : str, optional, default='xyXY' bbox_layout : str, optional, default='xyXY'
......
...@@ -47,7 +47,7 @@ class CoinFlip(object): ...@@ -47,7 +47,7 @@ class CoinFlip(object):
The operator. The operator.
""" """
return ops.CoinFlip(probability=probability, **kwargs) return ops.random.CoinFlip(probability=probability, **kwargs)
class Uniform(object): class Uniform(object):
...@@ -76,4 +76,4 @@ class Uniform(object): ...@@ -76,4 +76,4 @@ class Uniform(object):
The operator. The operator.
""" """
return ops.Uniform(range=range, **kwargs) return ops.random.Uniform(range=range, **kwargs)
...@@ -218,7 +218,7 @@ class TFRecordReader(object): ...@@ -218,7 +218,7 @@ class TFRecordReader(object):
""" """
path, index_path, features = cls.check_files(path) path, index_path, features = cls.check_files(path)
return ops.TFRecordReader( return ops.readers.TFRecord(
path=path, path=path,
index_path=index_path, index_path=index_path,
shard_id=shard_id, shard_id=shard_id,
......
...@@ -15,6 +15,10 @@ Buffer ...@@ -15,6 +15,10 @@ Buffer
###### ######
.. doxygenfunction:: dragon::Operator::Buffer .. doxygenfunction:: dragon::Operator::Buffer
DeriveFrom
##########
.. doxygenfunction:: dragon::Operator::DeriveFrom
Fuse Fuse
#### ####
.. doxygenfunction:: dragon::Operator::Fuse .. doxygenfunction:: dragon::Operator::Fuse
...@@ -55,10 +59,6 @@ Run ...@@ -55,10 +59,6 @@ Run
### ###
.. doxygenfunction:: dragon::Operator::Run .. doxygenfunction:: dragon::Operator::Run
UpdateFrom
##########
.. doxygenfunction:: dragon::Operator::UpdateFrom
data_format data_format
########### ###########
.. doxygenfunction:: dragon::Operator::data_format .. doxygenfunction:: dragon::Operator::data_format
......
...@@ -23,10 +23,6 @@ CreateTensor ...@@ -23,10 +23,6 @@ CreateTensor
############ ############
.. doxygenfunction:: dragon::Workspace::CreateTensor .. doxygenfunction:: dragon::Workspace::CreateTensor
GetFillerInfo
#############
.. doxygenfunction:: dragon::Workspace::GetFillerInfo
GetTensor GetTensor
######### #########
.. doxygenfunction:: dragon::Workspace::GetTensor .. doxygenfunction:: dragon::Workspace::GetTensor
...@@ -39,14 +35,6 @@ MergeFrom ...@@ -39,14 +35,6 @@ MergeFrom
######### #########
.. doxygenfunction:: dragon::Workspace::MergeFrom .. doxygenfunction:: dragon::Workspace::MergeFrom
RegisterAlias
#############
.. doxygenfunction:: dragon::Workspace::RegisterAlias
ResetTensor
###########
.. doxygenfunction:: dragon::Workspace::ResetTensor
RunGraph RunGraph
######## ########
.. doxygenfunction:: dragon::Workspace::RunGraph .. doxygenfunction:: dragon::Workspace::RunGraph
...@@ -55,6 +43,10 @@ RunOperator ...@@ -55,6 +43,10 @@ RunOperator
########### ###########
.. doxygenfunction:: dragon::Workspace::RunOperator .. doxygenfunction:: dragon::Workspace::RunOperator
SetAlias
########
.. doxygenfunction:: dragon::Workspace::SetAlias
TryGetTensor TryGetTensor
############ ############
.. doxygenfunction:: dragon::Workspace::TryGetTensor .. doxygenfunction:: dragon::Workspace::TryGetTensor
......
...@@ -29,10 +29,6 @@ params ...@@ -29,10 +29,6 @@ params
Methods Methods
------- -------
backward
########
.. automethod:: dragon.vm.caffe.Net.backward
copy_from copy_from
######### #########
.. automethod:: dragon.vm.caffe.Net.copy_from .. automethod:: dragon.vm.caffe.Net.copy_from
...@@ -41,10 +37,6 @@ forward ...@@ -41,10 +37,6 @@ forward
######### #########
.. automethod:: dragon.vm.caffe.Net.forward .. automethod:: dragon.vm.caffe.Net.forward
forward_backward
################
.. automethod:: dragon.vm.caffe.Net.forward_backward
save save
#### ####
.. automethod:: dragon.vm.caffe.Net.save .. automethod:: dragon.vm.caffe.Net.save
......
...@@ -88,14 +88,6 @@ vm.caffe.layers ...@@ -88,14 +88,6 @@ vm.caffe.layers
`class Reshape <layers/Reshape.html>`_ `class Reshape <layers/Reshape.html>`_
: Change the dimensions of input. : Change the dimensions of input.
`class ROIAlign <layers/ROIAlign.html>`_
: Apply the average roi align.
`[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
`class ROIPooling <layers/ROIPooling.html>`_
: Apply the max roi pooling.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
`class Scale <layers/Scale.html>`_ `class Scale <layers/Scale.html>`_
: Compute the affine transformation along the given axes. : Compute the affine transformation along the given axes.
...@@ -149,8 +141,6 @@ vm.caffe.layers ...@@ -149,8 +141,6 @@ vm.caffe.layers
layers/Reduction layers/Reduction
layers/ReLU layers/ReLU
layers/Reshape layers/Reshape
layers/ROIAlign
layers/ROIPooling
layers/Scale layers/Scale
layers/Sigmoid layers/Sigmoid
layers/SigmoidCrossEntropyLoss layers/SigmoidCrossEntropyLoss
......
...@@ -34,7 +34,7 @@ extensions = [ ...@@ -34,7 +34,7 @@ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinxcontrib.katex', 'sphinxcontrib.katex',
'sphinx_seeta_theme.ext.viewcode', # 'sphinx_seeta_theme.ext.viewcode',
] ]
napoleon_use_rtype = False napoleon_use_rtype = False
......
...@@ -6,17 +6,17 @@ dragon ...@@ -6,17 +6,17 @@ dragon
Classes Classes
------- -------
`class EagerTensor <dragon/EagerTensor.html>`_ `class DeviceSpec <dragon/DeviceSpec.html>`_
: Tensor abstraction for eager executing. : Describe a computation device.
`class GradientTape <dragon/GradientTape.html>`_ `class GradientTape <dragon/GradientTape.html>`_
: Record the operations for auto differentiation. : Record the operations for auto differentiation.
`class Tensor <dragon/Tensor.html>`_ `class Tensor <dragon/Tensor.html>`_
: Tensor abstraction for graph executing. : A multi-dimensional array for computation
`class Workspace <dragon/Workspace.html>`_ `class Workspace <dragon/Workspace.html>`_
: Sandbox to isolate the resources and computations. : Standalone environment for resources and computations.
Functions Functions
--------- ---------
...@@ -27,6 +27,9 @@ dragon ...@@ -27,6 +27,9 @@ dragon
`assign(...) <dragon/assign.html>`_ `assign(...) <dragon/assign.html>`_
: Assign the value to input. : Assign the value to input.
`boolean_mask(...) <dragon/boolean_mask.html>`_
: Return the elements of input where mask is true.
`broadcast_to(...) <dragon/broadcast_to.html>`_ `broadcast_to(...) <dragon/broadcast_to.html>`_
: Broadcast input according to a given shape. : Broadcast input according to a given shape.
...@@ -34,13 +37,14 @@ dragon ...@@ -34,13 +37,14 @@ dragon
: Cast the data type of input. : Cast the data type of input.
`channel_affine(...) <dragon/channel_affine.html>`_ `channel_affine(...) <dragon/channel_affine.html>`_
: Apply affine transformation along the channels. : Apply affine transformation to each channel of input.
`channel_normalize(...) <dragon/channel_normalize.html>`_ `channel_normalize(...) <dragon/channel_normalize.html>`_
: Normalize channels with mean and standard deviation. : Apply normalization to each channel of input.
`channel_shuffle(...) <dragon/channel_shuffle.html>`_ `channel_shuffle(...) <dragon/channel_shuffle.html>`_
: Shuffle channels between a given number of groups. : Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`concat(...) <dragon/concat.html>`_ `concat(...) <dragon/concat.html>`_
: Concatenate the inputs along the given axis. : Concatenate the inputs along the given axis.
...@@ -48,17 +52,14 @@ dragon ...@@ -48,17 +52,14 @@ dragon
`constant(...) <dragon/constant.html>`_ `constant(...) <dragon/constant.html>`_
: Return a tensor initialized from the value. : Return a tensor initialized from the value.
`create_function(...) <dragon/create_function.html>`_
: Create a callable graph from the specified outputs.
`device(...) <dragon/device.html>`_ `device(...) <dragon/device.html>`_
: Context-manager to nest the device spec. : Context-manager to nest the device spec.
`eager_mode(...) <dragon/eager_mode.html>`_ `eager_mode(...) <dragon/eager_mode.html>`_
: Context-manager set the eager execution mode. : Context-manager set the eager execution mode.
`eager_scope(...) <dragon/eager_mode.html>`_ `variable_scope(...) <dragon/eager_mode.html>`_
: Context-manager to nest the name for eager resources. : Context-manager to nest the namespace for variables.
`expand_dims(...) <dragon/expand_dims.html>`_ `expand_dims(...) <dragon/expand_dims.html>`_
: Expand the dimensions of input with size 1. : Expand the dimensions of input with size 1.
...@@ -78,14 +79,14 @@ dragon ...@@ -78,14 +79,14 @@ dragon
`function(...) <dragon/function.html>`_ `function(...) <dragon/function.html>`_
: Compile a function and return an executable. : Compile a function and return an executable.
`gather(...) <dragon/gather.html>`_
: Gather the elements along the given axis using index.
`get_num_threads(...) <dragon/get_num_threads.html>`_ `get_num_threads(...) <dragon/get_num_threads.html>`_
: Return the number of threads for cpu parallelism. : Return the number of threads for cpu parallelism.
`get_workspace(...) <dragon/get_workspace.html>`_ `get_workspace(...) <dragon/get_workspace.html>`_
: Return the current default workspace. : Return the default workspace.
`gradients(...) <dragon/gradients.html>`_
: Compute the symbolic derivatives of ``ys`` w.r.t. ``xs`` .
`graph_mode(...) <dragon/graph_mode.html>`_ `graph_mode(...) <dragon/graph_mode.html>`_
: Context-manager set the graph execution mode. : Context-manager set the graph execution mode.
...@@ -93,21 +94,12 @@ dragon ...@@ -93,21 +94,12 @@ dragon
`identity(...) <dragon/identity.html>`_ `identity(...) <dragon/identity.html>`_
: Return a tensor copied from the input. : Return a tensor copied from the input.
`index_select(...) <dragon/index_select.html>`_
: Select the elements according to the index along the given axis.
`linspace(...) <dragon/linspace.html>`_ `linspace(...) <dragon/linspace.html>`_
: Generate evenly spaced values within intervals along the given axis. : Generate evenly spaced values within intervals along the given axis.
`load_library(...) <dragon/load_library.html>`_ `load_library(...) <dragon/load_library.html>`_
: Load a shared library. : Load a shared library.
`masked_assign(...) <dragon/masked_assign.html>`_
: Assign the value to input where mask is 1.
`masked_select(...) <dragon/masked_select.html>`_
: Select the elements of input where mask is 1.
`name_scope(...) <dragon/name_scope.html>`_ `name_scope(...) <dragon/name_scope.html>`_
: Context-manager to nest the name as prefix for operations. : Context-manager to nest the name as prefix for operations.
...@@ -166,11 +158,17 @@ dragon ...@@ -166,11 +158,17 @@ dragon
: Return the identity of input with truncated gradient-flow. : Return the identity of input with truncated gradient-flow.
`tile(...) <dragon/tile.html>`_ `tile(...) <dragon/tile.html>`_
: Tile the input according to the given repeats. : Repeat elements along each axis of input.
`transpose(...) <dragon/transpose.html>`_ `transpose(...) <dragon/transpose.html>`_
: Permute the dimensions of input. : Permute the dimensions of input.
`tril(...) <dragon/tril.html>`_
: Return the lower triangular part of input.
`triu(...) <dragon/triu.html>`_
: Return the upper triangular part of input.
`unique(...) <dragon/unique.html>`_ `unique(...) <dragon/unique.html>`_
: Return the unique elements of input. : Return the unique elements of input.
...@@ -186,12 +184,13 @@ dragon ...@@ -186,12 +184,13 @@ dragon
.. toctree:: .. toctree::
:hidden: :hidden:
dragon/EagerTensor dragon/DeviceSpec
dragon/GradientTape dragon/GradientTape
dragon/Tensor dragon/Tensor
dragon/Workspace dragon/Workspace
dragon/argsort dragon/argsort
dragon/assign dragon/assign
dragon/boolean_mask
dragon/broadcast_to dragon/broadcast_to
dragon/cast dragon/cast
dragon/channel_affine dragon/channel_affine
...@@ -199,26 +198,21 @@ dragon ...@@ -199,26 +198,21 @@ dragon
dragon/channel_shuffle dragon/channel_shuffle
dragon/concat dragon/concat
dragon/constant dragon/constant
dragon/create_function
dragon/device dragon/device
dragon/eager_mode dragon/eager_mode
dragon/eager_scope
dragon/expand_dims dragon/expand_dims
dragon/eye dragon/eye
dragon/eye_like dragon/eye_like
dragon/fill dragon/fill
dragon/flatten dragon/flatten
dragon/function dragon/function
dragon/gather
dragon/get_num_threads dragon/get_num_threads
dragon/get_workspace dragon/get_workspace
dragon/gradients
dragon/graph_mode dragon/graph_mode
dragon/identity dragon/identity
dragon/index_select
dragon/linspace dragon/linspace
dragon/load_library dragon/load_library
dragon/masked_assign
dragon/masked_select
dragon/name_scope dragon/name_scope
dragon/nonzero dragon/nonzero
dragon/ones dragon/ones
...@@ -240,7 +234,10 @@ dragon ...@@ -240,7 +234,10 @@ dragon
dragon/stop_gradient dragon/stop_gradient
dragon/tile dragon/tile
dragon/transpose dragon/transpose
dragon/tril
dragon/triu
dragon/unique dragon/unique
dragon/variable_scope
dragon/where dragon/where
dragon/zeros dragon/zeros
dragon/zeros_like dragon/zeros_like
......
DeviceSpec
==========
.. autoclass:: dragon.DeviceSpec
__init__
--------
.. automethod:: dragon.DeviceSpec.__init__
Properties
----------
index
#####
.. autoattribute:: dragon.DeviceSpec.index
type
####
.. autoattribute:: dragon.DeviceSpec.type
Methods
-------
copy
####
.. automethod:: dragon.DeviceSpec.copy
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
EagerTensor
===========
.. autoclass:: dragon.EagerTensor
__init__
--------
.. automethod:: dragon.EagerTensor.__init__
Properties
----------
dtype
#####
.. autoattribute:: dragon.EagerTensor.dtype
id
##
.. autoattribute:: dragon.EagerTensor.id
name
####
.. autoattribute:: dragon.EagerTensor.name
ndim
####
.. autoattribute:: dragon.EagerTensor.ndim
shape
#####
.. autoattribute:: dragon.EagerTensor.shape
size
#####
.. autoattribute:: dragon.EagerTensor.size
Methods
-------
astype
######
.. automethod:: dragon.EagerTensor.astype
constant
########
.. automethod:: dragon.EagerTensor.constant
copy
####
.. automethod:: dragon.EagerTensor.copy
from_value
##########
.. automethod:: dragon.EagerTensor.from_value
get_value
#########
.. automethod:: dragon.EagerTensor.get_value
glorot_normal
#############
.. automethod:: dragon.EagerTensor.glorot_normal
glorot_uniform
##############
.. automethod:: dragon.EagerTensor.glorot_uniform
normal
######
.. automethod:: dragon.EagerTensor.normal
numpy
#####
.. automethod:: dragon.EagerTensor.numpy
reshape
#######
.. automethod:: dragon.EagerTensor.reshape
set_value
#########
.. automethod:: dragon.EagerTensor.set_value
truncated_normal
################
.. automethod:: dragon.EagerTensor.truncated_normal
uniform
#######
.. automethod:: dragon.EagerTensor.uniform
Overrides
---------
__add__
#######
.. automethod:: dragon.EagerTensor.__add__
__float__
#########
.. automethod:: dragon.EagerTensor.__float__
__ge__
######
.. automethod:: dragon.EagerTensor.__ge__
__getitem__
###########
.. automethod:: dragon.EagerTensor.__getitem__
__gt__
######
.. automethod:: dragon.EagerTensor.__gt__
__iadd__
########
.. automethod:: dragon.EagerTensor.__iadd__
__imul__
########
.. automethod:: dragon.EagerTensor.__imul__
__int__
#######
.. automethod:: dragon.EagerTensor.__int__
__isub__
########
.. automethod:: dragon.EagerTensor.__isub__
__itruediv__
############
.. automethod:: dragon.EagerTensor.__itruediv__
__le__
######
.. automethod:: dragon.EagerTensor.__le__
__lt__
######
.. automethod:: dragon.EagerTensor.__lt__
__mul__
#######
.. automethod:: dragon.EagerTensor.__mul__
__neg__
#######
.. automethod:: dragon.EagerTensor.__neg__
__radd__
########
.. automethod:: dragon.EagerTensor.__radd__
__rmul__
########
.. automethod:: dragon.EagerTensor.__rmul__
__rsub__
########
.. automethod:: dragon.EagerTensor.__rsub__
__rtruediv__
############
.. automethod:: dragon.EagerTensor.__rtruediv__
__setitem__
###########
.. automethod:: dragon.EagerTensor.__setitem__
__sub__
#######
.. automethod:: dragon.EagerTensor.__sub__
__truediv__
###########
.. automethod:: dragon.EagerTensor.__truediv__
.. _dragon.assign(...): assign.html
.. _dragon.cast(...): cast.html
.. _dragon.fill(...): fill.html
.. _dragon.identity(...): identity.html
.. _dragon.masked_assign(...): masked_assign.html
.. _dragon.masked_select(...): masked_select.html
.. _dragon.math.add(...): math/add.html
.. _dragon.math.div(...): math/div.html
.. _dragon.math.greater(...): math/greater.html
.. _dragon.math.greater_equal(...): math/greater_equal.html
.. _dragon.math.less(...): math/less.html
.. _dragon.math.less_equal(...): math/less_equal.html
.. _dragon.math.mul(...): math/mul.html
.. _dragon.math.negative(...): math/negative.html
.. _dragon.math.sub(...): math/sub.html
.. _dragon.random.glorot_normal(...): random/glorot_normal.html
.. _dragon.random.glorot_uniform(...): random/glorot_uniform.html
.. _dragon.random.normal(...): random/normal.html
.. _dragon.random.truncated_normal(...): random/truncated_normal.html
.. _dragon.random.uniform(...): random/uniform.html
.. _dragon.reshape(...): reshape.html
.. _dragon.slice(...): slice.html
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
...@@ -10,6 +10,10 @@ __init__ ...@@ -10,6 +10,10 @@ __init__
Properties Properties
---------- ----------
device
######
.. autoattribute:: dragon.Tensor.device
dtype dtype
##### #####
.. autoattribute:: dragon.Tensor.dtype .. autoattribute:: dragon.Tensor.dtype
...@@ -26,6 +30,10 @@ ndim ...@@ -26,6 +30,10 @@ ndim
#### ####
.. autoattribute:: dragon.Tensor.ndim .. autoattribute:: dragon.Tensor.ndim
requires_grad
#############
.. autoattribute:: dragon.Tensor.requires_grad
shape shape
##### #####
.. autoattribute:: dragon.Tensor.shape .. autoattribute:: dragon.Tensor.shape
...@@ -41,21 +49,13 @@ astype ...@@ -41,21 +49,13 @@ astype
###### ######
.. automethod:: dragon.Tensor.astype .. automethod:: dragon.Tensor.astype
constant
########
.. automethod:: dragon.Tensor.constant
copy copy
#### ####
.. automethod:: dragon.Tensor.copy .. automethod:: dragon.Tensor.copy
from_value fill
########## ####
.. automethod:: dragon.Tensor.from_value .. automethod:: dragon.Tensor.fill
get_value
##########
.. automethod:: dragon.Tensor.get_value
glorot_normal glorot_normal
############# #############
...@@ -69,14 +69,14 @@ normal ...@@ -69,14 +69,14 @@ normal
###### ######
.. automethod:: dragon.Tensor.normal .. automethod:: dragon.Tensor.normal
numpy
#####
.. automethod:: dragon.Tensor.numpy
reshape reshape
####### #######
.. automethod:: dragon.Tensor.reshape .. automethod:: dragon.Tensor.reshape
set_value
#########
.. automethod:: dragon.Tensor.set_value
truncated_normal truncated_normal
################ ################
.. automethod:: dragon.Tensor.truncated_normal .. automethod:: dragon.Tensor.truncated_normal
...@@ -92,6 +92,10 @@ __add__ ...@@ -92,6 +92,10 @@ __add__
####### #######
.. automethod:: dragon.Tensor.__add__ .. automethod:: dragon.Tensor.__add__
__and__
#######
.. automethod:: dragon.Tensor.__and__
__float__ __float__
######### #########
.. automethod:: dragon.Tensor.__float__ .. automethod:: dragon.Tensor.__float__
...@@ -108,10 +112,42 @@ __gt__ ...@@ -108,10 +112,42 @@ __gt__
###### ######
.. automethod:: dragon.Tensor.__gt__ .. automethod:: dragon.Tensor.__gt__
__iadd__
########
.. automethod:: dragon.Tensor.__iadd__
__iand__
########
.. automethod:: dragon.Tensor.__iand__
__imul__
########
.. automethod:: dragon.Tensor.__imul__
__int__ __int__
####### #######
.. automethod:: dragon.Tensor.__int__ .. automethod:: dragon.Tensor.__int__
__invert__
###########
.. automethod:: dragon.Tensor.__invert__
__ior__
#######
.. automethod:: dragon.Tensor.__ior__
__isub__
########
.. automethod:: dragon.Tensor.__isub__
__itruediv__
############
.. automethod:: dragon.Tensor.__itruediv__
__ixor__
########
.. automethod:: dragon.Tensor.__ixor__
__le__ __le__
###### ######
.. automethod:: dragon.Tensor.__le__ .. automethod:: dragon.Tensor.__le__
...@@ -128,18 +164,34 @@ __neg__ ...@@ -128,18 +164,34 @@ __neg__
####### #######
.. automethod:: dragon.Tensor.__neg__ .. automethod:: dragon.Tensor.__neg__
__or__
#######
.. automethod:: dragon.Tensor.__or__
__radd__ __radd__
######## ########
.. automethod:: dragon.Tensor.__radd__ .. automethod:: dragon.Tensor.__radd__
__rand__
########
.. automethod:: dragon.Tensor.__rand__
__rmul__ __rmul__
######## ########
.. automethod:: dragon.Tensor.__rmul__ .. automethod:: dragon.Tensor.__rmul__
__ror__
#######
.. automethod:: dragon.Tensor.__ror__
__rsub__ __rsub__
######## ########
.. automethod:: dragon.Tensor.__rsub__ .. automethod:: dragon.Tensor.__rsub__
__rxor__
########
.. automethod:: dragon.Tensor.__rxor__
__setitem__ __setitem__
########### ###########
.. automethod:: dragon.Tensor.__setitem__ .. automethod:: dragon.Tensor.__setitem__
...@@ -156,12 +208,18 @@ __truediv__ ...@@ -156,12 +208,18 @@ __truediv__
############ ############
.. automethod:: dragon.Tensor.__truediv__ .. automethod:: dragon.Tensor.__truediv__
__xor__
#######
.. automethod:: dragon.Tensor.__xor__
.. _dragon.assign(...): assign.html .. _dragon.assign(...): assign.html
.. _dragon.bitwise.bitwise_and(...): bitwise/bitwise_and.html
.. _dragon.bitwise.bitwise_or(...): bitwise/bitwise_or.html
.. _dragon.bitwise.bitwise_xor(...): bitwise/bitwise_xor.html
.. _dragon.bitwise.invert(...): bitwise/invert.html
.. _dragon.cast(...): cast.html .. _dragon.cast(...): cast.html
.. _dragon.fill(...): fill.html .. _dragon.fill(...): fill.html
.. _dragon.identity(...): identity.html .. _dragon.identity(...): identity.html
.. _dragon.masked_assign(...): masked_assign.html
.. _dragon.masked_select(...): masked_select.html
.. _dragon.math.add(...): math/add.html .. _dragon.math.add(...): math/add.html
.. _dragon.math.div(...): math/div.html .. _dragon.math.div(...): math/div.html
.. _dragon.math.greater(...): math/greater.html .. _dragon.math.greater(...): math/greater.html
...@@ -177,7 +235,6 @@ __truediv__ ...@@ -177,7 +235,6 @@ __truediv__
.. _dragon.random.truncated_normal(...): random/truncated_normal.html .. _dragon.random.truncated_normal(...): random/truncated_normal.html
.. _dragon.random.uniform(...): random/uniform.html .. _dragon.random.uniform(...): random/uniform.html
.. _dragon.reshape(...): reshape.html .. _dragon.reshape(...): reshape.html
.. _dragon.slice(...): slice.html
.. raw:: html .. raw:: html
......
...@@ -18,18 +18,6 @@ clear ...@@ -18,18 +18,6 @@ clear
##### #####
.. automethod:: dragon.Workspace.clear .. automethod:: dragon.Workspace.clear
feed_tensor
###########
.. automethod:: dragon.Workspace.feed_tensor
fetch_tensor
############
.. automethod:: dragon.Workspace.fetch_tensor
has_tensor
##########
.. automethod:: dragon.Workspace.has_tensor
memory_allocated memory_allocated
################ ################
.. automethod:: dragon.Workspace.memory_allocated .. automethod:: dragon.Workspace.memory_allocated
...@@ -38,10 +26,6 @@ merge_from ...@@ -38,10 +26,6 @@ merge_from
########## ##########
.. automethod:: dragon.Workspace.merge_from .. automethod:: dragon.Workspace.merge_from
reset_tensor
############
.. automethod:: dragon.Workspace.reset_tensor
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -6,9 +6,6 @@ dragon.autograph ...@@ -6,9 +6,6 @@ dragon.autograph
Functions Functions
--------- ---------
`set_execution(...) <autograph/set_execution.html>`_
: Set the execution mode for graph ir.
`set_optimization(...) <autograph/set_optimization.html>`_ `set_optimization(...) <autograph/set_optimization.html>`_
: Set the optimization for graph ir. : Set the optimization for graph ir.
...@@ -21,7 +18,6 @@ dragon.autograph ...@@ -21,7 +18,6 @@ dragon.autograph
.. toctree:: .. toctree::
:hidden: :hidden:
autograph/set_execution
autograph/set_optimization autograph/set_optimization
autograph/set_scheduler autograph/set_scheduler
autograph/set_verbosity autograph/set_verbosity
......
index_select boolean_mask
============ ============
.. autofunction:: dragon.index_select .. autofunction:: dragon.boolean_mask
.. raw:: html .. raw:: html
......
create_function gather
=============== ======
.. autofunction:: dragon.create_function .. autofunction:: dragon.gather
.. raw:: html .. raw:: html
......
...@@ -10,29 +10,27 @@ dragon.losses ...@@ -10,29 +10,27 @@ dragon.losses
: Compute the ctc loss with batched labels. : Compute the ctc loss with batched labels.
`l1_loss(...) <losses/l1_loss.html>`_ `l1_loss(...) <losses/l1_loss.html>`_
: Compute the element-wise absolute value difference. : Compute the loss of element-wise absolute value difference.
`l2_loss(...) <losses/l2_loss.html>`_ `l2_loss(...) <losses/l2_loss.html>`_
: Compute the element-wise squared error. : Compute the loss of element-wise squared error.
`nll_loss(...) <losses/nll_loss.html>`_ `nll_loss(...) <losses/nll_loss.html>`_
: Compute the negative likelihood loss with sparse labels. : Compute the loss of negative likelihood.
`sigmoid_cross_entropy(...) <losses/sigmoid_cross_entropy.html>`_ `sigmoid_cross_entropy_loss(...) <losses/sigmoid_cross_entropy_loss.html>`_
: Compute the sigmoid cross entropy with contiguous targets. : Compute the loss of sigmoid cross entropy.
`sigmoid_focal_loss(...) <losses/sigmoid_focal_loss.html>`_ `sigmoid_focal_loss(...) <losses/sigmoid_focal_loss.html>`_
: Compute the sigmoid focal loss with sparse labels. : Compute the focal loss of sigmoid cross entropy.
`[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`_.
`smooth_l1_loss(...) <losses/smooth_l1_loss.html>`_ `smooth_l1_loss(...) <losses/smooth_l1_loss.html>`_
: Compute the element-wise error transited from L1 and L2. : Compute the loss of element-wise error transited from L1 and L2.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_. `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
`softmax_cross_entropy(...) <losses/softmax_cross_entropy.html>`_ `softmax_cross_entropy_loss(...) <losses/softmax_cross_entropy_loss.html>`_
: Compute the softmax cross entropy with contiguous targets. : Compute the loss of softmax cross entropy.
`sparse_softmax_cross_entropy(...) <losses/sparse_softmax_cross_entropy.html>`_
: Compute the softmax cross entropy with sparse labels.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -41,11 +39,10 @@ dragon.losses ...@@ -41,11 +39,10 @@ dragon.losses
losses/l1_loss losses/l1_loss
losses/l2_loss losses/l2_loss
losses/nll_loss losses/nll_loss
losses/sigmoid_cross_entropy losses/sigmoid_cross_entropy_loss
losses/sigmoid_focal_loss losses/sigmoid_focal_loss
losses/smooth_l1_loss losses/smooth_l1_loss
losses/softmax_cross_entropy losses/softmax_cross_entropy_loss
losses/sparse_softmax_cross_entropy
.. raw:: html .. raw:: html
......
sigmoid_cross_entropy sigmoid_cross_entropy_loss
===================== ==========================
.. autofunction:: dragon.losses.sigmoid_cross_entropy .. autofunction:: dragon.losses.sigmoid_cross_entropy_loss
.. raw:: html .. raw:: html
......
softmax_cross_entropy softmax_cross_entropy_loss
===================== ==========================
.. autofunction:: dragon.losses.softmax_cross_entropy .. autofunction:: dragon.losses.softmax_cross_entropy_loss
.. raw:: html .. raw:: html
......
...@@ -18,9 +18,6 @@ dragon.math ...@@ -18,9 +18,6 @@ dragon.math
`argmin(...) <math/argmin.html>`_ `argmin(...) <math/argmin.html>`_
: Compute the index of minimum elements along the given axis. : Compute the index of minimum elements along the given axis.
`axpby(...) <math/axpby.html>`_
: Compute the element-wise addition from input to output.
`ceil(...) <math/ceil.html>`_ `ceil(...) <math/ceil.html>`_
: Compute the smallest integer not less than input. : Compute the smallest integer not less than input.
...@@ -36,9 +33,6 @@ dragon.math ...@@ -36,9 +33,6 @@ dragon.math
`div(...) <math/div.html>`_ `div(...) <math/div.html>`_
: Compute the element-wise division. : Compute the element-wise division.
`dot(...) <math/dot.html>`_
: Compute the vector dot.
`equal(...) <math/equal.html>`_ `equal(...) <math/equal.html>`_
: Compute the element-wise equal comparison. : Compute the element-wise equal comparison.
...@@ -72,6 +66,18 @@ dragon.math ...@@ -72,6 +66,18 @@ dragon.math
`log(...) <math/log.html>`_ `log(...) <math/log.html>`_
: Compute the logarithm of input. : Compute the logarithm of input.
`logical_and(...) <math/logical_and.html>`_
: Compute the element-wise AND logical operation.
`logical_not(...) <math/logical_not.html>`_
: Compute the element-wise NOT logical operation.
`logical_or(...) <math/logical_or.html>`_
: Compute the element-wise OR logical operation.
`logical_xor(...) <math/logical_xor.html>`_
: Compute the element-wise XOR logical operation.
`lp_normalize(...) <math/lp_normalize.html>`_ `lp_normalize(...) <math/lp_normalize.html>`_
: Apply the lp normalization. : Apply the lp normalization.
...@@ -93,9 +99,6 @@ dragon.math ...@@ -93,9 +99,6 @@ dragon.math
`minimum(...) <math/minimum.html>`_ `minimum(...) <math/minimum.html>`_
: Compute the minimum value of given two inputs. : Compute the minimum value of given two inputs.
`moments(...) <math/moments.html>`_
: Compute the mean and variance of input along the given axes.
`mul(...) <math/mul.html>`_ `mul(...) <math/mul.html>`_
: Compute the element-wise multiplication. : Compute the element-wise multiplication.
...@@ -151,13 +154,11 @@ dragon.math ...@@ -151,13 +154,11 @@ dragon.math
math/add math/add
math/argmax math/argmax
math/argmin math/argmin
math/axpby
math/ceil math/ceil
math/clip math/clip
math/cos math/cos
math/cumsum math/cumsum
math/div math/div
math/dot
math/equal math/equal
math/exp math/exp
math/floor math/floor
...@@ -169,6 +170,10 @@ dragon.math ...@@ -169,6 +170,10 @@ dragon.math
math/less math/less
math/less_equal math/less_equal
math/log math/log
math/logical_and
math/logical_not
math/logical_or
math/logical_xor
math/lp_normalize math/lp_normalize
math/matmul math/matmul
math/max math/max
...@@ -176,7 +181,6 @@ dragon.math ...@@ -176,7 +181,6 @@ dragon.math
math/mean math/mean
math/min math/min
math/minimum math/minimum
math/moments
math/mul math/mul
math/negative math/negative
math/not_equal math/not_equal
......
eager_scope logical_and
=========== ===========
.. autofunction:: dragon.eager_scope .. autofunction:: dragon.math.logical_and
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon."; content: "dragon.math.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
axpby logical_not
===== ===========
.. autofunction:: dragon.math.axpby .. autofunction:: dragon.math.logical_not
.. raw:: html .. raw:: html
......
dot logical_or
=== ===========
.. autofunction:: dragon.math.dot .. autofunction:: dragon.math.logical_or
.. raw:: html .. raw:: html
......
logical_xor
===========
.. autofunction:: dragon.math.logical_xor
.. raw:: html
<style>
h1:before {
content: "dragon.math.";
color: #103d3e;
}
</style>
...@@ -59,15 +59,15 @@ dragon.nn ...@@ -59,15 +59,15 @@ dragon.nn
: Rearrange depth data into spatial blocks. : Rearrange depth data into spatial blocks.
`dropout(...) <nn/dropout.html>`_ `dropout(...) <nn/dropout.html>`_
: Set the elements of the input to zero randomly. : Set the elements of input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_. `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
`drop_block2d(...) <nn/drop_block2d.html>`_ `drop_block(...) <nn/drop_block.html>`_
: Set the spatial blocks over input to zero randomly. : Set the blocks over input to zero randomly.
`[Ghiasi et.al, 2018] <https://arxiv.org/abs/1810.12890>`_. `[Ghiasi et.al, 2018] <https://arxiv.org/abs/1810.12890>`_.
`drop_path(...) <nn/drop_path.html>`_ `drop_path(...) <nn/drop_path.html>`_
: Set the examples over the input to zero randomly. : Set the examples over input to zero randomly.
`[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_. `[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_.
`elu(...) <nn/elu.html>`_ `elu(...) <nn/elu.html>`_
...@@ -103,6 +103,9 @@ dragon.nn ...@@ -103,6 +103,9 @@ dragon.nn
`log_softmax(...) <nn/log_softmax.html>`_ `log_softmax(...) <nn/log_softmax.html>`_
: Compute the composite of logarithm and softmax. : Compute the composite of logarithm and softmax.
`moments(...) <nn/moments.html>`_
: Compute the mean and variance of input along the given axis.
`prelu(...) <nn/prelu.html>`_ `prelu(...) <nn/prelu.html>`_
: Apply the parametric rectified linear unit. : Apply the parametric rectified linear unit.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_. `[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
...@@ -161,7 +164,7 @@ dragon.nn ...@@ -161,7 +164,7 @@ dragon.nn
nn/depthwise_conv2d nn/depthwise_conv2d
nn/depth_to_space nn/depth_to_space
nn/dropout nn/dropout
nn/drop_block2d nn/drop_block
nn/drop_path nn/drop_path
nn/elu nn/elu
nn/group_norm nn/group_norm
...@@ -172,6 +175,7 @@ dragon.nn ...@@ -172,6 +175,7 @@ dragon.nn
nn/leaky_relu nn/leaky_relu
nn/local_response_norm nn/local_response_norm
nn/log_softmax nn/log_softmax
nn/moments
nn/pool nn/pool
nn/pool1d nn/pool1d
nn/pool2d nn/pool2d
......
ROIPooling drop_block
========== ==========
.. autoclass:: dragon.vm.caffe.core.layers.ROIPooling .. autofunction:: dragon.nn.drop_block
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "dragon.nn.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
moments moments
======= =======
.. autofunction:: dragon.math.moments .. autofunction:: dragon.nn.moments
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.math."; content: "dragon.nn.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -37,14 +37,14 @@ Name Supported Reference ...@@ -37,14 +37,14 @@ Name Supported Reference
`Acos`_ `Acos`_
`Acosh`_ `Acosh`_
`Add`_ |v| :func:`dragon.math.add` `Add`_ |v| :func:`dragon.math.add`
`And`_ |v| :func:`dragon.bitwise.bitwise_and` `And`_ |v| :func:`dragon.math.logical_and`
`ArgMax`_ |v| :func:`dragon.math.argmax` `ArgMax`_ |v| :func:`dragon.math.argmax`
`ArgMin`_ |v| :func:`dragon.math.argmin` `ArgMin`_ |v| :func:`dragon.math.argmin`
`Asin`_ `Asin`_
`Asinh`_ `Asinh`_
`Atan`_ `Atan`_
`Atanh`_ `Atanh`_
`AveragePool`_ |v| :func:`dragon.nn.pool2d` `AveragePool`_ |v| :func:`dragon.nn.pool`
`BatchNormalization`_ |v| :func:`dragon.nn.batch_norm` `BatchNormalization`_ |v| :func:`dragon.nn.batch_norm`
`BitShift`_ `BitShift`_
`Cast`_ |v| :func:`dragon.cast` `Cast`_ |v| :func:`dragon.cast`
...@@ -55,9 +55,9 @@ Name Supported Reference ...@@ -55,9 +55,9 @@ Name Supported Reference
`ConcatFromSequence`_ `ConcatFromSequence`_
`Constant`_ `Constant`_
`ConstantOfShape`_ `ConstantOfShape`_
`Conv`_ |v| :func:`dragon.nn.conv2d` `Conv`_ |v| :func:`dragon.nn.conv`
`ConvInteger`_ `ConvInteger`_
`ConvTranspose`_ |v| :func:`dragon.nn.conv2d_transpose` `ConvTranspose`_ |v| :func:`dragon.nn.conv_transpose`
`Cos`_ |v| :func:`dragon.math.cos` `Cos`_ |v| :func:`dragon.math.cos`
`Cosh`_ `Cosh`_
`CumSum`_ |v| :func:`dragon.math.cumsum` `CumSum`_ |v| :func:`dragon.math.cumsum`
...@@ -76,13 +76,13 @@ Name Supported Reference ...@@ -76,13 +76,13 @@ Name Supported Reference
`Flatten`_ |v| :func:`dragon.flatten` `Flatten`_ |v| :func:`dragon.flatten`
`Floor`_ |v| :func:`dragon.math.floor` `Floor`_ |v| :func:`dragon.math.floor`
`GRU`_ |v| :func:`dragon.nn.GRU` `GRU`_ |v| :func:`dragon.nn.GRU`
`Gather`_ |v| :func:`dragon.index_select` `Gather`_ |v| :func:`dragon.gather`
`GatherElements`_ `GatherElements`_
`GatherND`_ `GatherND`_
`Gemm`_ |v| :func:`dragon.math.gemm` `Gemm`_ |v| :func:`dragon.math.gemm`
`GlobalAveragePool`_ |v| :func:`dragon.nn.pool2d` `GlobalAveragePool`_ |v| :func:`dragon.nn.pool`
`GlobalLpPool`_ `GlobalLpPool`_
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d` `GlobalMaxPool`_ |v| :func:`dragon.nn.pool`
`Greater`_ |v| :func:`dragon.math.greater` `Greater`_ |v| :func:`dragon.math.greater`
`HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid` `HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid`
`Hardmax`_ `Hardmax`_
...@@ -103,7 +103,7 @@ Name Supported Reference ...@@ -103,7 +103,7 @@ Name Supported Reference
`MatMul`_ |v| :func:`dragon.math.matmul` `MatMul`_ |v| :func:`dragon.math.matmul`
`MatMulInteger`_ `MatMulInteger`_
`Max`_ |v| :func:`dragon.math.maximum` `Max`_ |v| :func:`dragon.math.maximum`
`MaxPool`_ |v| :func:`dragon.nn.pool2d` `MaxPool`_ |v| :func:`dragon.nn.pool`
`MaxRoiPool`_ |v| :func:`dragon.vision.roi_pool` `MaxRoiPool`_ |v| :func:`dragon.vision.roi_pool`
`MaxUnpool`_ `MaxUnpool`_
`Mean`_ |v| :func:`dragon.math.add` `Mean`_ |v| :func:`dragon.math.add`
...@@ -114,9 +114,9 @@ Name Supported Reference ...@@ -114,9 +114,9 @@ Name Supported Reference
`Neg`_ |v| :func:`dragon.math.negative` `Neg`_ |v| :func:`dragon.math.negative`
`NonMaxSuppression`_ `NonMaxSuppression`_
`NonZero`_ |v| :func:`dragon.nonzero` `NonZero`_ |v| :func:`dragon.nonzero`
`Not`_ |v| :func:`dragon.bitwise.invert` `Not`_ |v| :func:`dragon.math.logical_not`
`OneHot`_ |v| :func:`dragon.one_hot` `OneHot`_ |v| :func:`dragon.one_hot`
`Or`_ |v| :func:`dragon.bitwise.bitwise_or` `Or`_ |v| :func:`dragon.math.logical_or`
`PRelu`_ |v| :func:`dragon.nn.prelu` `PRelu`_ |v| :func:`dragon.nn.prelu`
`Pad`_ |v| :func:`dragon.pad` `Pad`_ |v| :func:`dragon.pad`
`Pow`_ |v| :func:`dragon.math.pow` `Pow`_ |v| :func:`dragon.math.pow`
...@@ -186,7 +186,7 @@ Name Supported Reference ...@@ -186,7 +186,7 @@ Name Supported Reference
`Unsqueeze`_ |v| :func:`dragon.expand_dims` `Unsqueeze`_ |v| :func:`dragon.expand_dims`
`Upsample`_ |v| :func:`dragon.vision.resize` `Upsample`_ |v| :func:`dragon.vision.resize`
`Where`_ |v| :func:`dragon.where` `Where`_ |v| :func:`dragon.where`
`Xor`_ |v| :func:`dragon.bitwise.bitwise_xor` `Xor`_ |v| :func:`dragon.math.logical_xor`
======================== ========= ======================================== ======================== ========= ========================================
.. toctree:: .. toctree::
......
...@@ -13,7 +13,7 @@ dragon.random ...@@ -13,7 +13,7 @@ dragon.random
: Return a tensor initialized from the glorot uniform distribution. : Return a tensor initialized from the glorot uniform distribution.
`multinomial(...) <random/multinomial.html>`_ `multinomial(...) <random/multinomial.html>`_
: Return a tensor with index sampled from multinomial distribution. : Return an index tensor sampled from the multinomial distribution.
`normal(...) <random/normal.html>`_ `normal(...) <random/normal.html>`_
: Return a tensor initialized from the normal distribution. : Return a tensor initialized from the normal distribution.
......
gradients tril
========= ====
.. autofunction:: dragon.gradients .. autofunction:: dragon.tril
.. raw:: html .. raw:: html
......
masked_assign triu
============= ====
.. autofunction:: dragon.masked_assign .. autofunction:: dragon.triu
.. raw:: html .. raw:: html
......
masked_select variable_scope
============= ==============
.. autofunction:: dragon.masked_select .. autofunction:: dragon.variable_scope
.. raw:: html .. raw:: html
......
...@@ -55,14 +55,11 @@ vm.tensorflow ...@@ -55,14 +55,11 @@ vm.tensorflow
: Return a tensor filled with the scalar value. : Return a tensor filled with the scalar value.
`gather(...) <tensorflow/gather.html>`_ `gather(...) <tensorflow/gather.html>`_
: Select the elements according to the index along the given axis. : Gather the elements along the given axis using index.
`function(...) <tensorflow/function.html>`_ `function(...) <tensorflow/function.html>`_
: Create a callable graph from the python function. : Create a callable graph from the python function.
`gradients(...) <tensorflow/gradients.html>`_
: Compute the symbolic derivatives of ``ys`` w.r.t. ``xs`` .
`identity(...) <tensorflow/identity.html>`_ `identity(...) <tensorflow/identity.html>`_
: Return a tensor copied from the input. : Return a tensor copied from the input.
...@@ -105,6 +102,9 @@ vm.tensorflow ...@@ -105,6 +102,9 @@ vm.tensorflow
`squeeze(...) <tensorflow/squeeze.html>`_ `squeeze(...) <tensorflow/squeeze.html>`_
: Remove the dimensions of input with size 1. : Remove the dimensions of input with size 1.
`tile(...) <tensorflow/tile.html>`_
: Tile input according to the given repeats.
`transpose(...) <tensorflow/transpose.html>`_ `transpose(...) <tensorflow/transpose.html>`_
: Permute the dimensions of input. : Permute the dimensions of input.
...@@ -140,7 +140,6 @@ vm.tensorflow ...@@ -140,7 +140,6 @@ vm.tensorflow
tensorflow/fill tensorflow/fill
tensorflow/function tensorflow/function
tensorflow/gather tensorflow/gather
tensorflow/gradients
tensorflow/identity tensorflow/identity
tensorflow/linspace tensorflow/linspace
tensorflow/name_scope tensorflow/name_scope
...@@ -155,6 +154,7 @@ vm.tensorflow ...@@ -155,6 +154,7 @@ vm.tensorflow
tensorflow/sort tensorflow/sort
tensorflow/split tensorflow/split
tensorflow/squeeze tensorflow/squeeze
tensorflow/tile
tensorflow/transpose tensorflow/transpose
tensorflow/unique tensorflow/unique
tensorflow/unique_with_counts tensorflow/unique_with_counts
......
...@@ -57,7 +57,7 @@ vm.tensorflow.nn ...@@ -57,7 +57,7 @@ vm.tensorflow.nn
: Rearrange depth data into spatial blocks. : Rearrange depth data into spatial blocks.
`dropout(...) <nn/dropout.html>`_ `dropout(...) <nn/dropout.html>`_
: Set the elements of the input to zero randomly. : Set the elements of input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_. `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
`elu(...) <nn/elu.html>`_ `elu(...) <nn/elu.html>`_
......
gradients tile
========= ====
.. autofunction:: dragon.vm.tensorflow.gradients .. autofunction:: dragon.vm.tensorflow.tile
.. raw:: html .. raw:: html
......
...@@ -51,15 +51,18 @@ vm.torch ...@@ -51,15 +51,18 @@ vm.torch
`argsort(...) <torch/argsort.html>`_ `argsort(...) <torch/argsort.html>`_
: Return the index of sorted elements along the given dimension. : Return the index of sorted elements along the given dimension.
`axpby(...) <torch/axpby.html>`_
: Compute the element-wise addition from input to output.
`baddbmm(...) <torch/baddbmm.html>`_ `baddbmm(...) <torch/baddbmm.html>`_
: Add input to the result of batched matrix-matrix multiplication. : Add input to the result of batched matrix-matrix multiplication.
`bitwise_and(...) <torch/bitwise_and.html>`_
: Compute the element-wise AND bitwise operation.
`bitwise_not(...) <torch/bitwise_not.html>`_ `bitwise_not(...) <torch/bitwise_not.html>`_
: Compute the element-wise NOT bitwise operation. : Compute the element-wise NOT bitwise operation.
`bitwise_or(...) <torch/bitwise_or.html>`_
: Compute the element-wise OR bitwise operation.
`bitwise_xor(...) <torch/bitwise_xor.html>`_ `bitwise_xor(...) <torch/bitwise_xor.html>`_
: Compute the element-wise XOR bitwise operation. : Compute the element-wise XOR bitwise operation.
...@@ -73,13 +76,13 @@ vm.torch ...@@ -73,13 +76,13 @@ vm.torch
: Compute the smallest integer not less than input. : Compute the smallest integer not less than input.
`channel_affine(...) <torch/channel_affine.html>`_ `channel_affine(...) <torch/channel_affine.html>`_
: Apply affine transformation along the channels. : Apply affine transformation to each channel of input.
`channel_normalize(...) <torch/channel_normalize.html>`_ `channel_normalize(...) <torch/channel_normalize.html>`_
: Normalize channels with mean and standard deviation. : Apply normalization to each channel of input.
`channel_shuffle(...) <torch/channel_shuffle.html>`_ `channel_shuffle(...) <torch/channel_shuffle.html>`_
: Shuffle channels between a given number of groups. : Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_. `[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`chunk(...) <torch/chunk.html>`_ `chunk(...) <torch/chunk.html>`_
...@@ -143,11 +146,23 @@ vm.torch ...@@ -143,11 +146,23 @@ vm.torch
: Compute the element-wise less-equal comparison. : Compute the element-wise less-equal comparison.
`linspace(...) <torch/linspace.html>`_ `linspace(...) <torch/linspace.html>`_
: Generate evenly spaced values within intervals along the given axis. : Generate evenly spaced values within intervals along the given dimension.
`log(...) <torch/log.html>`_ `log(...) <torch/log.html>`_
: Compute the natural logarithm of input. : Compute the natural logarithm of input.
`logical_and(...) <torch/logical_and.html>`_
: Compute the element-wise AND logical operation.
`logical_not(...) <torch/logical_not.html>`_
: Compute the element-wise NOT logical operation.
`logical_or(...) <torch/logical_or.html>`_
: Compute the element-wise OR logical operation.
`logical_xor(...) <torch/logical_xor.html>`_
: Compute the element-wise XOR logical operation.
`logsumexp(...) <torch/logsumexp.html>`_ `logsumexp(...) <torch/logsumexp.html>`_
: Apply the composite of log, sum, and exp to input. : Apply the composite of log, sum, and exp to input.
...@@ -155,7 +170,7 @@ vm.torch ...@@ -155,7 +170,7 @@ vm.torch
: Compute the element-wise less comparison. : Compute the element-wise less comparison.
`masked_select(...) <torch/logsumexp.html>`_ `masked_select(...) <torch/logsumexp.html>`_
: Select the input elements where mask is 1. : Select the input elements where mask is true.
`matmul(...) <torch/matmul.html>`_ `matmul(...) <torch/matmul.html>`_
: Compute the matrix multiplication. : Compute the matrix multiplication.
...@@ -223,9 +238,6 @@ vm.torch ...@@ -223,9 +238,6 @@ vm.torch
`reciprocal(...) <torch/reciprocal.html>`_ `reciprocal(...) <torch/reciprocal.html>`_
: Compute the reciprocal of input. : Compute the reciprocal of input.
`repeat(...) <torch/repeat.html>`_
: Repeat elements along the specified dimensions.
`reshape(...) <torch/reshape.html>`_ `reshape(...) <torch/reshape.html>`_
: Change the shape of input. : Change the shape of input.
...@@ -265,12 +277,21 @@ vm.torch ...@@ -265,12 +277,21 @@ vm.torch
`tensor(...) <torch/tensor.html>`_ `tensor(...) <torch/tensor.html>`_
: Create a tensor initializing the content from data. : Create a tensor initializing the content from data.
`tile(...) <torch/tile.html>`_
: Repeat elements along each dimension of input.
`topk(...) <torch/topk.html>`_ `topk(...) <torch/topk.html>`_
: Return the top-K largest or smallest elements along the given dimension. : Return the top-K largest or smallest elements along the given dimension.
`transpose(...) <torch/transpose.html>`_ `transpose(...) <torch/transpose.html>`_
: Return a new tensor with two dimensions swapped. : Return a new tensor with two dimensions swapped.
`tril(...) <torch/tril.html>`_
: Return the lower triangular part of input.
`triu(...) <torch/triu.html>`_
: Return the upper triangular part of input.
`unique(...) <torch/unique.html>`_ `unique(...) <torch/unique.html>`_
: Return the unique elements of input. : Return the unique elements of input.
...@@ -298,9 +319,10 @@ vm.torch ...@@ -298,9 +319,10 @@ vm.torch
torch/argmax torch/argmax
torch/argmin torch/argmin
torch/argsort torch/argsort
torch/axpby
torch/baddbmm torch/baddbmm
torch/bitwise_and
torch/bitwise_not torch/bitwise_not
torch/bitwise_or
torch/bitwise_xor torch/bitwise_xor
torch/bmm torch/bmm
torch/cat torch/cat
...@@ -333,6 +355,10 @@ vm.torch ...@@ -333,6 +355,10 @@ vm.torch
torch/le torch/le
torch/linspace torch/linspace
torch/log torch/log
torch/logical_and
torch/logical_not
torch/logical_or
torch/logical_xor
torch/logsumexp torch/logsumexp
torch/lt torch/lt
torch/masked_select torch/masked_select
...@@ -359,7 +385,6 @@ vm.torch ...@@ -359,7 +385,6 @@ vm.torch
torch/randn torch/randn
torch/randperm torch/randperm
torch/reciprocal torch/reciprocal
torch/repeat
torch/reshape torch/reshape
torch/round torch/round
torch/rsqrt torch/rsqrt
...@@ -374,8 +399,11 @@ vm.torch ...@@ -374,8 +399,11 @@ vm.torch
torch/sub torch/sub
torch/sum torch/sum
torch/tensor torch/tensor
torch/tile
torch/topk torch/topk
torch/transpose torch/transpose
torch/tril
torch/triu
torch/unique torch/unique
torch/unsqueeze torch/unsqueeze
torch/where torch/where
......
...@@ -81,6 +81,14 @@ baddbmm\_ ...@@ -81,6 +81,14 @@ baddbmm\_
######### #########
.. automethod:: dragon.vm.torch.Tensor.baddbmm_ .. automethod:: dragon.vm.torch.Tensor.baddbmm_
bitwise_and
###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_and
bitwise_and\_
#############
.. automethod:: dragon.vm.torch.Tensor.bitwise_and_
bitwise_not bitwise_not
########### ###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_not .. automethod:: dragon.vm.torch.Tensor.bitwise_not
...@@ -89,6 +97,14 @@ bitwise_not\_ ...@@ -89,6 +97,14 @@ bitwise_not\_
############# #############
.. automethod:: dragon.vm.torch.Tensor.bitwise_not_ .. automethod:: dragon.vm.torch.Tensor.bitwise_not_
bitwise_or
###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_or
bitwise_or\_
############
.. automethod:: dragon.vm.torch.Tensor.bitwise_or_
bitwise_xor bitwise_xor
########### ###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_xor .. automethod:: dragon.vm.torch.Tensor.bitwise_xor
...@@ -281,6 +297,22 @@ log ...@@ -281,6 +297,22 @@ log
### ###
.. automethod:: dragon.vm.torch.Tensor.log .. automethod:: dragon.vm.torch.Tensor.log
logical_and
###########
.. automethod:: dragon.vm.torch.Tensor.logical_and
logical_not
###########
.. automethod:: dragon.vm.torch.Tensor.logical_not
logical_or
##########
.. automethod:: dragon.vm.torch.Tensor.logical_or
logical_xor
###########
.. automethod:: dragon.vm.torch.Tensor.logical_xor
logsumexp logsumexp
######### #########
.. automethod:: dragon.vm.torch.Tensor.logsumexp .. automethod:: dragon.vm.torch.Tensor.logsumexp
...@@ -513,6 +545,22 @@ transpose ...@@ -513,6 +545,22 @@ transpose
######### #########
.. automethod:: dragon.vm.torch.Tensor.transpose .. automethod:: dragon.vm.torch.Tensor.transpose
tril
####
.. automethod:: dragon.vm.torch.Tensor.tril
tril\_
######
.. automethod:: dragon.vm.torch.Tensor.tril_
triu
####
.. automethod:: dragon.vm.torch.Tensor.triu
triu\_
######
.. automethod:: dragon.vm.torch.Tensor.triu_
type type
#### ####
.. automethod:: dragon.vm.torch.Tensor.type .. automethod:: dragon.vm.torch.Tensor.type
...@@ -560,7 +608,9 @@ zero\_ ...@@ -560,7 +608,9 @@ zero\_
.. _torch.argmin(...): argmin.html .. _torch.argmin(...): argmin.html
.. _torch.argsort(...): argsort.html .. _torch.argsort(...): argsort.html
.. _torch.baddbmm(...): baddbmm.html .. _torch.baddbmm(...): baddbmm.html
.. _torch.bitwise_and(...): bitwise_and.html
.. _torch.bitwise_not(...): bitwise_not.html .. _torch.bitwise_not(...): bitwise_not.html
.. _torch.bitwise_or(...): bitwise_or.html
.. _torch.bitwise_xor(...): bitwise_xor.html .. _torch.bitwise_xor(...): bitwise_xor.html
.. _torch.bmm(...): bmm.html .. _torch.bmm(...): bmm.html
.. _torch.ceil(...): ceil.html .. _torch.ceil(...): ceil.html
...@@ -579,6 +629,12 @@ zero\_ ...@@ -579,6 +629,12 @@ zero\_
.. _torch.isinf(...): isinf.html .. _torch.isinf(...): isinf.html
.. _torch.isnan(...): isnan.html .. _torch.isnan(...): isnan.html
.. _torch.le(...): le.html .. _torch.le(...): le.html
.. _torch.log(...): log.html
.. _torch.logical_and(...): logical_and.html
.. _torch.logical_not(...): logical_not.html
.. _torch.logical_or(...): logical_or.html
.. _torch.logical_xor(...): logical_xor.html
.. _torch.logsumexp(...): logsumexp.html
.. _torch.lt(...): lt.html .. _torch.lt(...): lt.html
.. _torch.matmul(...): matmul.html .. _torch.matmul(...): matmul.html
.. _torch.max(...): max.html .. _torch.max(...): max.html
...@@ -606,6 +662,8 @@ zero\_ ...@@ -606,6 +662,8 @@ zero\_
.. _torch.sum(...): sum.html .. _torch.sum(...): sum.html
.. _torch.topk(...): topk.html .. _torch.topk(...): topk.html
.. _torch.transpose(...): transpose.html .. _torch.transpose(...): transpose.html
.. _torch.tril(...): tril.html
.. _torch.triu(...): triu.html
.. _torch.unique(...): unique.html .. _torch.unique(...): unique.html
.. _torch.unsqueeze(...): unsqueeze.html .. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html .. _torch.where(...): where.html
......
...@@ -6,9 +6,6 @@ vm.torch.autograd ...@@ -6,9 +6,6 @@ vm.torch.autograd
Classes Classes
------- -------
`class Function <autograd/Function.html>`_
: Dispatch the tensor operation.
Functions Functions
--------- ---------
...@@ -18,7 +15,6 @@ vm.torch.autograd ...@@ -18,7 +15,6 @@ vm.torch.autograd
.. toctree:: .. toctree::
:hidden: :hidden:
autograd/Function
autograd/backward autograd/backward
.. raw:: html .. raw:: html
......
Function
========
.. autoclass:: dragon.vm.torch.autograd.Function
__init__
--------
.. automethod:: dragon.vm.torch.autograd.Function.__init__
Methods
-------
alloc
#####
.. automethod:: dragon.vm.torch.autograd.Function.alloc
apply
#####
.. automethod:: dragon.vm.torch.autograd.Function.apply
attributes
##########
.. automethod:: dragon.vm.torch.autograd.Function.attributes
dispatch
########
.. automethod:: dragon.vm.torch.autograd.Function.dispatch
instantiate
###########
.. automethod:: dragon.vm.torch.autograd.Function.instantiate
forward
#######
.. automethod:: dragon.vm.torch.autograd.Function.forward
.. raw:: html
<style>
h1:before {
content: "torch.autograd.";
color: #103d3e;
}
</style>
axpby bitwise_and
===== ===========
.. autofunction:: dragon.vm.torch.axpby .. autofunction:: dragon.vm.torch.bitwise_and
.. raw:: html .. raw:: html
......
repeat bitwise_or
====== ==========
.. autofunction:: dragon.vm.torch.bitwise_or
.. autofunction:: dragon.vm.torch.repeat
.. raw:: html .. raw:: html
......
...@@ -7,6 +7,24 @@ __init__ ...@@ -7,6 +7,24 @@ __init__
-------- --------
.. automethod:: dragon.vm.torch.device.__init__ .. automethod:: dragon.vm.torch.device.__init__
Properties
----------
index
#####
.. autoattribute:: dragon.vm.torch.device.index
type
####
.. autoattribute:: dragon.vm.torch.device.type
Methods
-------
copy
####
.. automethod:: dragon.vm.torch.device.copy
.. raw:: html .. raw:: html
<style> <style>
......
ROIAlign logical_and
======== ===========
.. autofunction:: dragon.vm.torch.logical_and
.. autoclass:: dragon.vm.caffe.core.layers.ROIAlign
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "torch.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
set_execution logical_not
============= ===========
.. autofunction:: dragon.vm.torch.logical_not
.. autofunction:: dragon.autograph.set_execution
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.autograph."; content: "torch.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
logical_or
==========
.. autofunction:: dragon.vm.torch.logical_or
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
drop_block2d logical_xor
============ ===========
.. autofunction:: dragon.vm.torch.logical_xor
.. autofunction:: dragon.nn.drop_block2d
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.nn."; content: "torch.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -142,6 +142,10 @@ vm.torch.nn ...@@ -142,6 +142,10 @@ vm.torch.nn
`class Linear <nn/Linear.html>`_ `class Linear <nn/Linear.html>`_
: Apply the linear transformation. : Apply the linear transformation.
`class LayerNorm <nn/LayerNorm.html>`_
: Apply the layer normalization.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
`class LocalResponseNorm <nn/LocalResponseNorm.html>`_ `class LocalResponseNorm <nn/LocalResponseNorm.html>`_
: Apply the local response normalization. : Apply the local response normalization.
`[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_. `[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_.
...@@ -169,9 +173,16 @@ vm.torch.nn ...@@ -169,9 +173,16 @@ vm.torch.nn
`class Module <nn/Module.html>`_ `class Module <nn/Module.html>`_
: The base class of modules. : The base class of modules.
`class ModuleList <nn/ModuleList.html>`_
: The list module container.
`class MSELoss <nn/MSELoss.html>`_ `class MSELoss <nn/MSELoss.html>`_
: Compute the element-wise squared error. : Compute the element-wise squared error.
`class MultiheadAttention <nn/MultiheadAttention.html>`_
: Apply the multihead attention.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class NLLLoss <nn/NLLLoss.html>`_ `class NLLLoss <nn/NLLLoss.html>`_
: Compute the negative likelihood loss with sparse labels. : Compute the negative likelihood loss with sparse labels.
...@@ -216,6 +227,9 @@ vm.torch.nn ...@@ -216,6 +227,9 @@ vm.torch.nn
: Apply the scaled exponential linear unit. : Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_. `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`class Sequential <nn/Sequential.html>`_
: The sequential module container.
`class Sigmoid <nn/Sigmoid.html>`_ `class Sigmoid <nn/Sigmoid.html>`_
: Apply the sigmoid function. : Apply the sigmoid function.
...@@ -237,6 +251,22 @@ vm.torch.nn ...@@ -237,6 +251,22 @@ vm.torch.nn
`class Tanh <nn/Tanh.html>`_ `class Tanh <nn/Tanh.html>`_
: Apply the tanh function. : Apply the tanh function.
`class TransformerDecoder <nn/TransformerDecoder.html>`_
: Standard transformer decoder.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class TransformerDecoderLayer <nn/TransformerDecoderLayer.html>`_
: Layer for a standard transformer decoder.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class TransformerEncoder <nn/TransformerEncoder.html>`_
: Standard transformer encoder.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class TransformerEncoderLayer <nn/TransformerEncoderLayer.html>`_
: Layer for a standard transformer encoder.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class SyncBatchNorm <nn/SyncBatchNorm.html>`_ `class SyncBatchNorm <nn/SyncBatchNorm.html>`_
: Apply the sync batch normalization over input. : Apply the sync batch normalization over input.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_. `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
...@@ -295,6 +325,7 @@ vm.torch.nn ...@@ -295,6 +325,7 @@ vm.torch.nn
nn/Identity nn/Identity
nn/KLDivLoss nn/KLDivLoss
nn/L1Loss nn/L1Loss
nn/LayerNorm
nn/LeakyReLU nn/LeakyReLU
nn/Linear nn/Linear
nn/LocalResponseNorm nn/LocalResponseNorm
...@@ -305,7 +336,9 @@ vm.torch.nn ...@@ -305,7 +336,9 @@ vm.torch.nn
nn/MaxPool2d nn/MaxPool2d
nn/MaxPool3d nn/MaxPool3d
nn/Module nn/Module
nn/ModuleList
nn/MSELoss nn/MSELoss
nn/MultiheadAttention
nn/NLLLoss nn/NLLLoss
nn/Parameter nn/Parameter
nn/PReLU nn/PReLU
...@@ -319,12 +352,17 @@ vm.torch.nn ...@@ -319,12 +352,17 @@ vm.torch.nn
nn/ReplicationPad3d nn/ReplicationPad3d
nn/RNN nn/RNN
nn/SELU nn/SELU
nn/Sequential
nn/Sigmoid nn/Sigmoid
nn/SigmoidFocalLoss nn/SigmoidFocalLoss
nn/SmoothL1Loss nn/SmoothL1Loss
nn/Softmax nn/Softmax
nn/Swish nn/Swish
nn/Tanh nn/Tanh
nn/TransformerDecoder
nn/TransformerDecoderLayer
nn/TransformerEncoder
nn/TransformerEncoderLayer
nn/SyncBatchNorm nn/SyncBatchNorm
nn/Upsample nn/Upsample
nn/UpsamplingBilinear2d nn/UpsamplingBilinear2d
......
LayerNorm
=========
.. autoclass:: dragon.vm.torch.nn.LayerNorm
__init__
--------
.. automethod:: dragon.vm.torch.nn.LayerNorm.__init__
.. _torch.nn.functional.layer_norm(...): functional/layer_norm.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
ModuleList
==========
.. autoclass:: dragon.vm.torch.nn.ModuleList
__init__
--------
.. automethod:: dragon.vm.torch.nn.ModuleList.__init__
Methods
-------
add_module
##########
.. automethod:: dragon.vm.torch.nn.Module.add_module
:noindex:
append
######
.. automethod:: dragon.vm.torch.nn.ModuleList.apply
apply
#####
.. automethod:: dragon.vm.torch.nn.Module.apply
:noindex:
buffers
#######
.. automethod:: dragon.vm.torch.nn.Module.buffers
:noindex:
children
########
.. automethod:: dragon.vm.torch.nn.Module.children
:noindex:
cpu
###
.. automethod:: dragon.vm.torch.nn.Module.cpu
:noindex:
cuda
####
.. automethod:: dragon.vm.torch.nn.Module.cuda
:noindex:
double
######
.. automethod:: dragon.vm.torch.nn.Module.double
:noindex:
eval
####
.. automethod:: dragon.vm.torch.nn.Module.eval
:noindex:
extend
######
.. automethod:: dragon.vm.torch.nn.ModuleList.extend
float
#####
.. automethod:: dragon.vm.torch.nn.Module.float
:noindex:
half
####
.. automethod:: dragon.vm.torch.nn.Module.half
:noindex:
insert
######
.. automethod:: dragon.vm.torch.nn.ModuleList.insert
load_state_dict
###############
.. automethod:: dragon.vm.torch.nn.Module.load_state_dict
:noindex:
modules
#######
.. automethod:: dragon.vm.torch.nn.Module.modules
:noindex:
named_buffers
#############
.. automethod:: dragon.vm.torch.nn.Module.named_buffers
:noindex:
named_children
##############
.. automethod:: dragon.vm.torch.nn.Module.named_children
:noindex:
named_modules
#############
.. automethod:: dragon.vm.torch.nn.Module.named_modules
:noindex:
named_parameters
################
.. automethod:: dragon.vm.torch.nn.Module.named_parameters
:noindex:
parameters
##########
.. automethod:: dragon.vm.torch.nn.Module.parameters
:noindex:
register_buffer
###############
.. automethod:: dragon.vm.torch.nn.Module.register_buffer
:noindex:
register_forward_hook
#####################
.. automethod:: dragon.vm.torch.nn.Module.register_forward_hook
:noindex:
register_parameter
##################
.. automethod:: dragon.vm.torch.nn.Module.register_parameter
:noindex:
state_dict
##########
.. automethod:: dragon.vm.torch.nn.Module.state_dict
:noindex:
train
#####
.. automethod:: dragon.vm.torch.nn.Module.train
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.nn.Module.zero_grad
:noindex:
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
MultiheadAttention
==================
.. autoclass:: dragon.vm.torch.nn.MultiheadAttention
__init__
--------
.. automethod:: dragon.vm.torch.nn.MultiheadAttention.__init__
.. _torch.nn.functional.multi_head_attention_forward(...): functional/multi_head_attention_forward.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
Sequential
==========
.. autoclass:: dragon.vm.torch.nn.Sequential
__init__
--------
.. automethod:: dragon.vm.torch.nn.Sequential.__init__
Methods
-------
add_module
##########
.. automethod:: dragon.vm.torch.nn.Module.add_module
:noindex:
apply
#####
.. automethod:: dragon.vm.torch.nn.Module.apply
:noindex:
buffers
#######
.. automethod:: dragon.vm.torch.nn.Module.buffers
:noindex:
children
########
.. automethod:: dragon.vm.torch.nn.Module.children
:noindex:
cpu
###
.. automethod:: dragon.vm.torch.nn.Module.cpu
:noindex:
cuda
####
.. automethod:: dragon.vm.torch.nn.Module.cuda
:noindex:
double
######
.. automethod:: dragon.vm.torch.nn.Module.double
:noindex:
eval
####
.. automethod:: dragon.vm.torch.nn.Module.eval
:noindex:
float
#####
.. automethod:: dragon.vm.torch.nn.Module.float
:noindex:
forward
#######
.. automethod:: dragon.vm.torch.nn.Sequential.forward
half
####
.. automethod:: dragon.vm.torch.nn.Module.half
:noindex:
load_state_dict
###############
.. automethod:: dragon.vm.torch.nn.Module.load_state_dict
:noindex:
modules
#######
.. automethod:: dragon.vm.torch.nn.Module.modules
:noindex:
named_buffers
#############
.. automethod:: dragon.vm.torch.nn.Module.named_buffers
:noindex:
named_children
##############
.. automethod:: dragon.vm.torch.nn.Module.named_children
:noindex:
named_modules
#############
.. automethod:: dragon.vm.torch.nn.Module.named_modules
:noindex:
named_parameters
################
.. automethod:: dragon.vm.torch.nn.Module.named_parameters
:noindex:
parameters
##########
.. automethod:: dragon.vm.torch.nn.Module.parameters
:noindex:
register_buffer
###############
.. automethod:: dragon.vm.torch.nn.Module.register_buffer
:noindex:
register_forward_hook
#####################
.. automethod:: dragon.vm.torch.nn.Module.register_forward_hook
:noindex:
register_parameter
##################
.. automethod:: dragon.vm.torch.nn.Module.register_parameter
:noindex:
state_dict
##########
.. automethod:: dragon.vm.torch.nn.Module.state_dict
:noindex:
train
#####
.. automethod:: dragon.vm.torch.nn.Module.train
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.nn.Module.zero_grad
:noindex:
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
TransformerDecoder
==================
.. autoclass:: dragon.vm.torch.nn.TransformerDecoder
__init__
--------
.. automethod:: dragon.vm.torch.nn.TransformerDecoder.__init__
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
TransformerDecoderLayer
=======================
.. autoclass:: dragon.vm.torch.nn.TransformerDecoderLayer
__init__
--------
.. automethod:: dragon.vm.torch.nn.TransformerDecoderLayer.__init__
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
TransformerEncoder
=======================
.. autoclass:: dragon.vm.torch.nn.TransformerEncoder
__init__
--------
.. automethod:: dragon.vm.torch.nn.TransformerEncoder.__init__
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
TransformerEncoderLayer
=======================
.. autoclass:: dragon.vm.torch.nn.TransformerEncoderLayer
__init__
--------
.. automethod:: dragon.vm.torch.nn.TransformerEncoderLayer.__init__
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
...@@ -41,22 +41,22 @@ vm.torch.nn.functional ...@@ -41,22 +41,22 @@ vm.torch.nn.functional
: Compute the sigmoid cross entropy with contiguous target. : Compute the sigmoid cross entropy with contiguous target.
`conv1d(...) <functional/conv1d.html>`_ `conv1d(...) <functional/conv1d.html>`_
: Apply 1d convolution to the input. : Apply the 1d convolution to input.
`conv2d(...) <functional/conv2d.html>`_ `conv2d(...) <functional/conv2d.html>`_
: Apply 2d convolution to the input. : Apply the 2d convolution to input.
`conv3d(...) <functional/conv3d.html>`_ `conv3d(...) <functional/conv3d.html>`_
: Apply 3d convolution to the input. : Apply the 3d convolution to input.
`conv_transpose1d(...) <functional/conv_transpose1d.html>`_ `conv_transpose1d(...) <functional/conv_transpose1d.html>`_
: Apply 1d deconvolution to the input. : Apply the 1d deconvolution to input.
`conv_transpose2d(...) <functional/conv_transpose2d.html>`_ `conv_transpose2d(...) <functional/conv_transpose2d.html>`_
: Apply 2d deconvolution to the input. : Apply the 2d deconvolution to input.
`conv_transpose3d(...) <functional/conv_transpose3d.html>`_ `conv_transpose3d(...) <functional/conv_transpose3d.html>`_
: Apply 3d deconvolution to the input. : Apply the 3d deconvolution to input.
`cross_entropy(...) <functional/cross_entropy.html>`_ `cross_entropy(...) <functional/cross_entropy.html>`_
: Compute the softmax cross entropy with sparse labels. : Compute the softmax cross entropy with sparse labels.
...@@ -66,11 +66,11 @@ vm.torch.nn.functional ...@@ -66,11 +66,11 @@ vm.torch.nn.functional
`[Graves & Gomez, 2006] <http://www.cs.utoronto.ca/~graves/icml_2006.pdf>`_. `[Graves & Gomez, 2006] <http://www.cs.utoronto.ca/~graves/icml_2006.pdf>`_.
`depthwise_conv2d(...) <functional/depthwise_conv2d.html>`_ `depthwise_conv2d(...) <functional/depthwise_conv2d.html>`_
: Apply 2d depthwise convolution to the input. : Apply the 2d depthwise convolution to input.
`[Chollet, 2016] <https://arxiv.org/abs/1610.02357>`_. `[Chollet, 2016] <https://arxiv.org/abs/1610.02357>`_.
`drop_block2d(...) <functional/drop_block2d.html>`_ `drop_block2d(...) <functional/drop_block2d.html>`_
: Set the spatial blocks over input to zero randomly. : Set the blocks over input to zero randomly.
`[Ghiasi et.al, 2018] <https://arxiv.org/abs/1810.12890>`_. `[Ghiasi et.al, 2018] <https://arxiv.org/abs/1810.12890>`_.
`drop_path(...) <functional/drop_path.html>`_ `drop_path(...) <functional/drop_path.html>`_
...@@ -78,7 +78,7 @@ vm.torch.nn.functional ...@@ -78,7 +78,7 @@ vm.torch.nn.functional
`[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_. `[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_.
`dropout(...) <functional/dropout.html>`_ `dropout(...) <functional/dropout.html>`_
: Set the elements of the input to zero randomly. : Set the elements of input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_. `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
`elu(...) <functional/elu.html>`_ `elu(...) <functional/elu.html>`_
...@@ -102,6 +102,10 @@ vm.torch.nn.functional ...@@ -102,6 +102,10 @@ vm.torch.nn.functional
`l1_loss(...) <functional/l1_loss.html>`_ `l1_loss(...) <functional/l1_loss.html>`_
: Compute the element-wise absolute value difference. : Compute the element-wise absolute value difference.
`layer_norm(...) <functional/layer_norm.html>`_
: Apply the layer normalization to input.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
`leaky_relu(...) <functional/leaky_relu.html>`_ `leaky_relu(...) <functional/leaky_relu.html>`_
: Apply the leaky rectified linear unit to input. : Apply the leaky rectified linear unit to input.
...@@ -130,6 +134,10 @@ vm.torch.nn.functional ...@@ -130,6 +134,10 @@ vm.torch.nn.functional
`mse_loss(...) <functional/mse_loss.html>`_ `mse_loss(...) <functional/mse_loss.html>`_
: Compute the element-wise squared error. : Compute the element-wise squared error.
`multi_head_attention_forward(...) <functional/multi_head_attention_forward.html>`_
: Apply the multihead attention to input.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`nll_loss(...) <functional/nll_loss.html>`_ `nll_loss(...) <functional/nll_loss.html>`_
: Compute the negative likelihood loss with sparse labels. : Compute the negative likelihood loss with sparse labels.
...@@ -216,6 +224,7 @@ vm.torch.nn.functional ...@@ -216,6 +224,7 @@ vm.torch.nn.functional
functional/l1_loss functional/l1_loss
functional/leaky_relu functional/leaky_relu
functional/linear functional/linear
functional/layer_norm
functional/local_response_norm functional/local_response_norm
functional/log_softmax functional/log_softmax
functional/interpolate functional/interpolate
...@@ -223,6 +232,7 @@ vm.torch.nn.functional ...@@ -223,6 +232,7 @@ vm.torch.nn.functional
functional/max_pool2d functional/max_pool2d
functional/max_pool3d functional/max_pool3d
functional/mse_loss functional/mse_loss
functional/multi_head_attention_forward
functional/nll_loss functional/nll_loss
functional/normalize functional/normalize
functional/pad functional/pad
......
layer_norm
==========
.. autofunction:: dragon.vm.torch.nn.functional.layer_norm
.. _torch.nn.LayerNorm(...): ../LayerNorm.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
sparse_softmax_cross_entropy multi_head_attention_forward
============================ ============================
.. autofunction:: dragon.losses.sparse_softmax_cross_entropy .. autofunction:: dragon.vm.torch.nn.functional.multi_head_attention_forward
.. _torch.nn.MultiheadAttention(...): ../MultiheadAttention.html
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.losses."; content: "torch.nn.functional.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -27,6 +27,9 @@ vm.torch.nn.init ...@@ -27,6 +27,9 @@ vm.torch.nn.init
`normal_(...) <init/normal_.html>`_ `normal_(...) <init/normal_.html>`_
: Fill tensor from a normal distribution. : Fill tensor from a normal distribution.
`ones_(...) <init/ones_.html>`_
: Fill tensor with ones.
`uniform_(...) <init/uniform_.html>`_ `uniform_(...) <init/uniform_.html>`_
: Fill tensor from an uniform distribution. : Fill tensor from an uniform distribution.
...@@ -36,6 +39,9 @@ vm.torch.nn.init ...@@ -36,6 +39,9 @@ vm.torch.nn.init
`xavier_uniform_(...) <init/xavier_uniform_.html>`_ `xavier_uniform_(...) <init/xavier_uniform_.html>`_
: Fill tensor from a xavier uniform distribution. : Fill tensor from a xavier uniform distribution.
`zeros_(...) <init/zeros_.html>`_
: Fill tensor with zeros.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -46,9 +52,11 @@ vm.torch.nn.init ...@@ -46,9 +52,11 @@ vm.torch.nn.init
init/kaiming_normal_ init/kaiming_normal_
init/kaiming_uniform_ init/kaiming_uniform_
init/normal_ init/normal_
init/ones_
init/uniform_ init/uniform_
init/xavier_normal_ init/xavier_normal_
init/xavier_uniform_ init/xavier_uniform_
init/zeros_
.. raw:: html .. raw:: html
......
ones\_
======
.. autofunction:: dragon.vm.torch.nn.init.ones_
.. raw:: html
<style>
h1:before {
content: "torch.nn.init.";
color: #103d3e;
}
</style>
zeros\_
=======
.. autofunction:: dragon.vm.torch.nn.init.zeros_
.. raw:: html
<style>
h1:before {
content: "torch.nn.init.";
color: #103d3e;
}
</style>
...@@ -10,11 +10,6 @@ __init__ ...@@ -10,11 +10,6 @@ __init__
Methods Methods
------- -------
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group add_param_group
############### ###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group .. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
...@@ -25,6 +20,11 @@ step ...@@ -25,6 +20,11 @@ step
.. automethod:: dragon.vm.torch.optim.Optimizer.step .. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex: :noindex:
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
:noindex:
zero_grad zero_grad
######### #########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad .. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
......
...@@ -10,10 +10,6 @@ __init__ ...@@ -10,10 +10,6 @@ __init__
Methods Methods
------- -------
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
add_param_group add_param_group
############### ###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group .. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
...@@ -22,6 +18,10 @@ step ...@@ -22,6 +18,10 @@ step
#### ####
.. automethod:: dragon.vm.torch.optim.Optimizer.step .. automethod:: dragon.vm.torch.optim.Optimizer.step
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
zero_grad zero_grad
######### #########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad .. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
......
...@@ -10,11 +10,6 @@ __init__ ...@@ -10,11 +10,6 @@ __init__
Methods Methods
------- -------
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group add_param_group
############### ###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group .. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
...@@ -25,6 +20,11 @@ step ...@@ -25,6 +20,11 @@ step
.. automethod:: dragon.vm.torch.optim.Optimizer.step .. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex: :noindex:
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
:noindex:
zero_grad zero_grad
######### #########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad .. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
......
...@@ -10,11 +10,6 @@ __init__ ...@@ -10,11 +10,6 @@ __init__
Methods Methods
------- -------
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group add_param_group
############### ###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group .. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
...@@ -25,6 +20,11 @@ step ...@@ -25,6 +20,11 @@ step
.. automethod:: dragon.vm.torch.optim.Optimizer.step .. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex: :noindex:
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
:noindex:
zero_grad zero_grad
######### #########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad .. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
......
tile
====
.. autofunction:: dragon.vm.torch.tile
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
tril
====
.. autofunction:: dragon.vm.torch.tril
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
triu
====
.. autofunction:: dragon.vm.torch.triu
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_CNML_H_
#define DRAGON_CORE_CONTEXT_CNML_H_
#include "dragon/core/common.h"
struct cnrtStream;
struct cnmlCpuTensor;
struct cnmlTensor;
struct cnmlFusionOp;
typedef struct cnrtStream* cnrtStream_t;
typedef struct cnmlCpuTensor* cnmlCpuTensor_t;
typedef struct cnmlTensor* cnmlTensor_t;
typedef struct cnmlFusionOp* cnmlFusionOp_t;
namespace dragon {
/*!
* \brief The cnml device context.
*/
class CNMLContext {
public:
/*! \brief Default constructor */
CNMLContext() : device_id_(0), random_seed_(DEFAULT_RNG_SEED) {}
/*! \brief Constructor with the device index */
explicit CNMLContext(int device)
: device_id_(device), random_seed_(DEFAULT_RNG_SEED) {}
/*! \brief Constructor with the device option */
explicit CNMLContext(const DeviceOption& option)
: device_id_(option.device_id()),
random_seed_(
option.has_random_seed() ? option.random_seed()
: DEFAULT_RNG_SEED) {
CHECK_EQ(option.device_type(), PROTO_CNML);
}
/*! \brief Allocate a block of memory */
static void* New(size_t size) {
return nullptr;
}
/*! \brief Set a memory block to the given value */
static void Memset(size_t n, void* ptr, int value) {}
/*! \brief Set a memory block to the given value asynchronously */
void MemsetAsync(size_t n, void* ptr, int value) {
Memset(n, ptr, value);
}
/*! \brief Copy a memory block to the destination */
template <class DestContext, class SrcContext>
static void Memcpy(size_t n, void* dest, const void* src) {}
/*! \brief Copy a memory block to the destination asynchronously */
template <class DestContext, class SrcContext>
void MemcpyAsync(size_t n, void* dest, const void* src) {
Memcpy<DestContext, SrcContext>(dest, src, n);
}
/*! \brief Deallocate a memory block */
static void Delete(void* ptr) {}
/*! \brief Switch to the device in current thread */
void SwitchToDevice() {
SwitchToDevice(0);
}
/*! \brief Switch to the device and select given stream in current thread */
void SwitchToDevice(int stream) {}
/*! \brief Wait for the dispatched computation to complete */
void FinishDeviceComputation() {}
/*! \brief Return the cnrt stream */
cnrtStream_t cnrt_stream() {
return cnrt_stream(device_id_, stream_id_);
}
/*! \brief Return the specified cnrt stream */
static cnrtStream_t cnrt_stream(int device_id, int stream_id) {
return (cnrtStream_t) nullptr;
}
/*! \brief Return the device index */
int device() const {
return device_id_;
}
/*! \brief Return the stream index */
int stream() const {
return stream_id_;
}
private:
int device_id_, stream_id_ = 1, random_seed_;
unique_ptr<std::mt19937> rand_generator_;
};
} // namespace dragon
#endif // DRAGON_CORE_CONTEXT_CNML_H_
#include "dragon/core/gradient.h"
namespace dragon {
void GradientTape::CreateGradientDefs(
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& grad_targets) {
def_.clear_op();
Set<string> split_grads;
Map<string, string> sources_to_grads;
Map<string, string> targets_to_grads;
Map<string, int> inputs_count, splits_count;
// Function to check if grad op can be removed.
auto IsNoGradient = [&](const OperatorDef& op,
vector<pair<string, int>>& init_grads) mutable {
if (NoGradientRegistry()->Has(op.type())) {
return true;
}
bool maybe_skip = false;
for (int i = 0; i < op.output_size(); ++i) {
const auto& y = op.output(i);
if (!sources_to_grads.count(y)) {
maybe_skip = true;
if (targets_to_grads.count(y)) {
init_grads.push_back({y, i});
sources_to_grads[y] = y + "_grad";
}
}
}
return maybe_skip && init_grads.empty() && op.output_size() == 1;
};
// Set the gradient of targets.
for (int i = 0; i < targets.size(); ++i) {
targets_to_grads[targets[i]] =
i < grad_targets.size() ? grad_targets[i] : "";
}
// PLAY for the forward.
for (auto* op : op_defs) {
if (NoGradientRegistry()->Has(op->type())) continue;
for (const auto& x : op->input()) {
if (std::find(op->output().begin(), op->output().end(), x) ==
op->output().end()) {
inputs_count[x]++;
}
}
}
// PLAY for the backward.
for (auto op_iter = op_defs.rbegin(); op_iter != op_defs.rend(); op_iter++) {
const auto& op = *op_iter;
vector<pair<string, int>> init_grads;
if (IsNoGradient(*op, init_grads)) continue;
vector<string> grad_ys;
for (const auto& y : op->output()) {
const auto& iter = sources_to_grads.find(y);
grad_ys.emplace_back(iter != sources_to_grads.end() ? iter->second : "");
}
CHECK(GradientRegistry()->Has(op->type()))
<< "\nMissing gradient maker for " << op->type() << ".";
unique_ptr<GradientMakerBase> maker(
GradientRegistry()->Create(op->type(), *op, grad_ys));
maker->Make();
vector<OperatorDef> gather_defs;
for (auto& grad_def : maker->grad_defs()) {
for (int i = 0; i < grad_def.output_size(); ++i) {
const auto& grad_ys = grad_def.input();
const auto& grad_x = grad_def.output(i);
if (std::find(grad_ys.begin(), grad_ys.end(), grad_x) != grad_ys.end())
continue;
int x_index = -1;
for (int j = 0; j < maker->grad_inputs().size(); ++j) {
if (grad_x == maker->grad_inputs()[j]) x_index = j;
}
if (x_index == -1) continue;
const auto& x = op->input(x_index);
if (inputs_count[x] <= 1) continue;
auto split_prefix = grad_x + "_split_";
auto grad_x_split = split_prefix + str::to(splits_count[grad_x]++);
split_grads.insert(grad_x_split);
if (splits_count[grad_x] == inputs_count[x]) {
gather_defs.emplace_back(CreateOperatorDef(
"GradientGather",
"",
vector<string>({}),
vector<string>({grad_x}),
vector<Argument>(),
grad_def.device_option()));
for (int j = 0; j < splits_count[grad_x]; ++j) {
auto iter = split_grads.find(split_prefix + str::to(j));
if (iter != split_grads.end()) {
gather_defs.back().add_input(*iter);
}
}
}
grad_def.set_output(i, grad_x_split);
}
}
for (int i = 0; i < op->input_size(); ++i) {
sources_to_grads[op->input(i)] = maker->grad_inputs()[i];
}
if (init_grads.size() > 0) {
Argument values;
values.set_name("values");
vector<string> inputs, outputs;
auto fills = maker->grad_defaults();
for (auto& iter : init_grads) {
const auto& grad = targets_to_grads[iter.first];
inputs.emplace_back(grad.empty() ? iter.first : grad);
outputs.emplace_back(iter.first + "_grad");
values.add_floats(grad.empty() ? fills[iter.second] : -100.f);
}
def_.add_op()->CopyFrom(CreateOperatorDef(
"GradientFill",
"",
inputs,
outputs,
vector<Argument>({values}),
op->device_option()));
}
for (const auto& grad_def : maker->grad_defs()) {
def_.add_op()->CopyFrom(grad_def);
}
for (const auto& gather_def : gather_defs) {
def_.add_op()->CopyFrom(gather_def);
}
}
}
void GradientTape::Optimize(const vector<string>& sources) {
Set<int> noop_indices;
Set<string> required_grads;
Map<string, int> inputs_count;
Map<string, string> grads_to_buffers;
Map<string, pair<int, string>> splits;
for (int op_index = 0; op_index < def_.op_size(); ++op_index) {
const auto& op = def_.op(op_index);
if (!str::find(op.type(), "Gradient")) continue;
// Count noops.
if (op.type() == "GradientGather") noop_indices.insert(op_index);
// Count inputs.
for (const auto& input : op.input()) {
inputs_count[input] += 1;
}
}
// Initialize the required grads before optimization.
for (const auto& input : sources) {
required_grads.insert(input + "_grad");
}
for (auto op_index : noop_indices) {
const auto& op = def_.op(op_index);
if (op.type() == "GradientGather") {
if (inputs_count.count(op.output(0)) == 0 &&
required_grads.count(op.output(0)) == 0) {
for (const auto& input : op.input()) {
inputs_count.erase(input);
}
} else {
string first_input;
for (const auto& input : op.input()) {
if (!input.empty()) {
if (first_input.empty()) first_input = input;
splits[input] = {op_index, first_input};
}
}
}
}
}
optimized_def_ = def_;
optimized_def_.clear_op();
for (int op_index = 0; op_index < def_.op_size(); ++op_index) {
if (noop_indices.count(op_index)) continue;
const auto& op = def_.op(op_index);
optimized_def_.add_op()->CopyFrom(op);
if (!str::find(op.type(), "Gradient")) continue;
for (const auto& output : op.output()) {
// Decouple the gathering of split grads.
const auto& split_iter = splits.find(output);
if (split_iter != splits.end()) {
auto& gather_op = def_.op(split_iter->second.first);
auto* decouple_op = optimized_def_.add_op();
decouple_op->CopyFrom(gather_op);
decouple_op->clear_input();
if (output != split_iter->second.second) {
decouple_op->set_type("GradientAdd");
decouple_op->add_input(gather_op.output(0));
const auto& count_iter = inputs_count.find(gather_op.output(0));
if (count_iter != inputs_count.end()) count_iter->second++;
}
decouple_op->add_input(output);
if (!op.arg().empty()) {
const auto& arg = *(op.arg().end() - 1);
if (arg.name() == "cache_key") {
auto* new_arg = decouple_op->add_arg();
const auto& dev = decouple_op->device_option();
new_arg->set_name("cache_key");
new_arg->set_s(
decouple_op->type() + "/" + str::to(dev.device_type()) + ":" +
str::to(dev.device_id()));
}
}
}
}
}
// Prepare the pool
int buffer_index = 0;
std::deque<string> pool;
auto get_buffer = [&]() mutable {
if (pool.empty()) {
return "shared/buffer/grad:" + str::to(++buffer_index);
} else {
auto buffer = pool.back();
pool.pop_back();
return buffer;
}
};
for (int op_index = 0; op_index < optimized_def_.op_size(); ++op_index) {
auto* op = optimized_def_.mutable_op(op_index);
if (!str::find(op->type(), "Gradient")) continue;
// Check output aliases.
vec32_t output_aliases(op->output_size(), -1);
for (int i = 0; i < op->output_size(); ++i) {
for (int j = 0; j < op->input_size(); ++j) {
if (op->output(i) != op->input(j)) continue;
output_aliases[i] = j;
break;
}
}
// Rewrite inputs.
vector<string> dead_buffers;
for (int i = 0; i < op->input_size(); ++i) {
const string& input = op->input(i);
const auto& count_iter = inputs_count.find(input);
if (count_iter == inputs_count.end()) continue;
count_iter->second--;
const auto& buffer_iter = grads_to_buffers.find(input);
if (buffer_iter == grads_to_buffers.end()) continue;
if (count_iter->second == 0) {
dead_buffers.emplace_back(buffer_iter->second);
}
op->set_input(i, buffer_iter->second);
}
// Rewrite outputs.
for (int i = 0; i < op->output_size(); ++i) {
const string& output = op->output(i);
if (output.empty() || required_grads.count(output) > 0) continue;
if (inputs_count.count(output) == 0) {
op->mutable_output(i)->clear();
continue;
}
if (output_aliases[i] >= 0) {
op->set_output(i, op->input(output_aliases[i]));
} else {
*op->mutable_output(i) = grads_to_buffers[output] = get_buffer();
}
}
// Update pool.
for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer);
}
}
}
DEFINE_REGISTRY(
GradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
DEFINE_REGISTRY(
NoGradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
} // namespace dragon
...@@ -10,32 +10,21 @@ ...@@ -10,32 +10,21 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_ #ifndef DRAGON_CORE_GRADIENT_H_
#define DRAGON_CORE_OPERATOR_GRADIENT_H_ #define DRAGON_CORE_GRADIENT_H_
#include "dragon/core/common.h" #include "dragon/core/common.h"
#include "dragon/core/operator.h"
#include "dragon/core/registry.h" #include "dragon/core/registry.h"
#include "dragon/utils/proto_utils.h" #include "dragon/utils/proto_utils.h"
namespace dragon { namespace dragon {
struct GradientPack {
GradientPack(
const vector<OperatorDef>& grad_defs,
const vector<string>& grad_inputs,
const vector<float>& defaults)
: grad_defs(grad_defs), grad_inputs(grad_inputs), defaults(defaults) {}
vector<OperatorDef> grad_defs;
vector<string> grad_inputs;
vector<float> defaults;
};
class GradientMakerBase { class GradientMakerBase {
public: public:
GradientMakerBase(const OperatorDef& def, const vector<string>& grad_outputs) GradientMakerBase(const OperatorDef& def, const vector<string>& grad_outputs)
: def(def), grad_inputs_(def.input_size()), grad_outputs_(grad_outputs) {} : def_(def),
grad_outputs_(grad_outputs),
grad_inputs_(def.input_size()) {}
virtual ~GradientMakerBase() {} virtual ~GradientMakerBase() {}
...@@ -49,45 +38,54 @@ class GradientMakerBase { ...@@ -49,45 +38,54 @@ class GradientMakerBase {
return true; return true;
} }
virtual GradientPack Make() { virtual void Make() {
auto new_defs = MakeDef(); CreateGradientDefs();
if (def.has_cache_key()) { string cache_key;
// Attach the handle to name if having cache key if (!def_.arg().empty()) {
for (size_t i = 0; i < new_defs.size(); i++) { const auto& arg = *(def_.arg().end() - 1);
new_defs[i].set_name(def.name()); if (arg.name() == "cache_key") cache_key = arg.s();
}
Argument new_arg;
new_arg.set_name("handle");
new_arg.set_s(def_.name());
for (auto& grad_def : grad_defs_) {
if (CopyDeviceOption() && def_.has_device_option()) {
grad_def.mutable_device_option()->CopyFrom(def_.device_option());
} }
} else { if (CopyArguments() && !def_.arg().empty()) {
// Otherwise, just put it into the arguments grad_def.mutable_arg()->MergeFrom(def_.arg());
Argument arg; if (!cache_key.empty()) grad_def.mutable_arg()->RemoveLast();
arg.set_name("handle"); }
arg.set_s(def.name()); grad_def.add_arg()->CopyFrom(new_arg);
for (size_t i = 0; i < new_defs.size(); i++) { }
new_defs[i].add_arg()->CopyFrom(arg); if (!cache_key.empty()) {
cache_key += "/grad";
new_arg.set_name("cache_key");
for (int i = 0; i < grad_defs_.size(); ++i) {
new_arg.set_s(cache_key + (i > 0 ? ("/" + str::to(i)) : ""));
grad_defs_[i].add_arg()->CopyFrom(new_arg);
} }
} }
return GradientPack(new_defs, grad_inputs_, defaults());
}; };
virtual vector<OperatorDef> MakeDef() { virtual void CreateGradientDefs() {}
return vector<OperatorDef>();
}
template <class... Args> template <class... Args>
static vector<OperatorDef> SingleDef(const Args&... args) { void AddGradientDef(const Args&... args) {
return vector<OperatorDef>{MakeOperatorDef(args...)}; grad_defs_.emplace_back(CreateOperatorDef(args...));
} }
const string I(const int i) const { const string I(const int i) const {
return i < int(def.input_size()) ? def.input(i) : ""; return i < int(def_.input_size()) ? def_.input(i) : "";
} }
const string O(const int i) const { const string O(const int i) const {
return i < int(def.output_size()) ? def.output(i) : ""; return i < int(def_.output_size()) ? def_.output(i) : "";
} }
string GI(const int i) { string GI(const int i) {
if (i >= int(grad_inputs_.size())) return ""; if (i >= int(grad_inputs_.size())) return "";
grad_inputs_[i] = def.input(i) + "_grad"; grad_inputs_[i] = def_.input(i) + "_grad";
return grad_inputs_[i]; return grad_inputs_[i];
} }
...@@ -95,82 +93,89 @@ class GradientMakerBase { ...@@ -95,82 +93,89 @@ class GradientMakerBase {
return i < int(grad_outputs_.size()) ? grad_outputs_[i] : ""; return i < int(grad_outputs_.size()) ? grad_outputs_[i] : "";
} }
virtual vector<float> defaults() { const OperatorDef& def() const {
return def_;
}
vector<OperatorDef>& grad_defs() {
return grad_defs_;
}
vector<string>& grad_inputs() {
return grad_inputs_;
}
virtual vector<float> grad_defaults() {
return vector<float>(grad_outputs_.size(), 1.f); return vector<float>(grad_outputs_.size(), 1.f);
} }
protected: protected:
const OperatorDef& def; const OperatorDef& def_;
vector<string> grad_inputs_; vector<OperatorDef> grad_defs_;
const vector<string>& grad_outputs_; const vector<string>& grad_outputs_;
vector<string> grad_inputs_;
}; };
DRAGON_API GradientPack #define GRADIENT_MAKER_CTOR(name) \
MakeGradientForOp(const OperatorDef& op_def, const vector<string>& g_outputs); name(const OperatorDef& def, const vector<string>& grad_outputs) \
: GradientMakerBase(def, grad_outputs) {}
#define GRADIENT_MAKER_CTOR(name) \
name(const OperatorDef& def, const vector<string>& g_output) \
: GradientMakerBase(def, g_output) {}
class NoGradient : public GradientMakerBase { class NoGradient : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(NoGradient); GRADIENT_MAKER_CTOR(NoGradient);
vector<OperatorDef> MakeDef() override {
return vector<OperatorDef>();
}
}; };
namespace { namespace {
// Here we define some common gradient makers
// Reuse them to make the codes cleaner
class GenericGradientMaker final : public GradientMakerBase { class GenericGradientMaker final : public GradientMakerBase {
public: public:
/*!
* Inputs: X1, X2, ..., Xn, dY1, dY2, ..., dYm
* Outputs: dX1, dX2, ..., dXn
*/
GRADIENT_MAKER_CTOR(GenericGradientMaker); GRADIENT_MAKER_CTOR(GenericGradientMaker);
vector<OperatorDef> MakeDef() override { void CreateGradientDefs() override {
vector<string> inputs, outputs; /*!
for (const auto& input : def.input()) * X1, X2, ..., Xn, dY1, dY2, ..., dYm
inputs.push_back(input); * dX1, dX2, ..., dXn
for (int i = 0; i < def.output_size(); ++i) */
inputs.push_back(GO(i)); vector<string> inputs({def().input().begin(), def().input().end()});
for (int i = 0; i < def.input_size(); ++i) vector<string> outputs;
outputs.push_back(GI(i)); for (int i = 0; i < def().output_size(); ++i) {
return SingleDef(def.type() + "Gradient", "", inputs, outputs); inputs.emplace_back(GO(i));
}
for (int i = 0; i < def().input_size(); ++i) {
outputs.emplace_back(GI(i));
}
AddGradientDef(def().type() + "Gradient", "", inputs, outputs);
} }
}; };
class SimpleGradientMaker final : public GradientMakerBase { class SimpleGradientMaker final : public GradientMakerBase {
public: public:
/*!
* Inputs: dY1, dY2, ..., dYm
* Outputs: dX1, dX2, ..., dXn
*/
GRADIENT_MAKER_CTOR(SimpleGradientMaker); GRADIENT_MAKER_CTOR(SimpleGradientMaker);
vector<OperatorDef> MakeDef() override { void CreateGradientDefs() override {
/*!
* dY1, dY2, ..., dYm
* dX1, dX2, ..., dXn
*/
vector<string> inputs, outputs; vector<string> inputs, outputs;
for (int i = 0; i < def.output_size(); ++i) for (int i = 0; i < def().output_size(); ++i) {
inputs.push_back(GO(i)); inputs.emplace_back(GO(i));
for (int i = 0; i < def.input_size(); ++i) }
outputs.push_back(GI(i)); for (int i = 0; i < def().input_size(); ++i) {
return SingleDef(def.type() + "Gradient", "", inputs, outputs); outputs.emplace_back(GI(i));
}
AddGradientDef(def().type() + "Gradient", "", inputs, outputs);
} }
}; };
class InplaceGradientMaker final : public GradientMakerBase { class InplaceGradientMaker final : public GradientMakerBase {
public: public:
/*!
* Inputs: Y, dY
* Outputs: dX
*/
GRADIENT_MAKER_CTOR(InplaceGradientMaker); GRADIENT_MAKER_CTOR(InplaceGradientMaker);
vector<OperatorDef> MakeDef() override { void CreateGradientDefs() override {
return SingleDef( /*!
def.type() + "Gradient", * Y, dY
* dX
*/
AddGradientDef(
def().type() + "Gradient",
"", "",
vector<string>({O(0), GO(0)}), vector<string>({O(0), GO(0)}),
vector<string>({GI(0)})); vector<string>({GI(0)}));
...@@ -179,6 +184,39 @@ class InplaceGradientMaker final : public GradientMakerBase { ...@@ -179,6 +184,39 @@ class InplaceGradientMaker final : public GradientMakerBase {
} // namespace } // namespace
class DRAGON_API GradientTape {
public:
GradientTape() {}
GradientTape(const GraphDef& def) : def_(def) {}
GradientTape(
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& grad_targets) {
CreateGradientDefs(op_defs, targets, grad_targets);
}
/*! \brief Create gradient defs */
void CreateGradientDefs(
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& grad_targets);
/*! \brief Optimize gradient computations */
void Optimize(const vector<string>& sources = vector<string>());
/*! \brief Return gradient defs */
const GraphDef& def() {
if (optimized_def_.op_size() > 0) {
return optimized_def_;
}
return def_;
}
private:
GraphDef def_;
GraphDef optimized_def_;
};
DECLARE_REGISTRY( DECLARE_REGISTRY(
GradientRegistry, GradientRegistry,
GradientMakerBase, GradientMakerBase,
...@@ -191,7 +229,6 @@ DECLARE_REGISTRY( ...@@ -191,7 +229,6 @@ DECLARE_REGISTRY(
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&);
// Defined in the operator.cc
#define REGISTER_GRADIENT(name, ...) \ #define REGISTER_GRADIENT(name, ...) \
REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
...@@ -201,4 +238,4 @@ DECLARE_REGISTRY( ...@@ -201,4 +238,4 @@ DECLARE_REGISTRY(
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_OPERATOR_GRADIENT_H_ #endif // DRAGON_CORE_GRADIENT_H_
#include <regex> #include <regex>
#include "dragon/core/graph.h" #include "dragon/core/graph.h"
#include "dragon/core/graph_gradient.h"
#include "dragon/core/graph_optimizer.h" #include "dragon/core/graph_optimizer.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
...@@ -9,99 +8,78 @@ namespace dragon { ...@@ -9,99 +8,78 @@ namespace dragon {
GraphBase::GraphBase(const GraphDef& def, Workspace* ws) GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
: def_(def), ws_(ws), name_(def.name()), phase_("TEST") { : def_(def), ws_(ws), name_(def.name()), phase_("TEST") {
// Collect arguments // Collect arguments.
for (auto& arg : def_.arg()) { for (auto& arg : def_.arg()) {
CHECK_GT(arg.name().size(), 0); CHECK_GT(arg.name().size(), 0);
CHECK_EQ(args_.count(arg.name()), 0); CHECK_EQ(args_.count(arg.name()), 0);
args_[arg.name()] = &arg; args_[arg.name()] = &arg;
if (arg.name() == "phase") phase_ = arg.s(); if (arg.name() == "phase") phase_ = arg.s();
} }
// Check inputs.
// Collect outputs
Set<string> outputs; Set<string> outputs;
for (const auto& op : def.op()) { for (const auto& op : def.op()) {
for (const auto& input : op.input()) for (const auto& input : op.input())
CHECK(outputs.count(input) || ws_->HasTensor(input)) CHECK(outputs.count(input) || ws_->HasTensor(input))
<< "\nThe input <" << input << "> is not in graph."; << "\nInput " << input << " is not in the graph.";
for (const auto& output : op.output()) { for (const auto& output : op.output()) {
outputs.insert(output); outputs.insert(output);
} }
} }
// Check outputs.
// Check targets for (const auto& output : def.output()) {
Set<string> targets; CHECK(outputs.count(output) || ws_->HasTensor(output))
for (const auto& target : def.output()) { << "\nOutput " << output << " is not in the graph.";
CHECK(outputs.count(target) || ws_->HasTensor(target))
<< "\nThe output <" << target << "> is not in graph.";
targets.insert(target);
}
// Check gradients
for (const auto& grad_info : def.grad_info()) {
const auto& y = grad_info.y();
CHECK_GT(targets.count(y), 0)
<< "\nThe derivative target <" << y << "> is not in outputs.";
for (const auto& x : grad_info.xs()) {
CHECK(outputs.count(x) || ws_->HasTensor(x))
<< "\nThe differentiated input <" << x << "> is not in graph.";
}
} }
} }
bool Graph::Create(const GraphDef& def) { bool Graph::Create(const GraphDef& def) {
this->optimized_def_ = def; // Store for debugging this->optimized_def_ = def;
bool has_device_option = def.has_device_option(); bool has_device_option = def.has_device_option();
for (int i = 0; i < def.op_size(); i++) { for (int i = 0; i < def.op_size(); i++) {
auto op_def(def.op(i)); auto op_def(def.op(i));
LOG(DEBUG) << "Create Operator " << op_def.name() << ": " << op_def.type(); // Inherit device if not provided.
// Inherit device option if necessary
if (!op_def.has_device_option() && has_device_option) { if (!op_def.has_device_option() && has_device_option) {
op_def.mutable_device_option()->CopyFrom(def.device_option()); op_def.mutable_device_option()->CopyFrom(def.device_option());
} }
Argument arg; LOG(DEBUG) << "Create: " << op_def.name() << " [" << op_def.type() << "]";
// For the last operator, enforce the synchronization ops_.push_back(OperatorBase::New(op_def, ws_));
if (i == def.op_size() - 1) { ops_.back()->set_output_aliases(output_aliases_);
arg.set_name("do_sync");
arg.set_i(1);
op_def.add_arg()->CopyFrom(arg);
}
cached_ops_.push_back(NewOperator(op_def, ws_));
cached_ops_.back()->set_output_aliases(output_aliases_);
} }
return true; return true;
} }
Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
// Apply the optimizations // Apply the optimizations.
GraphDef def_v2(def); GraphDef def_v2(def);
GraphOptimizer graph_optimizer(ws); GraphOptimizer optimizer(ws);
GraphGradientMaker gradient_maker;
Map<string, vec32_t> subgraph_indices; Map<string, vec32_t> subgraph_indices;
int opt = 3; // default: O3 int opt = 1;
if (args().count("optimization")) opt = arg("optimization").i(); if (args().count("optimization")) opt = arg("optimization").i();
if (opt >= 1) def_v2 = graph_optimizer.EliminateUnused(def); if (opt >= 2) optimizer.PlanInplace(def_v2, output_aliases_);
if (opt >= 2) graph_optimizer.PlanInplace(def_v2, output_aliases_);
if (opt >= 3) { if (opt >= 3) {
if (phase() == "TRAIN") { if (phase() == "TRAIN") {
def_v2 = graph_optimizer.PlanCheckpoint(def_v2, subgraph_indices); def_v2 = optimizer.PlanCheckpoint(def_v2, subgraph_indices);
def_v2 = gradient_maker.Optimize(def_v2); if (args().count("grad_sources")) {
GradientTape tape(def_v2);
auto& grad_sources = args_["grad_sources"]->strings();
tape.Optimize({grad_sources.begin(), grad_sources.end()});
def_v2 = tape.def();
}
} else { } else {
def_v2 = graph_optimizer.SimulateGC(def_v2); def_v2 = optimizer.EliminateIntermediates(def_v2);
} }
} }
// Create graph.
// Create
Create(def_v2); Create(def_v2);
// Create subgraphs.
// Recomputation and SubGraph
if (subgraph_indices.size() > 0) { if (subgraph_indices.size() > 0) {
Map<string, vector<OperatorBase*>> subgraph; Map<string, vector<OperatorBase*>> subgraph;
for (const auto& it : subgraph_indices) { for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>(); subgraph[it.first] = vector<OperatorBase*>();
for (auto op_idx : subgraph_indices[it.first]) for (auto op_idx : subgraph_indices[it.first])
subgraph[it.first].push_back(cached_ops_[op_idx]); subgraph[it.first].push_back(ops_[op_idx]);
} }
for (auto* op : cached_ops_) { for (auto* op : ops_) {
op->set_subgraph(subgraph); op->set_subgraph(subgraph);
} }
} }
...@@ -111,27 +89,28 @@ bool Graph::Run(int stream, const string& include, const string& exclude) { ...@@ -111,27 +89,28 @@ bool Graph::Run(int stream, const string& include, const string& exclude) {
unique_ptr<std::regex> regex_incl, regex_excl; unique_ptr<std::regex> regex_incl, regex_excl;
if (!include.empty()) regex_incl.reset(new std::regex(include)); if (!include.empty()) regex_incl.reset(new std::regex(include));
if (!exclude.empty()) regex_excl.reset(new std::regex(exclude)); if (!exclude.empty()) regex_excl.reset(new std::regex(exclude));
LOG(DEBUG) << "Run Graph: " << name(); LOG(DEBUG) << "Run: " << name();
for (auto* op : cached_ops_) { for (auto* op : ops_) {
if (regex_incl && !regex_match(op->type(), *regex_incl)) continue; if (regex_incl && !regex_match(op->type(), *regex_incl)) continue;
if (regex_excl && regex_match(op->type(), *regex_excl)) continue; if (regex_excl && regex_match(op->type(), *regex_excl)) continue;
op->SwitchToPhase(phase()); op->SwitchToPhase(phase());
LOG(DEBUG) << "Run Op: " << op->name(); LOG(DEBUG) << "Run: " << op->name();
op->Run(stream); op->Run(stream);
LOG(DEBUG) << "Finish Op: " << op->name(); LOG(DEBUG) << "Finish: " << op->name();
} }
LOG(DEBUG) << "Finish: " << name();
return true; return true;
} }
GraphBase* NewGraph(const GraphDef& def, Workspace* ws) { GraphBase* GraphBase::New(const GraphDef& def, Workspace* ws) {
if (!def.has_graph_type() || def.graph_type().empty()) { if (!def.has_type() || def.type().empty()) {
return new Graph(def, ws); // Sequential scheduler // Sequential scheduler.
return new Graph(def, ws);
} }
return GraphRegistry()->Create(def.graph_type(), def, ws); return GraphRegistry()->Create(def.type(), def, ws);
} }
/* Graph Registry */ /* Graph Registry */
DEFINE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*); DEFINE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*);
} // namespace dragon } // namespace dragon
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
namespace dragon { namespace dragon {
class Workspace;
/*! /*!
* \brief The base graph class. * \brief The base graph class.
*/ */
...@@ -29,6 +31,9 @@ class DRAGON_API GraphBase { ...@@ -29,6 +31,9 @@ class DRAGON_API GraphBase {
/*! \brief Destructor */ /*! \brief Destructor */
virtual ~GraphBase() {} virtual ~GraphBase() {}
/*! \brief Create a new graph */
static GraphBase* New(const GraphDef& def, Workspace* ws);
/*! \brief Create graph in the workspace */ /*! \brief Create graph in the workspace */
virtual bool Create(const GraphDef& def) = 0; virtual bool Create(const GraphDef& def) = 0;
...@@ -102,8 +107,8 @@ class Graph : public GraphBase { ...@@ -102,8 +107,8 @@ class Graph : public GraphBase {
/*! \brief Destructor */ /*! \brief Destructor */
virtual ~Graph() { virtual ~Graph() {
for (auto* cached_op : cached_ops_) { for (auto* op : ops_) {
delete cached_op; delete op;
} }
} }
...@@ -117,16 +122,13 @@ class Graph : public GraphBase { ...@@ -117,16 +122,13 @@ class Graph : public GraphBase {
const string& exclude = "") override; const string& exclude = "") override;
protected: protected:
/*! \brief The cached operators */ /*! \brief The created operators */
vector<OperatorBase*> cached_ops_; vector<OperatorBase*> ops_;
/*! \brief The candidate output aliases */ /*! \brief The output aliases */
Map<string, Set<string>> output_aliases_; Map<string, Set<string>> output_aliases_;
}; };
/*! \brief Create a graph from the raw def */
GraphBase* NewGraph(const GraphDef&, Workspace*);
/* Macros */ /* Macros */
DECLARE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*); DECLARE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*);
......
#include "dragon/core/graph_gradient.h"
#include "dragon/core/operator.h"
namespace dragon {
bool GraphGradientMaker::CheckGrad(
const OperatorDef& op,
const Set<string>& targets,
vector<pair<string, int>>& gen_grads) {
if (NoGradientRegistry()->Has(op.type())) {
return true;
}
bool maybe_skip = false;
for (int i = 0; i < op.output_size(); ++i) {
const auto& out = op.output(i);
if (!inputs_to_grads_.count(out)) {
maybe_skip = true;
if (targets.count(out)) {
gen_grads.push_back({out, i});
inputs_to_grads_[out] = out + "_grad";
}
}
}
return maybe_skip && gen_grads.empty() && op.output_size() == 1;
}
void GraphGradientMaker::Make(
const vector<OperatorDef*>& ops,
const vector<string>& targets,
const vector<string>& input_grads,
GraphDef& graph) {
Set<string> split_grads, targets_v2;
Map<string, int> inputs_count, grads_count;
// PLAY for the forward
for (auto* op : ops) {
if (NoGradientRegistry()->Has(op->type())) continue;
for (const auto& input : op->input()) {
bool input_in_outputs = false;
for (auto& output : op->output())
if (output == input) {
input_in_outputs = true;
break;
}
// Avoid to count the duplicate input (i.e. the in-place output)
if (!input_in_outputs) inputs_count[input]++;
}
}
// Set the gradient of targets
for (int i = 0; i < targets.size(); ++i) {
if (i < input_grads.size()) {
inputs_to_grads_[targets[i]] = input_grads[i];
}
targets_v2.insert(targets[i]);
}
// PLAY for the backward
for (int op_idx = (int)ops.size() - 1; op_idx >= 0; --op_idx) {
const auto& op = *ops[op_idx];
// Generate def by registered gradient maker
vector<pair<string, int>> gen_grads;
vector<string> grad_outputs;
bool is_skip = CheckGrad(op, targets_v2, gen_grads);
for (const auto& out : op.output()) {
string grad_out = "";
const auto& it = inputs_to_grads_.find(out);
if (it != inputs_to_grads_.end()) grad_out = it->second;
grad_outputs.push_back(grad_out);
}
auto pack = MakeGradientForOp(op, grad_outputs);
// Split and gather gradient for multi-used inputs
vector<OperatorDef> gather_ops;
for (auto& grad_def : pack.grad_defs) {
if (!grad_def.has_name()) {
grad_def.set_name(GetOperatorName());
}
for (int i = 0; i < grad_def.output_size(); ++i) {
const auto& grad_name = grad_def.output(i);
int original_index = -1;
for (int j = 0; j < pack.grad_inputs.size(); ++j) {
if (grad_name == pack.grad_inputs[j]) {
original_index = j;
}
}
if (original_index == -1) continue;
bool output_in_inputs = false;
for (const auto& name : grad_def.input()) {
if (grad_name == name) {
output_in_inputs = true;
break;
}
}
if (output_in_inputs) continue;
// Detect a split branch
const auto& original_name = op.input(original_index);
if (inputs_count[original_name] > 1) {
auto grad_name_v2 =
grad_name + "_autosplit_" + str::to(grads_count[grad_name]++);
if (!is_skip) split_grads.insert(grad_name_v2);
if (grads_count[grad_name] == inputs_count[original_name]) {
auto gather_op = MakeOperatorDef(
"GradientGather",
GetOperatorName(),
vector<string>({}),
vector<string>({grad_name}));
if (grad_def.has_device_option()) {
gather_op.mutable_device_option()->CopyFrom(
grad_def.device_option());
}
for (int j = 0; j < grads_count[grad_name]; j++) {
auto name = grad_name + "_autosplit_" + str::to(j);
if (split_grads.count(name)) gather_op.add_input(name);
}
gather_ops.push_back(gather_op);
}
*grad_def.mutable_output(i) = grad_name_v2;
}
}
}
// Add gradient ops
if (!is_skip) {
for (int i = 0; i < op.input_size(); ++i) {
inputs_to_grads_[op.input(i)] = pack.grad_inputs[i];
}
// Add ``GradientGenerateOp``
if (gen_grads.size() > 0) {
vector<string> inputs, outputs;
Argument arg_defaults;
arg_defaults.set_name("defaults");
for (auto& gen_grad : gen_grads) {
inputs.push_back(gen_grad.first);
outputs.emplace_back(gen_grad.first + "_grad");
arg_defaults.add_floats(pack.defaults[gen_grad.second]);
}
auto gen_op = MakeOperatorDef(
"GradientGenerate",
GetOperatorName(),
inputs,
outputs,
vector<Argument>({arg_defaults}));
if (op.has_device_option()) {
gen_op.mutable_device_option()->CopyFrom(op.device_option());
}
graph.add_op()->CopyFrom(gen_op);
}
// Add ``GradientOp``
for (const auto& grad_def : pack.grad_defs) {
graph.add_op()->CopyFrom(grad_def);
}
}
// Add ``GradientGatherOp``
for (const auto& gather_op : gather_ops) {
graph.add_op()->CopyFrom(gather_op);
}
}
}
GraphDef GraphGradientMaker::Optimize(const GraphDef& graph) {
Set<int> invalid_ops;
Map<string, int> ref_count;
Map<string, pair<int, string>> gather_map;
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx);
if (!str::find(op.type(), "Gradient")) continue;
// Flag the gathering gradients
if (op.type() == "GradientGather") {
invalid_ops.insert(op_idx);
if (empty_grads_.count(op.output(0))) {
for (const auto& input : op.input()) {
empty_grads_.insert(input);
}
continue;
} else {
string first_input;
for (const auto& input : op.input()) {
if (!input.empty()) {
if (first_input.empty()) first_input = input;
gather_map[input] = {op_idx, first_input};
}
}
}
}
// Count the references to detect leafs
for (const auto& input : op.input()) {
if (str::endswith(input, "_grad")) {
ref_count[input] += 1;
}
}
}
// Decompose the <GradientGather> into <GradientAdd>
// This trick accumulates the split to target right after computing,
// which helps to reduce the total number of buffers.
auto graph_v2(graph);
graph_v2.clear_op();
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
if (invalid_ops.count(op_idx)) continue;
const auto& op = graph.op(op_idx);
graph_v2.add_op()->CopyFrom(op);
if (!str::find(op.type(), "Gradient")) continue;
for (const auto& output : op.output()) {
const auto& find_iter = gather_map.find(output);
if (find_iter != gather_map.end()) {
const auto& gather_op = graph.op(find_iter->second.first);
auto add_op(gather_op);
add_op.clear_input();
if (output != find_iter->second.second) {
add_op.set_type("GradientAdd");
// Make an in-place to avoid a new buffer
add_op.add_input(gather_op.output(0));
const auto& ref_iter = ref_count.find(gather_op.output(0));
if (ref_iter != ref_count.end()) ref_iter->second++;
}
add_op.add_input(output);
graph_v2.add_op()->CopyFrom(add_op);
}
}
}
// Prepare the pool
int buffer_idx = 0;
std::deque<string> pool;
Map<string, string> grad_to_buffer;
auto get_buffer = [&]() mutable {
if (pool.empty()) {
return "/share/buffer/grad:" + str::to(buffer_idx++);
} else {
/*!
* LIFO is more memory efficient than FIFO usually,
* Because the larger gradients will bring out later.
*
* Memory distribution turns out to be uniform,
* if the early temporary tensors are selected prior.
*/
auto buffer = pool.back();
pool.pop_back();
return buffer;
}
};
for (int op_idx = 0; op_idx < graph_v2.op_size(); ++op_idx) {
auto* op = graph_v2.mutable_op(op_idx);
// Ignore the non-gradient ops
if (!str::find(op->type(), "Gradient")) continue;
// Check if output is an alias of input
vec32_t inplace_flags;
for (int i = 0; i < op->output_size(); ++i) {
int flag = -1;
for (int j = 0; j < op->input_size(); ++j)
if (op->output(i) == op->input(j)) {
flag = j;
break;
}
inplace_flags.emplace_back(flag);
}
// Besides, we need to collect the dead buffers
// Reuse them when current operator is done
vector<string> dead_buffers;
// Rewrite input gradients
for (int i = 0; i < op->input_size(); ++i) {
const string& in = op->input(i);
if (ref_count.count(in) > 0) {
ref_count[in] -= 1; // Decref
if (grad_to_buffer.count(in) == 0) continue;
string in_v2 = grad_to_buffer[in];
if (ref_count[in] == 0) {
dead_buffers.emplace_back(in_v2);
}
*op->mutable_input(i) = in_v2;
}
}
// Rewrite output gradients
for (int i = 0; i < op->output_size(); ++i) {
if (str::startswith(op->type(), "Python")) continue;
const string& out = op->output(i);
if (out.empty() || str::startswith(out, "/share/buffer")) continue;
if (empty_grads_.count(out) > 0) {
*op->mutable_output(i) = "";
continue;
}
// Protection for leafs
if (ref_count.count(out) == 0) continue;
// Protection for sources and leafs
if (retained_grads_.count(out) > 0) continue;
string out_v2 = out;
if (inplace_flags[i] >= 0) {
out_v2 = op->input(inplace_flags[i]);
} else {
grad_to_buffer[out] = out_v2 = get_buffer();
}
*op->mutable_output(i) = out_v2;
}
// Update the pool
for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer);
}
}
return graph_v2;
}
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_GRAPH_GRADIENT_H_
#define DRAGON_CORE_GRAPH_GRADIENT_H_
#include "dragon/core/common.h"
namespace dragon {
class DRAGON_API GraphGradientMaker {
public:
/*! \brief Generate graph from the executed ops */
void Make(
const vector<OperatorDef*>& ops,
const vector<string>& targets,
const vector<string>& input_grads,
GraphDef& graph);
/*! \brief Eliminate the unused and make sharing of outputs */
GraphDef Optimize(const GraphDef& graph);
/*! \brief Add an empty gradient */
void add_empty_grad(const string& name) {
empty_grads_.insert(name);
}
/*! \brief Add a retained gradient */
void add_retained_grad(const string& name) {
retained_grads_.insert(name);
}
/*! \brief Set the prefix of backward op name */
void set_op_prefix(const string& prefix) {
op_prefix_ = prefix;
}
private:
/*! \brief Check the missing grads */
bool CheckGrad(
const OperatorDef& op,
const Set<string>& targets,
vector<pair<string, int>>& gen_grads);
/*! \brief Return a dummy operator name */
string GetOperatorName() {
if (op_prefix_.empty()) return "GradientOp";
return op_prefix_ + str::to(op_idx_++);
}
/*! \brief The mapping from input to grad */
Map<string, string> inputs_to_grads_;
/*! \brief The gradients should be retained */
Set<string> retained_grads_;
/*! \brief The gradients should be set to empty */
Set<string> empty_grads_;
/*! \brief The prefix of op name */
string op_prefix_;
/*! \brief The counter of op name */
int64_t op_idx_ = 0;
};
} // namespace dragon
#endif // DRAGON_CORE_GRAPH_GRADIENT_H_
#include "dragon/core/graph_optimizer.h" #include "dragon/core/graph_optimizer.h"
#include "dragon/core/graph_gradient.h"
#include "dragon/core/operator_schema.h" #include "dragon/core/operator_schema.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
...@@ -9,11 +8,11 @@ namespace dragon { ...@@ -9,11 +8,11 @@ namespace dragon {
void GraphOptimizer::BuildDAG(const GraphDef& graph) { void GraphOptimizer::BuildDAG(const GraphDef& graph) {
nodes_.clear(); nodes_.clear();
reference_count_.clear(); inputs_count_.clear();
for (int i = 0; i < graph.op_size(); ++i) { for (int i = 0; i < graph.op_size(); ++i) {
const auto& op = graph.op(i); const auto& op = graph.op(i);
for (const auto& in : op.input()) { for (const auto& in : op.input()) {
reference_count_[in] += 1; inputs_count_[in] += 1;
} }
for (const auto& out : op.output()) { for (const auto& out : op.output()) {
if (out.empty()) continue; if (out.empty()) continue;
...@@ -32,95 +31,24 @@ void GraphOptimizer::BuildDAG(const GraphDef& graph) { ...@@ -32,95 +31,24 @@ void GraphOptimizer::BuildDAG(const GraphDef& graph) {
} }
} }
GraphDef GraphOptimizer::EliminateUnused(const GraphDef& graph) {
// Initialization
BuildDAG(graph);
used_.clear();
// Eliminate the unused nodes
for (const auto& out : graph.output()) {
EliminateUnusedNode(out);
}
for (const auto& grad_info : graph.grad_info()) {
const auto grad_y = grad_info.y() + "_grad";
for (const auto& x : grad_info.xs()) {
visited_.clear();
EliminateUnusedNode(grad_y, x + "_grad");
}
}
// Select the used operators
set<int> selected_op_indices;
for (auto it : used_) {
if (nodes_[it.first].op_idx == -1) continue;
selected_op_indices.insert(nodes_[it.first].op_idx);
}
// Prepare the registered placeholders
Set<string> outputs;
for (const auto& name : ws_->tensors()) {
outputs.insert(name);
}
// Rewrite graph
GraphDef graph_v2(graph);
graph_v2.clear_op();
for (auto op_idx : selected_op_indices) {
const auto& op = graph.op(op_idx);
auto* op_v2 = graph_v2.add_op();
op_v2->CopyFrom(op);
// Rewrite inputs
for (int i = 0; i < op.input_size(); ++i) {
const auto& in = op.input(i);
if (!used_[in] || outputs.count(in) == 0) {
*op_v2->mutable_input(i) = "";
}
}
// Rewrite outputs
for (int i = 0; i < op.output_size(); ++i) {
const auto& out = op.output(i);
if (!used_[out]) {
*op_v2->mutable_output(i) = "";
} else {
outputs.insert(out);
}
}
// Rewrite hand-craft cases
if (op.type() == "AffineGradient") {
if (op_v2->output(1).empty()) *op_v2->mutable_input(0) = "";
} else if (op.type() == "MulGradient") {
if (op_v2->output(0).empty()) *op_v2->mutable_input(1) = "";
if (op_v2->output(1).empty()) *op_v2->mutable_input(0) = "";
} else if (op.type() == "DivGradient") {
if (op_v2->output(1).empty()) {
*op_v2->mutable_input(0) = "";
if (op_v2->output(0).empty()) *op_v2->mutable_input(1) = "";
}
}
}
return graph_v2;
}
void GraphOptimizer::PlanInplace( void GraphOptimizer::PlanInplace(
const GraphDef& graph, const GraphDef& graph,
Map<string, Set<string>>& output_aliases) { Map<string, Set<string>>& output_aliases) {
// Initialization // Initialization.
BuildDAG(graph); BuildDAG(graph);
// Generate aliases map to apply in-place.
// Generate aliases map to apply in-place for (const auto& iter : inputs_count_) {
for (const auto& iter : reference_count_) { if (iter.second > 1 || iter.first.empty()) continue;
const auto& in = iter.first; const auto& input = iter.first;
if (iter.second == 1 && !in.empty() && nodes_[in].childs.size() > 0) { const auto& input_node = nodes_[input];
const auto& op = nodes_[nodes_[in].childs[0]].op_def; if (input_node.childs.empty() || input_node.parents.empty()) continue;
const auto* schema = OpSchemaRegistry::Schema(op.type()); const auto& op = nodes_[input_node.childs[0]].op_def;
for (int i = 0; i < op.input_size(); ++i) { const auto* schema = OpSchemaRegistry::Schema(op.type());
if (op.input(i) == in) { for (int i = 0; i < op.input_size(); ++i) {
for (int j = 0; j < op.output_size(); ++j) { if (op.input(i) != input) continue;
if (schema->CheckInplace(i, j)) { for (int j = 0; j < op.output_size(); ++j) {
output_aliases[op.output(j)].insert(in); if (!schema->CheckInplace(i, j)) continue;
} output_aliases[op.output(j)].insert(input);
}
}
} }
} }
} }
...@@ -134,7 +62,7 @@ GraphDef GraphOptimizer::PlanCheckpoint( ...@@ -134,7 +62,7 @@ GraphDef GraphOptimizer::PlanCheckpoint(
Map<string, string> rename_map; Map<string, string> rename_map;
Map<string, int> versions; Map<string, int> versions;
// Check the mirror stage setting // Check the mirror stage setting.
for (const auto& op : graph.op()) { for (const auto& op : graph.op()) {
if (str::find(op.type(), "Gradient")) continue; if (str::find(op.type(), "Gradient")) continue;
bool mirror_stage = false; bool mirror_stage = false;
...@@ -144,12 +72,12 @@ GraphDef GraphOptimizer::PlanCheckpoint( ...@@ -144,12 +72,12 @@ GraphDef GraphOptimizer::PlanCheckpoint(
} }
} }
if (mirror_stage) { if (mirror_stage) {
// We only assume X(0) can be recomputed // We only assume X(0) can be recomputed.
rename_map[op.input(0)] = "placeholder"; rename_map[op.input(0)] = "placeholder";
} }
} }
// Allocate the temporal buffers // Allocate the temporal buffers.
string v2_name, version_name; string v2_name, version_name;
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx); const auto& op = graph.op(op_idx);
...@@ -173,7 +101,7 @@ GraphDef GraphOptimizer::PlanCheckpoint( ...@@ -173,7 +101,7 @@ GraphDef GraphOptimizer::PlanCheckpoint(
continue; continue;
} }
for (int j = 0; j < GRAPH_TEMPORAL_OUTPUT_MAX_SIZE; ++j) { for (int j = 0; j < GRAPH_TEMPORAL_OUTPUT_MAX_SIZE; ++j) {
v2_name = "/share/buffer/symbol:" + str::to(j); v2_name = "shared/buffer/output:" + str::to(j);
for (const auto& buffer : used_buffers) for (const auto& buffer : used_buffers)
if (str::find(buffer, v2_name)) { if (str::find(buffer, v2_name)) {
v2_name.clear(); v2_name.clear();
...@@ -221,18 +149,18 @@ GraphDef GraphOptimizer::PlanCheckpoint( ...@@ -221,18 +149,18 @@ GraphDef GraphOptimizer::PlanCheckpoint(
return graph_v2; return graph_v2;
} }
GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) { GraphDef GraphOptimizer::EliminateIntermediates(const GraphDef& graph) {
Set<string> blacklist = {""}; Set<string> required_outputs;
Map<string, int> ref_count; Map<string, int> inputs_count;
Map<string, string> rename_map; Map<string, string> outputs_to_buffers;
static Set<string> star_ops = {"Shape"}; static Set<string> skip_ops = {"Shape"};
// Prepare the pool // Prepare pool.
int buffer_idx = 0; int buffer_idx = 0;
std::deque<string> pool; std::deque<string> pool;
auto get_buffer = [&]() mutable { auto get_buffer = [&]() mutable {
if (pool.empty()) { if (pool.empty()) {
return "/share/buffer/output:" + str::to(buffer_idx++); return "shared/buffer/output:" + str::to(++buffer_idx);
} else { } else {
auto buffer = pool.back(); auto buffer = pool.back();
pool.pop_back(); pool.pop_back();
...@@ -240,56 +168,58 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) { ...@@ -240,56 +168,58 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) {
} }
}; };
// Count the references // Count inputs.
for (const auto& op : graph.op()) { for (const auto& op : graph.op()) {
for (const auto& in : op.input()) { for (const auto& input : op.input()) {
ref_count[in] += 1; inputs_count[input] += 1;
} }
} }
// Preserve the graph outputs // Initialize the required outputs before optimization.
for (auto& out : graph.output()) { for (const auto& output : graph.output()) {
blacklist.insert(out); required_outputs.insert(output);
} }
// Rewrite the inputs and outputs // Rewrite the inputs and outputs.
auto graph_v2(graph); auto graph_v2(graph);
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx); const auto& op = graph.op(op_idx);
auto* op_v2 = graph_v2.mutable_op(op_idx);
// Ignore the init ops
if (op.input_size() == 0) continue; if (op.input_size() == 0) continue;
// We need to collect the dead buffers. auto* op_v2 = graph_v2.mutable_op(op_idx);
// Reuse them when current operator is done. // Check output aliases.
vec32_t output_aliases(op.output_size(), -1);
for (int i = 0; i < op.output_size(); ++i) {
for (int j = 0; j < op.input_size(); ++j) {
if (op.output(i) != op.input(j)) continue;
output_aliases[i] = j;
break;
}
}
// Rewrite inputs.
vector<string> dead_buffers; vector<string> dead_buffers;
// Rewrite inputs
for (int i = 0; i < op.input_size(); ++i) { for (int i = 0; i < op.input_size(); ++i) {
const auto& name = op.input(i); const auto& input = op.input(i);
if (rename_map.count(name)) { const auto& count_iter = inputs_count.find(input);
*op_v2->mutable_input(i) = rename_map[name]; count_iter->second--;
} const auto& buffer_iter = outputs_to_buffers.find(input);
ref_count[name]--; if (buffer_iter == outputs_to_buffers.end()) continue;
if (ref_count[name] == 0 && if (count_iter->second == 0) {
str::startswith(op_v2->input(i), "/share/buffer/output:")) { dead_buffers.emplace_back(buffer_iter->second);
dead_buffers.push_back(op_v2->input(i));
} }
op_v2->set_input(i, buffer_iter->second);
} }
// Rewrite outputs if (skip_ops.count(op.type())) continue;
if (!star_ops.count(op.type())) { // Rewrite outputs.
for (int i = 0; i < op.output_size(); ++i) { for (int i = 0; i < op.output_size(); ++i) {
const auto& name = op.output(i); const auto& output = op.output(i);
bool inplace_flag = false; if (output.empty() || required_outputs.count(output) > 0) continue;
if (blacklist.count(name)) continue; if (output_aliases[i] >= 0) {
for (const auto& input : op.input()) op_v2->set_output(i, op_v2->input(output_aliases[i]));
if (name == input) inplace_flag = true; } else {
if (inplace_flag) { *op_v2->mutable_output(i) = outputs_to_buffers[output] = get_buffer();
*op_v2->mutable_output(i) = op_v2->input(i);
} else {
rename_map[name] = *op_v2->mutable_output(i) = get_buffer();
}
} }
} }
// Update the pool // Update pool.
for (auto& buffer : dead_buffers) { for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer); pool.emplace_back(buffer);
} }
...@@ -297,36 +227,4 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) { ...@@ -297,36 +227,4 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) {
return graph_v2; return graph_v2;
} }
void GraphOptimizer::EliminateUnusedNode(
const string& source,
const string& sink) {
if (visited_.count(source)) return;
visited_[source] = false;
for (const auto& next : nodes_[source].childs) {
if (next == sink) {
visited_[next] = used_[next] = true;
visited_[source] = used_[source] = true;
return;
}
EliminateUnusedNode(next, sink);
if (visited_[next]) {
visited_[source] = used_[source] = true;
}
}
}
void GraphOptimizer::EliminateUnusedNode(const string& sink) {
std::queue<const string*> q;
q.push(&sink);
while (!q.empty()) {
const auto& source = *q.front();
q.pop();
used_[source] = true;
for (const auto& last : nodes_[source].parents) {
if (used_.count(last)) continue;
q.push(&last);
}
}
}
} // namespace dragon } // namespace dragon
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!