Commit f431756f by Ting PAN

Update with the new frontend API

Summary:
The new frontend makes an union of two execution modes, while starts from
a single tensor class. Besides, it emits the operator execution through
a common path that works both for dragon and torch.
1 parent 6bfe3e73
Showing with 1652 additions and 1568 deletions
......@@ -16,10 +16,8 @@ from __future__ import print_function
import numpy
from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context as eager_context
from dragon.core.framework import context
from dragon.core.util import logging
from dragon.core.autograph import context
from dragon.core.framework.tensor import Tensor
from dragon.vm.caffe.core.proto import caffe_pb2
......@@ -36,20 +34,10 @@ class Layer(object):
"""
self._proto = layer_param
self._name = layer_param.name
self._arguments, self.arguments = {'name': 'output'}, {}
# Store the inputs, outputs and trainable parameters.
self._bottom, self._top, self._blobs = [], [], []
for blob in layer_param.bottom:
self._bottom.append(blob)
for blob in layer_param.top:
self._top.append(blob)
# Store the loss weight to apply gradients.
self._loss_weight = layer_param.loss_weight \
if len(layer_param.loss_weight) > 0 else None
# Optional mirror stage argument for memory optimization.
if layer_param.HasField('mirror_stage'):
self._arguments['mirror_stage'] = layer_param.mirror_stage
self._bottom_names = [name for name in layer_param.bottom]
self._top_names = [name for name in layer_param.top]
self._blobs = []
self._call_layer = None
@property
def blobs(self):
......@@ -59,36 +47,37 @@ class Layer(object):
@property
def bottom(self):
"""Return the bottom names."""
return self._bottom
@property
def loss_weight(self):
"""Return the loss weight."""
return self._loss_weight
return self._bottom_names
@property
def name(self):
"""Return the layer name."""
return self._name
return self._proto.name
@property
def top(self):
"""Return the top names."""
return self._top
return self._top_names
def add_blob(self, value=None, filler=None, no_grad=False):
def add_blob(self, shape, filler, requires_grad=True):
"""Add a blob into this layer."""
# Set the name for reference explicitly.
data_name = context.get_name_scope() + 'param:{}'.format(len(self._blobs))
data, diff = TensorRef(data_name), TensorRef(data_name + '_grad')
if filler is not None:
data._register_as(**filler)
data = Tensor(shape, name='blob%d' % (len(self._blobs) + 1))
if filler.type == 'constant':
data.fill(filler.value)
elif filler.type == 'gaussian':
data.normal(filler.mean, filler.std)
elif filler.type == 'uniform':
data.uniform(filler.min, filler.max)
elif filler.type == 'xavier':
norm_modes = {0: 'fan_in', 1: 'fan_out', 2: 'fan_avg'}
data.glorot_uniform(norm_modes[filler.variance_norm])
elif filler.type == 'msra':
norm_modes = {0: 'fan_in', 1: 'fan_out', 2: 'fan_avg'}
data.glorot_normal(norm_modes[filler.variance_norm])
else:
# Register a constant filler by default.
value = value if value else 0
data.constant(value=value)
# Append to the blobs.
self._blobs.append({'data': data, 'diff': None if no_grad else diff})
raise ValueError('Unknown filler type: ' + filler.type)
data.requires_grad = requires_grad
self._blobs.append({'data': data, 'diff': None})
def from_proto(self, proto):
"""Deserialize from the proto.
......@@ -110,16 +99,14 @@ class Layer(object):
raise ValueError('Neither <data> or <double_data> in blob proto.')
if len(blob_proto.shape.dim) > 0:
value = value.reshape([dim for dim in blob_proto.shape.dim])
self._blobs[i]['data'].set_value(value)
logging.info('Blob({}/param:{}) loaded, shape: {}, size: {}'
.format(self._name, i, value.shape, value.size))
self._blobs[i]['data']._impl.FromNumpy(value, False)
def setup(self, bottom):
"""Setup the layer."""
self.arguments = dict(self.arguments, **self._arguments)
bottom = bottom[0] if len(bottom) == 1 else bottom
with eager_context.graph_mode():
return self.__call__(bottom)
with context.graph_mode():
call_layer = self._call_layer or self
return call_layer.__call__(bottom)
def to_proto(self):
"""Serialize to the proto.
......@@ -133,7 +120,7 @@ class Layer(object):
proto = caffe_pb2.LayerParameter()
proto.CopyFrom(self._proto)
for blob in self._blobs:
value = blob['data'].get_value()
value = blob['data'].numpy()
if str(value.dtype) == 'float32':
blob_proto = caffe_pb2.BlobProto(
data=value.flatten(),
......@@ -147,21 +134,6 @@ class Layer(object):
proto.blobs.extend([blob_proto])
return proto
@staticmethod
def get_filler(proto, filler_name):
"""Return the filler from proto."""
if proto.HasField(filler_name):
filler = getattr(proto, filler_name)
return {
'type': filler.type.lower(),
'value': filler.value,
'low': filler.min,
'high': filler.max,
'mean': filler.mean,
'std': filler.std,
}
return None
def __call__(self, bottom):
"""Define the forward pipeline."""
raise NotImplementedError
......@@ -48,7 +48,5 @@ from dragon.vm.caffe.core.layers.vision import Convolution
from dragon.vm.caffe.core.layers.vision import Deconvolution
from dragon.vm.caffe.core.layers.vision import LRN
from dragon.vm.caffe.core.layers.vision import Pooling
from dragon.vm.caffe.core.layers.vision import ROIAlign
from dragon.vm.caffe.core.layers.vision import ROIPooling
__all__ = [_s for _s in dir() if not _s.startswith('_')]
......@@ -23,18 +23,18 @@ from dragon.vm.caffe.core.layer import Layer
class _DataPlugin(object):
"""Embedded plugin for **Data** layer."""
"""Embedded plugin for data layer."""
def setup(self, inputs, outputs):
kwargs = eval(self.kwargs_str)
self.iterator = vision.DataIterator(
dataset=KPLRecordDataset, **kwargs)
default_ws = workspace.get_workspace()
self.outputs = [default_ws.get_tensor(output) for output in outputs]
self.iterator = vision.DataIterator(dataset=KPLRecordDataset, **kwargs)
def forward(self, inputs, outputs):
blobs = self.iterator.next()
current_ws = workspace.get_workspace()
for i, blob in enumerate(blobs):
current_ws.feed_tensor(outputs[i], blob)
self.outputs[i].FromNumpy(blob)
class Data(Layer):
......@@ -118,8 +118,8 @@ class Data(Layer):
'num_outputs': 2,
}
data, label = framework_ops.python_plugin([], **args)
data.shape = (self.data_args['batch_size'],
data._shape = (self.data_args['batch_size'],
None, None, len(self.norm_args['mean']))
label.shape = (self.data_args['batch_size'], None)
label._shape = (self.data_args['batch_size'], None)
data = array_ops.channel_normalize(data, **self.norm_args)
return data, label
......@@ -51,16 +51,17 @@ class EuclideanLoss(Layer):
reduction = 'mean'
else:
reduction = norm_dict[param.normalization]
self.arguments = {'reduction': reduction}
self.call_args = {'reduction': reduction}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom):
loss = loss_ops.l2_loss(bottom, **self.arguments)
loss = loss_ops.l2_loss(bottom, **self.call_args)
loss_weight = 1. if self.loss_weight is None else self.loss_weight
return loss * (loss_weight * 0.5)
class SigmoidCrossEntropyLoss(Layer):
r"""Compute the sigmoid cross entropy with contiguous targets.
"""Compute the loss of sigmoid cross entropy.
Examples:
......@@ -88,11 +89,12 @@ class SigmoidCrossEntropyLoss(Layer):
reduction = 'batch_mean'
else:
reduction = norm_dict[param.normalization]
self.arguments = {'reduction': reduction}
self.call_args = {'reduction': reduction}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom):
loss = loss_ops.sigmoid_cross_entropy(bottom, **self.arguments)
if self.loss_weight is not None:
loss = loss_ops.sigmoid_cross_entropy_loss(bottom, **self.call_args)
if self.loss_weight != 1:
loss *= self.loss_weight
return loss
......@@ -131,24 +133,18 @@ class SmoothL1Loss(Layer):
else:
reduction = norm_dict[param.normalization]
sigma2 = smooth_l1_param.sigma * smooth_l1_param.sigma
self.arguments = {
'beta': float(1. / sigma2),
'reduction': reduction,
}
self.call_args = {'beta': float(1. / sigma2), 'reduction': reduction}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom):
loss = loss_ops.smooth_l1_loss(bottom, **self.arguments)
if self.loss_weight is not None:
loss = loss_ops.smooth_l1_loss(bottom, **self.call_args)
if self.loss_weight != 1:
loss *= self.loss_weight
return loss
class SoftmaxWithLoss(Layer):
r"""Compute the softmax cross entropy with sparse labels.
The **CrossEntropy** function is defined as:
.. math:: \text{CrossEntropy}(p_{t}) = -\log(p_{t})
"""Compute the loss of softmax cross entropy.
Examples:
......@@ -181,15 +177,16 @@ class SoftmaxWithLoss(Layer):
reduction = 'batch_mean'
else:
reduction = norm_dict[param.normalization]
self.arguments = {
self.call_args = {
'axis': softmax_param.axis,
'reduction': reduction,
'ignore_index': param.ignore_label
if param.HasField('ignore_label') else None,
}
self.loss_weight = (layer_param.loss_weight or [1])[0]
def __call__(self, bottom):
loss = loss_ops.sparse_softmax_cross_entropy(bottom, **self.arguments)
if self.loss_weight is not None:
loss = loss_ops.softmax_cross_entropy_loss(bottom, **self.call_args)
if self.loss_weight != 1:
loss *= self.loss_weight
return loss
......@@ -17,6 +17,7 @@ from __future__ import print_function
from dragon.core.ops import activation_ops
from dragon.core.ops import math_ops
from dragon.vm.caffe.core.layer import Layer
from dragon.vm.caffe.core.proto import caffe_pb2
class Dropout(Layer):
......@@ -47,10 +48,10 @@ class Dropout(Layer):
param = layer_param.dropout_param
if not param.scale_train:
raise ValueError('Unscaled dropout is not supported.')
self.arguments = {'ratio': param.dropout_ratio}
self.call_args = {'ratio': param.dropout_ratio}
def __call__(self, bottom):
return activation_ops.dropout(bottom, **self.arguments)
return activation_ops.dropout(bottom, **self.call_args)
class ELU(Layer):
......@@ -83,10 +84,10 @@ class ELU(Layer):
def __init__(self, layer_param):
super(ELU, self).__init__(layer_param)
self.arguments = {'alpha': float(layer_param.elu_param.alpha)}
self.call_args = {'alpha': float(layer_param.elu_param.alpha)}
def __call__(self, bottom):
return activation_ops.elu(bottom, **self.arguments)
return activation_ops.elu(bottom, **self.call_args)
class Power(Layer):
......@@ -123,7 +124,7 @@ class Power(Layer):
bottom = bottom * self.scale
if self.shift != 0:
bottom = bottom + self.shift
return math_ops.pow([bottom, self.power], **self.arguments)
return math_ops.pow([bottom, self.power])
class PReLU(Layer):
......@@ -163,15 +164,24 @@ class PReLU(Layer):
def __init__(self, layer_param):
super(PReLU, self).__init__(layer_param)
param = layer_param.prelu_param
self.arguments = {
'channel_shared': param.channel_shared,
'data_format': 'NCHW',
}
self.add_blob(filler=self.get_filler(param, 'filler'), value=0.25)
self.filler = caffe_pb2.FillerParameter(type='constant', value=0.25)
self.filler = param.filler if param.HasField('filler') else self.filler
self.channel_shared = param.channel_shared
def build(self, bottom):
if self.channel_shared:
weight_shape = [1]
elif len(bottom.shape) > 1:
weight_shape = [bottom.shape[1]]
else:
weight_shape = [bottom.shape[0]]
self.add_blob(weight_shape, self.filler)
def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs]
return activation_ops.prelu(inputs, **self.arguments)
return activation_ops.prelu(inputs)
class ReLU(Layer):
......@@ -205,11 +215,12 @@ class ReLU(Layer):
def __init__(self, layer_param):
super(ReLU, self).__init__(layer_param)
param = layer_param.relu_param
if param.HasField('negative_slope'):
self.arguments = {'alpha': param.negative_slope}
self.negative_slope = param.negative_slope
def __call__(self, bottom):
return activation_ops.relu(bottom, **self.arguments)
if self.negative_slope > 0:
return activation_ops.leaky_relu(bottom, self.negative_slope)
return activation_ops.relu(bottom)
class Sigmoid(Layer):
......@@ -235,7 +246,7 @@ class Sigmoid(Layer):
super(Sigmoid, self).__init__(layer_param)
def __call__(self, bottom):
return activation_ops.sigmoid(bottom, **self.arguments)
return activation_ops.sigmoid(bottom)
class TanH(Layer):
......@@ -261,4 +272,4 @@ class TanH(Layer):
super(TanH, self).__init__(layer_param)
def __call__(self, bottom):
return activation_ops.tanh(bottom, **self.arguments)
return activation_ops.tanh(bottom)
......@@ -53,33 +53,39 @@ class Convolution(Layer):
def __init__(self, layer_param):
super(Convolution, self).__init__(layer_param)
param = layer_param.convolution_param
self.arguments = {
'out_channels': param.num_output,
'kernel_shape': [int(e) for e in param.kernel_size],
'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1],
'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0],
'dilations': [int(e) for e in param.dilation] if len(param.dilation) > 0 else [1],
'group': int(param.group),
'padding': 'VALID',
'data_format': 'NCHW',
}
if param.HasField('kernel_h'):
assert param.HasField('kernel_w')
self.arguments['kernel_shape'] = [param.kernel_h, param.kernel_w]
if param.HasField('stride_h'):
assert param.HasField('stride_w')
self.arguments['strides'] = [param.stride_h, param.stride_w]
if param.HasField('pad_h'):
assert param.HasField('pad_w')
self.arguments['pads'] = [param.pad_h, param.pad_w]
self.add_blob(filler=self.get_filler(param, 'weight_filler'))
if param.bias_term:
self.add_blob(filler=self.get_filler(param, 'bias_filler'))
self.kernel_shape = param.kernel_size or [1]
self.strides = param.stride or [1]
self.pads = param.pad or [0]
self.dilations = param.dilation or [1]
self.out_channels = param.num_output
self.weight_filler = param.weight_filler
self.bias_filler = param.bias_filler
self.bias_term = param.bias_term
self.call_args = {'group': param.group}
def build(self, bottom):
num_axes = len(bottom.shape) - 2
if num_axes < 1:
raise ValueError(
'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
.format(self.name, len(bottom.shape)))
in_channels = bottom.shape[1]
for k in ('kernel_shape', 'strides', 'pads', 'dilations'):
self.call_args[k] = [int(d) for d in getattr(self, k)]
if len(self.call_args[k]) < num_axes:
reps = num_axes - len(self.call_args[k])
self.call_args[k] += [self.call_args[k][-1]] * reps
weight_shape = [self.out_channels, in_channels] + self.call_args['kernel_shape']
self.add_blob(weight_shape, self.weight_filler)
if self.bias_term:
self.add_blob([self.out_channels], self.bias_filler)
def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs]
conv_op = 'conv{}d'.format(len(bottom.shape) - 2)
return getattr(vision_ops, conv_op)(inputs, **self.arguments)
conv_op = 'conv{}d'.format(len(self.call_args['kernel_shape']))
return getattr(vision_ops, conv_op)(inputs, **self.call_args)
class Deconvolution(Convolution):
......@@ -116,10 +122,29 @@ class Deconvolution(Convolution):
def __init__(self, layer_param):
super(Deconvolution, self).__init__(layer_param)
def build(self, bottom):
num_axes = len(bottom.shape) - 2
if num_axes < 1:
raise ValueError(
'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
.format(self.name, len(bottom.shape)))
in_channels = bottom.shape[1]
for k in ('kernel_shape', 'strides', 'pads', 'dilations'):
self.call_args[k] = [int(d) for d in getattr(self, k)]
if len(self.call_args[k]) < num_axes:
reps = num_axes - len(self.call_args[k])
self.call_args[k] += [self.call_args[k][-1]] * reps
weight_shape = [in_channels, self.out_channels] + self.call_args['kernel_shape']
self.add_blob(weight_shape, self.weight_filler)
if self.bias_term:
self.add_blob([self.out_channels], self.bias_filler)
def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self._blobs]
conv_op = 'conv{}d_transpose'.format(len(bottom.shape) - 2)
return getattr(vision_ops, conv_op)(inputs, **self.arguments)
conv_op = 'conv{}d_transpose'.format(len(self.call_args['kernel_shape']))
return getattr(vision_ops, conv_op)(inputs, **self.call_args)
class LRN(Layer):
......@@ -148,17 +173,14 @@ class LRN(Layer):
super(LRN, self).__init__(layer_param)
param = layer_param.lrn_param
if param.norm_region > 0:
raise NotImplementedError('WITHIN_CHANNEL mode is not implemented.')
self.arguments = {
'size': param.local_size,
raise NotImplementedError('<WITHIN_CHANNEL> mode is not implemented.')
self.op_args = {'size': param.local_size,
'alpha': param.alpha,
'beta': param.beta,
'bias': param.k,
'data_format': 'NCHW',
}
'bias': param.k}
def __call__(self, bottom):
return normalization_ops.local_response_norm(bottom, **self.arguments)
return normalization_ops.local_response_norm(bottom, **self.op_args)
class Pooling(Layer):
......@@ -184,93 +206,26 @@ class Pooling(Layer):
def __init__(self, layer_param):
super(Pooling, self).__init__(layer_param)
param = layer_param.pooling_param
self.arguments = {
self.kernel_shape = [param.kernel_size or 1]
self.strides = [param.stride or 1]
self.pads = [param.pad or 0]
self.call_args = {
'ceil_mode': True,
'mode': {0: 'MAX', 1: 'AVG'}[param.pool],
'data_format': 'NCHW',
'global_pool': param.global_pooling,
}
if not param.HasField('kernel_h'):
self.arguments['kernel_shape'] = [param.kernel_size]
else:
self.arguments['kernel_shape'] = [param.kernel_h, param.kernel_w]
if not param.HasField('pad_h'):
self.arguments['pads'] = [param.pad]
else:
self.arguments['pads'] = [param.pad_h, param.pad_w]
if not param.HasField('stride_h'):
self.arguments['strides'] = [param.stride]
else:
self.arguments['strides'] = [param.stride_h, param.stride_w]
def __call__(self, bottom):
pool_op = 'pool{}d'.format(len(bottom.shape) - 2)
return getattr(vision_ops, pool_op)(bottom, **self.arguments)
class ROIAlign(Layer):
r"""Apply the average roi align.
`[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
Examples:
```python
layer {
type: "ROIAlign"
bottom: "conv5_3"
top: "roi_pool4"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625
}
}
```
"""
def __init__(self, layer_param):
super(ROIAlign, self).__init__(layer_param)
param = layer_param.roi_pooling_param
self.arguments = {
'pool_h': int(param.pooled_h),
'pool_w': int(param.pooled_w),
'spatial_scale': param.spatial_scale,
}
def __call__(self, bottom):
return vision_ops.roi_align(bottom, **self.arguments)
class ROIPooling(Layer):
r"""Apply the max roi pooling.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
Examples:
```python
layer {
type: "ROIPooling"
bottom: "conv5_3"
top: "roi_pool4"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625
}
}
```
"""
def __init__(self, layer_param):
super(ROIPooling, self).__init__(layer_param)
param = layer_param.roi_pooling_param
self.arguments = {
'pool_h': int(param.pooled_h),
'pool_w': int(param.pooled_w),
'spatial_scale': param.spatial_scale,
}
def __call__(self, bottom):
return vision_ops.roi_pool(bottom, **self.arguments)
num_axes = len(bottom.shape) - 2
if num_axes < 1:
raise ValueError(
'Bottom 0 of layer "{}" is {}d, excepted 3d/4d/5d.'
.format(self.name, len(bottom.shape)))
call_args = self.call_args.copy()
for k in ('kernel_shape', 'strides', 'pads'):
call_args[k] = [int(d) for d in getattr(self, k)]
if len(call_args[k]) < num_axes:
reps = num_axes - len(call_args[k])
call_args[k] += [call_args[k][-1]] * reps
pool_op = 'pool{}d'.format(num_axes)
return getattr(vision_ops, pool_op)(bottom, **call_args)
......@@ -409,11 +409,10 @@ message LayerParameter {
optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138;
optional WindowDataParameter window_data_param = 129;
// Following parameters are extended on BVLC-Caffe.
optional ROIPoolingParameter roi_pooling_param = 151;
optional SmoothL1LossParameter smooth_l1_loss_param = 152;
optional PermuteParameter permute_param = 153;
optional NormalizeParameter normalize_param = 154;
// Following parameters are extended on BVLC/caffe.
optional SmoothL1LossParameter smooth_l1_loss_param = 151;
optional PermuteParameter permute_param = 152;
optional NormalizeParameter normalize_param = 153;
}
// Message that stores parameters used to apply transformation
......@@ -931,17 +930,6 @@ message PoolingParameter {
optional bool global_pooling = 12 [default = false];
}
// Message that stores parameters used by ROIPoolingLayer
message ROIPoolingParameter {
// Pad, kernel size, and stride are all given as a single value for equal
// dimensions in height and width or as Y, X pairs.
optional uint32 pooled_h = 1 [default = 0]; // The pooled output height
optional uint32 pooled_w = 2 [default = 0]; // The pooled output width
// Multiplicative spatial scale factor to translate ROI coords from their
// input scale to the scale used when pooling
optional float spatial_scale = 3 [default = 1];
}
message PowerParameter {
// PowerLayer computes outputs y = (shift + scale * x) ^ power.
optional float power = 1 [default = 1.0];
......
......@@ -57,7 +57,7 @@ def get_device_type(mixed=False):
Parameters
----------
mixed : bool, optional, default=False
**True** to return ``mixed`` for gpu device.
``True`` to return ``mixed`` for gpu device.
Returns
-------
......
......@@ -23,9 +23,9 @@ except ImportError:
TensorGPU = object
from dragon.core.device import cuda
from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import device_spec
from dragon.core.framework import workspace
from dragon.core.framework.tensor import Tensor
from dragon.vm.dali.core.framework import types
......@@ -152,7 +152,7 @@ class Iterator(object):
@staticmethod
def new_tensor(shape, dtype, device):
"""Return a new tensor abstraction."""
return EagerTensor(shape=shape, dtype=dtype, device=device)
return Tensor(shape=shape, dtype=dtype, device=device)
def __iter__(self):
"""Return the iterator self."""
......@@ -206,12 +206,12 @@ class Iterator(object):
def _transfer_tensor(self, dali_tensor, target_tensor):
"""Transfer the dali tensor to the target."""
target_shape = dali_tensor.shape()
device = target_tensor._device = \
self.new_device(
device = self.new_device(
device_type='cuda' if isinstance(
dali_tensor, TensorGPU) else 'cpu',
device_index=self._pipe.device_id,
)
device_index=self._pipe.device_id)
if hasattr(target_tensor, '_device'):
target_tensor._device = device
impl = target_tensor._impl
if target_shape != list(target_tensor.shape):
new_capacity = not impl.Reshape(target_shape)
......
......@@ -63,7 +63,7 @@ class ImageDecoder(object):
"""
if isinstance(output_type, six.string_types):
output_type = getattr(types, output_type)
return ops.ImageDecoder(
return ops.decoders.Image(
output_type=output_type,
host_memory_padding=host_memory_padding,
device_memory_padding=device_memory_padding,
......@@ -124,7 +124,7 @@ class ImageDecoderRandomCrop(object):
"""
if isinstance(output_type, six.string_types):
output_type = getattr(types, output_type)
return ops.ImageDecoderRandomCrop(
return ops.decoders.ImageRandomCrop(
output_type=output_type,
host_memory_padding=host_memory_padding,
device_memory_padding=device_memory_padding,
......
......@@ -316,7 +316,7 @@ class RandomBBoxCrop(object):
thresholds : Sequence[float], optional
The minimum IoU(s) to satisfy.
allow_no_crop : bool, optional, default=True
**True** to include the no-cropping as an option.
``True`` to include the no-cropping as an option.
num_attempts : int, optional, default=10
The max number of sampling trails.
bbox_layout : str, optional, default='xyXY'
......
......@@ -47,7 +47,7 @@ class CoinFlip(object):
The operator.
"""
return ops.CoinFlip(probability=probability, **kwargs)
return ops.random.CoinFlip(probability=probability, **kwargs)
class Uniform(object):
......@@ -76,4 +76,4 @@ class Uniform(object):
The operator.
"""
return ops.Uniform(range=range, **kwargs)
return ops.random.Uniform(range=range, **kwargs)
......@@ -218,7 +218,7 @@ class TFRecordReader(object):
"""
path, index_path, features = cls.check_files(path)
return ops.TFRecordReader(
return ops.readers.TFRecord(
path=path,
index_path=index_path,
shard_id=shard_id,
......
......@@ -15,6 +15,10 @@ Buffer
######
.. doxygenfunction:: dragon::Operator::Buffer
DeriveFrom
##########
.. doxygenfunction:: dragon::Operator::DeriveFrom
Fuse
####
.. doxygenfunction:: dragon::Operator::Fuse
......@@ -55,10 +59,6 @@ Run
###
.. doxygenfunction:: dragon::Operator::Run
UpdateFrom
##########
.. doxygenfunction:: dragon::Operator::UpdateFrom
data_format
###########
.. doxygenfunction:: dragon::Operator::data_format
......
......@@ -23,10 +23,6 @@ CreateTensor
############
.. doxygenfunction:: dragon::Workspace::CreateTensor
GetFillerInfo
#############
.. doxygenfunction:: dragon::Workspace::GetFillerInfo
GetTensor
#########
.. doxygenfunction:: dragon::Workspace::GetTensor
......@@ -39,14 +35,6 @@ MergeFrom
#########
.. doxygenfunction:: dragon::Workspace::MergeFrom
RegisterAlias
#############
.. doxygenfunction:: dragon::Workspace::RegisterAlias
ResetTensor
###########
.. doxygenfunction:: dragon::Workspace::ResetTensor
RunGraph
########
.. doxygenfunction:: dragon::Workspace::RunGraph
......@@ -55,6 +43,10 @@ RunOperator
###########
.. doxygenfunction:: dragon::Workspace::RunOperator
SetAlias
########
.. doxygenfunction:: dragon::Workspace::SetAlias
TryGetTensor
############
.. doxygenfunction:: dragon::Workspace::TryGetTensor
......
......@@ -29,10 +29,6 @@ params
Methods
-------
backward
########
.. automethod:: dragon.vm.caffe.Net.backward
copy_from
#########
.. automethod:: dragon.vm.caffe.Net.copy_from
......@@ -41,10 +37,6 @@ forward
#########
.. automethod:: dragon.vm.caffe.Net.forward
forward_backward
################
.. automethod:: dragon.vm.caffe.Net.forward_backward
save
####
.. automethod:: dragon.vm.caffe.Net.save
......
......@@ -88,14 +88,6 @@ vm.caffe.layers
`class Reshape <layers/Reshape.html>`_
: Change the dimensions of input.
`class ROIAlign <layers/ROIAlign.html>`_
: Apply the average roi align.
`[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
`class ROIPooling <layers/ROIPooling.html>`_
: Apply the max roi pooling.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
`class Scale <layers/Scale.html>`_
: Compute the affine transformation along the given axes.
......@@ -149,8 +141,6 @@ vm.caffe.layers
layers/Reduction
layers/ReLU
layers/Reshape
layers/ROIAlign
layers/ROIPooling
layers/Scale
layers/Sigmoid
layers/SigmoidCrossEntropyLoss
......
......@@ -34,7 +34,7 @@ extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinxcontrib.katex',
'sphinx_seeta_theme.ext.viewcode',
# 'sphinx_seeta_theme.ext.viewcode',
]
napoleon_use_rtype = False
......
......@@ -6,17 +6,17 @@ dragon
Classes
-------
`class EagerTensor <dragon/EagerTensor.html>`_
: Tensor abstraction for eager executing.
`class DeviceSpec <dragon/DeviceSpec.html>`_
: Describe a computation device.
`class GradientTape <dragon/GradientTape.html>`_
: Record the operations for auto differentiation.
`class Tensor <dragon/Tensor.html>`_
: Tensor abstraction for graph executing.
: A multi-dimensional array for computation
`class Workspace <dragon/Workspace.html>`_
: Sandbox to isolate the resources and computations.
: Standalone environment for resources and computations.
Functions
---------
......@@ -27,6 +27,9 @@ dragon
`assign(...) <dragon/assign.html>`_
: Assign the value to input.
`boolean_mask(...) <dragon/boolean_mask.html>`_
: Return the elements of input where mask is true.
`broadcast_to(...) <dragon/broadcast_to.html>`_
: Broadcast input according to a given shape.
......@@ -34,13 +37,14 @@ dragon
: Cast the data type of input.
`channel_affine(...) <dragon/channel_affine.html>`_
: Apply affine transformation along the channels.
: Apply affine transformation to each channel of input.
`channel_normalize(...) <dragon/channel_normalize.html>`_
: Normalize channels with mean and standard deviation.
: Apply normalization to each channel of input.
`channel_shuffle(...) <dragon/channel_shuffle.html>`_
: Shuffle channels between a given number of groups.
: Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`concat(...) <dragon/concat.html>`_
: Concatenate the inputs along the given axis.
......@@ -48,17 +52,14 @@ dragon
`constant(...) <dragon/constant.html>`_
: Return a tensor initialized from the value.
`create_function(...) <dragon/create_function.html>`_
: Create a callable graph from the specified outputs.
`device(...) <dragon/device.html>`_
: Context-manager to nest the device spec.
`eager_mode(...) <dragon/eager_mode.html>`_
: Context-manager set the eager execution mode.
`eager_scope(...) <dragon/eager_mode.html>`_
: Context-manager to nest the name for eager resources.
`variable_scope(...) <dragon/eager_mode.html>`_
: Context-manager to nest the namespace for variables.
`expand_dims(...) <dragon/expand_dims.html>`_
: Expand the dimensions of input with size 1.
......@@ -78,14 +79,14 @@ dragon
`function(...) <dragon/function.html>`_
: Compile a function and return an executable.
`gather(...) <dragon/gather.html>`_
: Gather the elements along the given axis using index.
`get_num_threads(...) <dragon/get_num_threads.html>`_
: Return the number of threads for cpu parallelism.
`get_workspace(...) <dragon/get_workspace.html>`_
: Return the current default workspace.
`gradients(...) <dragon/gradients.html>`_
: Compute the symbolic derivatives of ``ys`` w.r.t. ``xs`` .
: Return the default workspace.
`graph_mode(...) <dragon/graph_mode.html>`_
: Context-manager set the graph execution mode.
......@@ -93,21 +94,12 @@ dragon
`identity(...) <dragon/identity.html>`_
: Return a tensor copied from the input.
`index_select(...) <dragon/index_select.html>`_
: Select the elements according to the index along the given axis.
`linspace(...) <dragon/linspace.html>`_
: Generate evenly spaced values within intervals along the given axis.
`load_library(...) <dragon/load_library.html>`_
: Load a shared library.
`masked_assign(...) <dragon/masked_assign.html>`_
: Assign the value to input where mask is 1.
`masked_select(...) <dragon/masked_select.html>`_
: Select the elements of input where mask is 1.
`name_scope(...) <dragon/name_scope.html>`_
: Context-manager to nest the name as prefix for operations.
......@@ -166,11 +158,17 @@ dragon
: Return the identity of input with truncated gradient-flow.
`tile(...) <dragon/tile.html>`_
: Tile the input according to the given repeats.
: Repeat elements along each axis of input.
`transpose(...) <dragon/transpose.html>`_
: Permute the dimensions of input.
`tril(...) <dragon/tril.html>`_
: Return the lower triangular part of input.
`triu(...) <dragon/triu.html>`_
: Return the upper triangular part of input.
`unique(...) <dragon/unique.html>`_
: Return the unique elements of input.
......@@ -186,12 +184,13 @@ dragon
.. toctree::
:hidden:
dragon/EagerTensor
dragon/DeviceSpec
dragon/GradientTape
dragon/Tensor
dragon/Workspace
dragon/argsort
dragon/assign
dragon/boolean_mask
dragon/broadcast_to
dragon/cast
dragon/channel_affine
......@@ -199,26 +198,21 @@ dragon
dragon/channel_shuffle
dragon/concat
dragon/constant
dragon/create_function
dragon/device
dragon/eager_mode
dragon/eager_scope
dragon/expand_dims
dragon/eye
dragon/eye_like
dragon/fill
dragon/flatten
dragon/function
dragon/gather
dragon/get_num_threads
dragon/get_workspace
dragon/gradients
dragon/graph_mode
dragon/identity
dragon/index_select
dragon/linspace
dragon/load_library
dragon/masked_assign
dragon/masked_select
dragon/name_scope
dragon/nonzero
dragon/ones
......@@ -240,7 +234,10 @@ dragon
dragon/stop_gradient
dragon/tile
dragon/transpose
dragon/tril
dragon/triu
dragon/unique
dragon/variable_scope
dragon/where
dragon/zeros
dragon/zeros_like
......
DeviceSpec
==========
.. autoclass:: dragon.DeviceSpec
__init__
--------
.. automethod:: dragon.DeviceSpec.__init__
Properties
----------
index
#####
.. autoattribute:: dragon.DeviceSpec.index
type
####
.. autoattribute:: dragon.DeviceSpec.type
Methods
-------
copy
####
.. automethod:: dragon.DeviceSpec.copy
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
EagerTensor
===========
.. autoclass:: dragon.EagerTensor
__init__
--------
.. automethod:: dragon.EagerTensor.__init__
Properties
----------
dtype
#####
.. autoattribute:: dragon.EagerTensor.dtype
id
##
.. autoattribute:: dragon.EagerTensor.id
name
####
.. autoattribute:: dragon.EagerTensor.name
ndim
####
.. autoattribute:: dragon.EagerTensor.ndim
shape
#####
.. autoattribute:: dragon.EagerTensor.shape
size
#####
.. autoattribute:: dragon.EagerTensor.size
Methods
-------
astype
######
.. automethod:: dragon.EagerTensor.astype
constant
########
.. automethod:: dragon.EagerTensor.constant
copy
####
.. automethod:: dragon.EagerTensor.copy
from_value
##########
.. automethod:: dragon.EagerTensor.from_value
get_value
#########
.. automethod:: dragon.EagerTensor.get_value
glorot_normal
#############
.. automethod:: dragon.EagerTensor.glorot_normal
glorot_uniform
##############
.. automethod:: dragon.EagerTensor.glorot_uniform
normal
######
.. automethod:: dragon.EagerTensor.normal
numpy
#####
.. automethod:: dragon.EagerTensor.numpy
reshape
#######
.. automethod:: dragon.EagerTensor.reshape
set_value
#########
.. automethod:: dragon.EagerTensor.set_value
truncated_normal
################
.. automethod:: dragon.EagerTensor.truncated_normal
uniform
#######
.. automethod:: dragon.EagerTensor.uniform
Overrides
---------
__add__
#######
.. automethod:: dragon.EagerTensor.__add__
__float__
#########
.. automethod:: dragon.EagerTensor.__float__
__ge__
######
.. automethod:: dragon.EagerTensor.__ge__
__getitem__
###########
.. automethod:: dragon.EagerTensor.__getitem__
__gt__
######
.. automethod:: dragon.EagerTensor.__gt__
__iadd__
########
.. automethod:: dragon.EagerTensor.__iadd__
__imul__
########
.. automethod:: dragon.EagerTensor.__imul__
__int__
#######
.. automethod:: dragon.EagerTensor.__int__
__isub__
########
.. automethod:: dragon.EagerTensor.__isub__
__itruediv__
############
.. automethod:: dragon.EagerTensor.__itruediv__
__le__
######
.. automethod:: dragon.EagerTensor.__le__
__lt__
######
.. automethod:: dragon.EagerTensor.__lt__
__mul__
#######
.. automethod:: dragon.EagerTensor.__mul__
__neg__
#######
.. automethod:: dragon.EagerTensor.__neg__
__radd__
########
.. automethod:: dragon.EagerTensor.__radd__
__rmul__
########
.. automethod:: dragon.EagerTensor.__rmul__
__rsub__
########
.. automethod:: dragon.EagerTensor.__rsub__
__rtruediv__
############
.. automethod:: dragon.EagerTensor.__rtruediv__
__setitem__
###########
.. automethod:: dragon.EagerTensor.__setitem__
__sub__
#######
.. automethod:: dragon.EagerTensor.__sub__
__truediv__
###########
.. automethod:: dragon.EagerTensor.__truediv__
.. _dragon.assign(...): assign.html
.. _dragon.cast(...): cast.html
.. _dragon.fill(...): fill.html
.. _dragon.identity(...): identity.html
.. _dragon.masked_assign(...): masked_assign.html
.. _dragon.masked_select(...): masked_select.html
.. _dragon.math.add(...): math/add.html
.. _dragon.math.div(...): math/div.html
.. _dragon.math.greater(...): math/greater.html
.. _dragon.math.greater_equal(...): math/greater_equal.html
.. _dragon.math.less(...): math/less.html
.. _dragon.math.less_equal(...): math/less_equal.html
.. _dragon.math.mul(...): math/mul.html
.. _dragon.math.negative(...): math/negative.html
.. _dragon.math.sub(...): math/sub.html
.. _dragon.random.glorot_normal(...): random/glorot_normal.html
.. _dragon.random.glorot_uniform(...): random/glorot_uniform.html
.. _dragon.random.normal(...): random/normal.html
.. _dragon.random.truncated_normal(...): random/truncated_normal.html
.. _dragon.random.uniform(...): random/uniform.html
.. _dragon.reshape(...): reshape.html
.. _dragon.slice(...): slice.html
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
......@@ -10,6 +10,10 @@ __init__
Properties
----------
device
######
.. autoattribute:: dragon.Tensor.device
dtype
#####
.. autoattribute:: dragon.Tensor.dtype
......@@ -26,6 +30,10 @@ ndim
####
.. autoattribute:: dragon.Tensor.ndim
requires_grad
#############
.. autoattribute:: dragon.Tensor.requires_grad
shape
#####
.. autoattribute:: dragon.Tensor.shape
......@@ -41,21 +49,13 @@ astype
######
.. automethod:: dragon.Tensor.astype
constant
########
.. automethod:: dragon.Tensor.constant
copy
####
.. automethod:: dragon.Tensor.copy
from_value
##########
.. automethod:: dragon.Tensor.from_value
get_value
##########
.. automethod:: dragon.Tensor.get_value
fill
####
.. automethod:: dragon.Tensor.fill
glorot_normal
#############
......@@ -69,14 +69,14 @@ normal
######
.. automethod:: dragon.Tensor.normal
numpy
#####
.. automethod:: dragon.Tensor.numpy
reshape
#######
.. automethod:: dragon.Tensor.reshape
set_value
#########
.. automethod:: dragon.Tensor.set_value
truncated_normal
################
.. automethod:: dragon.Tensor.truncated_normal
......@@ -92,6 +92,10 @@ __add__
#######
.. automethod:: dragon.Tensor.__add__
__and__
#######
.. automethod:: dragon.Tensor.__and__
__float__
#########
.. automethod:: dragon.Tensor.__float__
......@@ -108,10 +112,42 @@ __gt__
######
.. automethod:: dragon.Tensor.__gt__
__iadd__
########
.. automethod:: dragon.Tensor.__iadd__
__iand__
########
.. automethod:: dragon.Tensor.__iand__
__imul__
########
.. automethod:: dragon.Tensor.__imul__
__int__
#######
.. automethod:: dragon.Tensor.__int__
__invert__
###########
.. automethod:: dragon.Tensor.__invert__
__ior__
#######
.. automethod:: dragon.Tensor.__ior__
__isub__
########
.. automethod:: dragon.Tensor.__isub__
__itruediv__
############
.. automethod:: dragon.Tensor.__itruediv__
__ixor__
########
.. automethod:: dragon.Tensor.__ixor__
__le__
######
.. automethod:: dragon.Tensor.__le__
......@@ -128,18 +164,34 @@ __neg__
#######
.. automethod:: dragon.Tensor.__neg__
__or__
#######
.. automethod:: dragon.Tensor.__or__
__radd__
########
.. automethod:: dragon.Tensor.__radd__
__rand__
########
.. automethod:: dragon.Tensor.__rand__
__rmul__
########
.. automethod:: dragon.Tensor.__rmul__
__ror__
#######
.. automethod:: dragon.Tensor.__ror__
__rsub__
########
.. automethod:: dragon.Tensor.__rsub__
__rxor__
########
.. automethod:: dragon.Tensor.__rxor__
__setitem__
###########
.. automethod:: dragon.Tensor.__setitem__
......@@ -156,12 +208,18 @@ __truediv__
############
.. automethod:: dragon.Tensor.__truediv__
__xor__
#######
.. automethod:: dragon.Tensor.__xor__
.. _dragon.assign(...): assign.html
.. _dragon.bitwise.bitwise_and(...): bitwise/bitwise_and.html
.. _dragon.bitwise.bitwise_or(...): bitwise/bitwise_or.html
.. _dragon.bitwise.bitwise_xor(...): bitwise/bitwise_xor.html
.. _dragon.bitwise.invert(...): bitwise/invert.html
.. _dragon.cast(...): cast.html
.. _dragon.fill(...): fill.html
.. _dragon.identity(...): identity.html
.. _dragon.masked_assign(...): masked_assign.html
.. _dragon.masked_select(...): masked_select.html
.. _dragon.math.add(...): math/add.html
.. _dragon.math.div(...): math/div.html
.. _dragon.math.greater(...): math/greater.html
......@@ -177,7 +235,6 @@ __truediv__
.. _dragon.random.truncated_normal(...): random/truncated_normal.html
.. _dragon.random.uniform(...): random/uniform.html
.. _dragon.reshape(...): reshape.html
.. _dragon.slice(...): slice.html
.. raw:: html
......
......@@ -18,18 +18,6 @@ clear
#####
.. automethod:: dragon.Workspace.clear
feed_tensor
###########
.. automethod:: dragon.Workspace.feed_tensor
fetch_tensor
############
.. automethod:: dragon.Workspace.fetch_tensor
has_tensor
##########
.. automethod:: dragon.Workspace.has_tensor
memory_allocated
################
.. automethod:: dragon.Workspace.memory_allocated
......@@ -38,10 +26,6 @@ merge_from
##########
.. automethod:: dragon.Workspace.merge_from
reset_tensor
############
.. automethod:: dragon.Workspace.reset_tensor
.. raw:: html
<style>
......
......@@ -6,9 +6,6 @@ dragon.autograph
Functions
---------
`set_execution(...) <autograph/set_execution.html>`_
: Set the execution mode for graph ir.
`set_optimization(...) <autograph/set_optimization.html>`_
: Set the optimization for graph ir.
......@@ -21,7 +18,6 @@ dragon.autograph
.. toctree::
:hidden:
autograph/set_execution
autograph/set_optimization
autograph/set_scheduler
autograph/set_verbosity
......
index_select
boolean_mask
============
.. autofunction:: dragon.index_select
.. autofunction:: dragon.boolean_mask
.. raw:: html
......
create_function
===============
gather
======
.. autofunction:: dragon.create_function
.. autofunction:: dragon.gather
.. raw:: html
......
......@@ -10,29 +10,27 @@ dragon.losses
: Compute the ctc loss with batched labels.
`l1_loss(...) <losses/l1_loss.html>`_
: Compute the element-wise absolute value difference.
: Compute the loss of element-wise absolute value difference.
`l2_loss(...) <losses/l2_loss.html>`_
: Compute the element-wise squared error.
: Compute the loss of element-wise squared error.
`nll_loss(...) <losses/nll_loss.html>`_
: Compute the negative likelihood loss with sparse labels.
: Compute the loss of negative likelihood.
`sigmoid_cross_entropy(...) <losses/sigmoid_cross_entropy.html>`_
: Compute the sigmoid cross entropy with contiguous targets.
`sigmoid_cross_entropy_loss(...) <losses/sigmoid_cross_entropy_loss.html>`_
: Compute the loss of sigmoid cross entropy.
`sigmoid_focal_loss(...) <losses/sigmoid_focal_loss.html>`_
: Compute the sigmoid focal loss with sparse labels.
: Compute the focal loss of sigmoid cross entropy.
`[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`_.
`smooth_l1_loss(...) <losses/smooth_l1_loss.html>`_
: Compute the element-wise error transited from L1 and L2.
: Compute the loss of element-wise error transited from L1 and L2.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
`softmax_cross_entropy(...) <losses/softmax_cross_entropy.html>`_
: Compute the softmax cross entropy with contiguous targets.
`sparse_softmax_cross_entropy(...) <losses/sparse_softmax_cross_entropy.html>`_
: Compute the softmax cross entropy with sparse labels.
`softmax_cross_entropy_loss(...) <losses/softmax_cross_entropy_loss.html>`_
: Compute the loss of softmax cross entropy.
.. toctree::
:hidden:
......@@ -41,11 +39,10 @@ dragon.losses
losses/l1_loss
losses/l2_loss
losses/nll_loss
losses/sigmoid_cross_entropy
losses/sigmoid_cross_entropy_loss
losses/sigmoid_focal_loss
losses/smooth_l1_loss
losses/softmax_cross_entropy
losses/sparse_softmax_cross_entropy
losses/softmax_cross_entropy_loss
.. raw:: html
......
sigmoid_cross_entropy
=====================
sigmoid_cross_entropy_loss
==========================
.. autofunction:: dragon.losses.sigmoid_cross_entropy
.. autofunction:: dragon.losses.sigmoid_cross_entropy_loss
.. raw:: html
......
softmax_cross_entropy
=====================
softmax_cross_entropy_loss
==========================
.. autofunction:: dragon.losses.softmax_cross_entropy
.. autofunction:: dragon.losses.softmax_cross_entropy_loss
.. raw:: html
......
......@@ -18,9 +18,6 @@ dragon.math
`argmin(...) <math/argmin.html>`_
: Compute the index of minimum elements along the given axis.
`axpby(...) <math/axpby.html>`_
: Compute the element-wise addition from input to output.
`ceil(...) <math/ceil.html>`_
: Compute the smallest integer not less than input.
......@@ -36,9 +33,6 @@ dragon.math
`div(...) <math/div.html>`_
: Compute the element-wise division.
`dot(...) <math/dot.html>`_
: Compute the vector dot.
`equal(...) <math/equal.html>`_
: Compute the element-wise equal comparison.
......@@ -72,6 +66,18 @@ dragon.math
`log(...) <math/log.html>`_
: Compute the logarithm of input.
`logical_and(...) <math/logical_and.html>`_
: Compute the element-wise AND logical operation.
`logical_not(...) <math/logical_not.html>`_
: Compute the element-wise NOT logical operation.
`logical_or(...) <math/logical_or.html>`_
: Compute the element-wise OR logical operation.
`logical_xor(...) <math/logical_xor.html>`_
: Compute the element-wise XOR logical operation.
`lp_normalize(...) <math/lp_normalize.html>`_
: Apply the lp normalization.
......@@ -93,9 +99,6 @@ dragon.math
`minimum(...) <math/minimum.html>`_
: Compute the minimum value of given two inputs.
`moments(...) <math/moments.html>`_
: Compute the mean and variance of input along the given axes.
`mul(...) <math/mul.html>`_
: Compute the element-wise multiplication.
......@@ -151,13 +154,11 @@ dragon.math
math/add
math/argmax
math/argmin
math/axpby
math/ceil
math/clip
math/cos
math/cumsum
math/div
math/dot
math/equal
math/exp
math/floor
......@@ -169,6 +170,10 @@ dragon.math
math/less
math/less_equal
math/log
math/logical_and
math/logical_not
math/logical_or
math/logical_xor
math/lp_normalize
math/matmul
math/max
......@@ -176,7 +181,6 @@ dragon.math
math/mean
math/min
math/minimum
math/moments
math/mul
math/negative
math/not_equal
......
eager_scope
logical_and
===========
.. autofunction:: dragon.eager_scope
.. autofunction:: dragon.math.logical_and
.. raw:: html
<style>
h1:before {
content: "dragon.";
content: "dragon.math.";
color: #103d3e;
}
</style>
axpby
=====
logical_not
===========
.. autofunction:: dragon.math.axpby
.. autofunction:: dragon.math.logical_not
.. raw:: html
......
dot
===
logical_or
===========
.. autofunction:: dragon.math.dot
.. autofunction:: dragon.math.logical_or
.. raw:: html
......
logical_xor
===========
.. autofunction:: dragon.math.logical_xor
.. raw:: html
<style>
h1:before {
content: "dragon.math.";
color: #103d3e;
}
</style>
......@@ -59,15 +59,15 @@ dragon.nn
: Rearrange depth data into spatial blocks.
`dropout(...) <nn/dropout.html>`_
: Set the elements of the input to zero randomly.
: Set the elements of input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
`drop_block2d(...) <nn/drop_block2d.html>`_
: Set the spatial blocks over input to zero randomly.
`drop_block(...) <nn/drop_block.html>`_
: Set the blocks over input to zero randomly.
`[Ghiasi et.al, 2018] <https://arxiv.org/abs/1810.12890>`_.
`drop_path(...) <nn/drop_path.html>`_
: Set the examples over the input to zero randomly.
: Set the examples over input to zero randomly.
`[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_.
`elu(...) <nn/elu.html>`_
......@@ -103,6 +103,9 @@ dragon.nn
`log_softmax(...) <nn/log_softmax.html>`_
: Compute the composite of logarithm and softmax.
`moments(...) <nn/moments.html>`_
: Compute the mean and variance of input along the given axis.
`prelu(...) <nn/prelu.html>`_
: Apply the parametric rectified linear unit.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
......@@ -161,7 +164,7 @@ dragon.nn
nn/depthwise_conv2d
nn/depth_to_space
nn/dropout
nn/drop_block2d
nn/drop_block
nn/drop_path
nn/elu
nn/group_norm
......@@ -172,6 +175,7 @@ dragon.nn
nn/leaky_relu
nn/local_response_norm
nn/log_softmax
nn/moments
nn/pool
nn/pool1d
nn/pool2d
......
ROIPooling
drop_block
==========
.. autoclass:: dragon.vm.caffe.core.layers.ROIPooling
.. autofunction:: dragon.nn.drop_block
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
content: "dragon.nn.";
color: #103d3e;
}
</style>
moments
=======
.. autofunction:: dragon.math.moments
.. autofunction:: dragon.nn.moments
.. raw:: html
<style>
h1:before {
content: "dragon.math.";
content: "dragon.nn.";
color: #103d3e;
}
</style>
......@@ -37,14 +37,14 @@ Name Supported Reference
`Acos`_
`Acosh`_
`Add`_ |v| :func:`dragon.math.add`
`And`_ |v| :func:`dragon.bitwise.bitwise_and`
`And`_ |v| :func:`dragon.math.logical_and`
`ArgMax`_ |v| :func:`dragon.math.argmax`
`ArgMin`_ |v| :func:`dragon.math.argmin`
`Asin`_
`Asinh`_
`Atan`_
`Atanh`_
`AveragePool`_ |v| :func:`dragon.nn.pool2d`
`AveragePool`_ |v| :func:`dragon.nn.pool`
`BatchNormalization`_ |v| :func:`dragon.nn.batch_norm`
`BitShift`_
`Cast`_ |v| :func:`dragon.cast`
......@@ -55,9 +55,9 @@ Name Supported Reference
`ConcatFromSequence`_
`Constant`_
`ConstantOfShape`_
`Conv`_ |v| :func:`dragon.nn.conv2d`
`Conv`_ |v| :func:`dragon.nn.conv`
`ConvInteger`_
`ConvTranspose`_ |v| :func:`dragon.nn.conv2d_transpose`
`ConvTranspose`_ |v| :func:`dragon.nn.conv_transpose`
`Cos`_ |v| :func:`dragon.math.cos`
`Cosh`_
`CumSum`_ |v| :func:`dragon.math.cumsum`
......@@ -76,13 +76,13 @@ Name Supported Reference
`Flatten`_ |v| :func:`dragon.flatten`
`Floor`_ |v| :func:`dragon.math.floor`
`GRU`_ |v| :func:`dragon.nn.GRU`
`Gather`_ |v| :func:`dragon.index_select`
`Gather`_ |v| :func:`dragon.gather`
`GatherElements`_
`GatherND`_
`Gemm`_ |v| :func:`dragon.math.gemm`
`GlobalAveragePool`_ |v| :func:`dragon.nn.pool2d`
`GlobalAveragePool`_ |v| :func:`dragon.nn.pool`
`GlobalLpPool`_
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d`
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool`
`Greater`_ |v| :func:`dragon.math.greater`
`HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid`
`Hardmax`_
......@@ -103,7 +103,7 @@ Name Supported Reference
`MatMul`_ |v| :func:`dragon.math.matmul`
`MatMulInteger`_
`Max`_ |v| :func:`dragon.math.maximum`
`MaxPool`_ |v| :func:`dragon.nn.pool2d`
`MaxPool`_ |v| :func:`dragon.nn.pool`
`MaxRoiPool`_ |v| :func:`dragon.vision.roi_pool`
`MaxUnpool`_
`Mean`_ |v| :func:`dragon.math.add`
......@@ -114,9 +114,9 @@ Name Supported Reference
`Neg`_ |v| :func:`dragon.math.negative`
`NonMaxSuppression`_
`NonZero`_ |v| :func:`dragon.nonzero`
`Not`_ |v| :func:`dragon.bitwise.invert`
`Not`_ |v| :func:`dragon.math.logical_not`
`OneHot`_ |v| :func:`dragon.one_hot`
`Or`_ |v| :func:`dragon.bitwise.bitwise_or`
`Or`_ |v| :func:`dragon.math.logical_or`
`PRelu`_ |v| :func:`dragon.nn.prelu`
`Pad`_ |v| :func:`dragon.pad`
`Pow`_ |v| :func:`dragon.math.pow`
......@@ -186,7 +186,7 @@ Name Supported Reference
`Unsqueeze`_ |v| :func:`dragon.expand_dims`
`Upsample`_ |v| :func:`dragon.vision.resize`
`Where`_ |v| :func:`dragon.where`
`Xor`_ |v| :func:`dragon.bitwise.bitwise_xor`
`Xor`_ |v| :func:`dragon.math.logical_xor`
======================== ========= ========================================
.. toctree::
......
......@@ -13,7 +13,7 @@ dragon.random
: Return a tensor initialized from the glorot uniform distribution.
`multinomial(...) <random/multinomial.html>`_
: Return a tensor with index sampled from multinomial distribution.
: Return an index tensor sampled from the multinomial distribution.
`normal(...) <random/normal.html>`_
: Return a tensor initialized from the normal distribution.
......
gradients
=========
tril
====
.. autofunction:: dragon.gradients
.. autofunction:: dragon.tril
.. raw:: html
......
masked_assign
=============
triu
====
.. autofunction:: dragon.masked_assign
.. autofunction:: dragon.triu
.. raw:: html
......
masked_select
=============
variable_scope
==============
.. autofunction:: dragon.masked_select
.. autofunction:: dragon.variable_scope
.. raw:: html
......
......@@ -55,14 +55,11 @@ vm.tensorflow
: Return a tensor filled with the scalar value.
`gather(...) <tensorflow/gather.html>`_
: Select the elements according to the index along the given axis.
: Gather the elements along the given axis using index.
`function(...) <tensorflow/function.html>`_
: Create a callable graph from the python function.
`gradients(...) <tensorflow/gradients.html>`_
: Compute the symbolic derivatives of ``ys`` w.r.t. ``xs`` .
`identity(...) <tensorflow/identity.html>`_
: Return a tensor copied from the input.
......@@ -105,6 +102,9 @@ vm.tensorflow
`squeeze(...) <tensorflow/squeeze.html>`_
: Remove the dimensions of input with size 1.
`tile(...) <tensorflow/tile.html>`_
: Tile input according to the given repeats.
`transpose(...) <tensorflow/transpose.html>`_
: Permute the dimensions of input.
......@@ -140,7 +140,6 @@ vm.tensorflow
tensorflow/fill
tensorflow/function
tensorflow/gather
tensorflow/gradients
tensorflow/identity
tensorflow/linspace
tensorflow/name_scope
......@@ -155,6 +154,7 @@ vm.tensorflow
tensorflow/sort
tensorflow/split
tensorflow/squeeze
tensorflow/tile
tensorflow/transpose
tensorflow/unique
tensorflow/unique_with_counts
......
......@@ -57,7 +57,7 @@ vm.tensorflow.nn
: Rearrange depth data into spatial blocks.
`dropout(...) <nn/dropout.html>`_
: Set the elements of the input to zero randomly.
: Set the elements of input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
`elu(...) <nn/elu.html>`_
......
gradients
=========
tile
====
.. autofunction:: dragon.vm.tensorflow.gradients
.. autofunction:: dragon.vm.tensorflow.tile
.. raw:: html
......
......@@ -51,15 +51,18 @@ vm.torch
`argsort(...) <torch/argsort.html>`_
: Return the index of sorted elements along the given dimension.
`axpby(...) <torch/axpby.html>`_
: Compute the element-wise addition from input to output.
`baddbmm(...) <torch/baddbmm.html>`_
: Add input to the result of batched matrix-matrix multiplication.
`bitwise_and(...) <torch/bitwise_and.html>`_
: Compute the element-wise AND bitwise operation.
`bitwise_not(...) <torch/bitwise_not.html>`_
: Compute the element-wise NOT bitwise operation.
`bitwise_or(...) <torch/bitwise_or.html>`_
: Compute the element-wise OR bitwise operation.
`bitwise_xor(...) <torch/bitwise_xor.html>`_
: Compute the element-wise XOR bitwise operation.
......@@ -73,13 +76,13 @@ vm.torch
: Compute the smallest integer not less than input.
`channel_affine(...) <torch/channel_affine.html>`_
: Apply affine transformation along the channels.
: Apply affine transformation to each channel of input.
`channel_normalize(...) <torch/channel_normalize.html>`_
: Normalize channels with mean and standard deviation.
: Apply normalization to each channel of input.
`channel_shuffle(...) <torch/channel_shuffle.html>`_
: Shuffle channels between a given number of groups.
: Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`chunk(...) <torch/chunk.html>`_
......@@ -143,11 +146,23 @@ vm.torch
: Compute the element-wise less-equal comparison.
`linspace(...) <torch/linspace.html>`_
: Generate evenly spaced values within intervals along the given axis.
: Generate evenly spaced values within intervals along the given dimension.
`log(...) <torch/log.html>`_
: Compute the natural logarithm of input.
`logical_and(...) <torch/logical_and.html>`_
: Compute the element-wise AND logical operation.
`logical_not(...) <torch/logical_not.html>`_
: Compute the element-wise NOT logical operation.
`logical_or(...) <torch/logical_or.html>`_
: Compute the element-wise OR logical operation.
`logical_xor(...) <torch/logical_xor.html>`_
: Compute the element-wise XOR logical operation.
`logsumexp(...) <torch/logsumexp.html>`_
: Apply the composite of log, sum, and exp to input.
......@@ -155,7 +170,7 @@ vm.torch
: Compute the element-wise less comparison.
`masked_select(...) <torch/logsumexp.html>`_
: Select the input elements where mask is 1.
: Select the input elements where mask is true.
`matmul(...) <torch/matmul.html>`_
: Compute the matrix multiplication.
......@@ -223,9 +238,6 @@ vm.torch
`reciprocal(...) <torch/reciprocal.html>`_
: Compute the reciprocal of input.
`repeat(...) <torch/repeat.html>`_
: Repeat elements along the specified dimensions.
`reshape(...) <torch/reshape.html>`_
: Change the shape of input.
......@@ -265,12 +277,21 @@ vm.torch
`tensor(...) <torch/tensor.html>`_
: Create a tensor initializing the content from data.
`tile(...) <torch/tile.html>`_
: Repeat elements along each dimension of input.
`topk(...) <torch/topk.html>`_
: Return the top-K largest or smallest elements along the given dimension.
`transpose(...) <torch/transpose.html>`_
: Return a new tensor with two dimensions swapped.
`tril(...) <torch/tril.html>`_
: Return the lower triangular part of input.
`triu(...) <torch/triu.html>`_
: Return the upper triangular part of input.
`unique(...) <torch/unique.html>`_
: Return the unique elements of input.
......@@ -298,9 +319,10 @@ vm.torch
torch/argmax
torch/argmin
torch/argsort
torch/axpby
torch/baddbmm
torch/bitwise_and
torch/bitwise_not
torch/bitwise_or
torch/bitwise_xor
torch/bmm
torch/cat
......@@ -333,6 +355,10 @@ vm.torch
torch/le
torch/linspace
torch/log
torch/logical_and
torch/logical_not
torch/logical_or
torch/logical_xor
torch/logsumexp
torch/lt
torch/masked_select
......@@ -359,7 +385,6 @@ vm.torch
torch/randn
torch/randperm
torch/reciprocal
torch/repeat
torch/reshape
torch/round
torch/rsqrt
......@@ -374,8 +399,11 @@ vm.torch
torch/sub
torch/sum
torch/tensor
torch/tile
torch/topk
torch/transpose
torch/tril
torch/triu
torch/unique
torch/unsqueeze
torch/where
......
......@@ -81,6 +81,14 @@ baddbmm\_
#########
.. automethod:: dragon.vm.torch.Tensor.baddbmm_
bitwise_and
###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_and
bitwise_and\_
#############
.. automethod:: dragon.vm.torch.Tensor.bitwise_and_
bitwise_not
###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_not
......@@ -89,6 +97,14 @@ bitwise_not\_
#############
.. automethod:: dragon.vm.torch.Tensor.bitwise_not_
bitwise_or
###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_or
bitwise_or\_
############
.. automethod:: dragon.vm.torch.Tensor.bitwise_or_
bitwise_xor
###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_xor
......@@ -281,6 +297,22 @@ log
###
.. automethod:: dragon.vm.torch.Tensor.log
logical_and
###########
.. automethod:: dragon.vm.torch.Tensor.logical_and
logical_not
###########
.. automethod:: dragon.vm.torch.Tensor.logical_not
logical_or
##########
.. automethod:: dragon.vm.torch.Tensor.logical_or
logical_xor
###########
.. automethod:: dragon.vm.torch.Tensor.logical_xor
logsumexp
#########
.. automethod:: dragon.vm.torch.Tensor.logsumexp
......@@ -513,6 +545,22 @@ transpose
#########
.. automethod:: dragon.vm.torch.Tensor.transpose
tril
####
.. automethod:: dragon.vm.torch.Tensor.tril
tril\_
######
.. automethod:: dragon.vm.torch.Tensor.tril_
triu
####
.. automethod:: dragon.vm.torch.Tensor.triu
triu\_
######
.. automethod:: dragon.vm.torch.Tensor.triu_
type
####
.. automethod:: dragon.vm.torch.Tensor.type
......@@ -560,7 +608,9 @@ zero\_
.. _torch.argmin(...): argmin.html
.. _torch.argsort(...): argsort.html
.. _torch.baddbmm(...): baddbmm.html
.. _torch.bitwise_and(...): bitwise_and.html
.. _torch.bitwise_not(...): bitwise_not.html
.. _torch.bitwise_or(...): bitwise_or.html
.. _torch.bitwise_xor(...): bitwise_xor.html
.. _torch.bmm(...): bmm.html
.. _torch.ceil(...): ceil.html
......@@ -579,6 +629,12 @@ zero\_
.. _torch.isinf(...): isinf.html
.. _torch.isnan(...): isnan.html
.. _torch.le(...): le.html
.. _torch.log(...): log.html
.. _torch.logical_and(...): logical_and.html
.. _torch.logical_not(...): logical_not.html
.. _torch.logical_or(...): logical_or.html
.. _torch.logical_xor(...): logical_xor.html
.. _torch.logsumexp(...): logsumexp.html
.. _torch.lt(...): lt.html
.. _torch.matmul(...): matmul.html
.. _torch.max(...): max.html
......@@ -606,6 +662,8 @@ zero\_
.. _torch.sum(...): sum.html
.. _torch.topk(...): topk.html
.. _torch.transpose(...): transpose.html
.. _torch.tril(...): tril.html
.. _torch.triu(...): triu.html
.. _torch.unique(...): unique.html
.. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html
......
......@@ -6,9 +6,6 @@ vm.torch.autograd
Classes
-------
`class Function <autograd/Function.html>`_
: Dispatch the tensor operation.
Functions
---------
......@@ -18,7 +15,6 @@ vm.torch.autograd
.. toctree::
:hidden:
autograd/Function
autograd/backward
.. raw:: html
......
Function
========
.. autoclass:: dragon.vm.torch.autograd.Function
__init__
--------
.. automethod:: dragon.vm.torch.autograd.Function.__init__
Methods
-------
alloc
#####
.. automethod:: dragon.vm.torch.autograd.Function.alloc
apply
#####
.. automethod:: dragon.vm.torch.autograd.Function.apply
attributes
##########
.. automethod:: dragon.vm.torch.autograd.Function.attributes
dispatch
########
.. automethod:: dragon.vm.torch.autograd.Function.dispatch
instantiate
###########
.. automethod:: dragon.vm.torch.autograd.Function.instantiate
forward
#######
.. automethod:: dragon.vm.torch.autograd.Function.forward
.. raw:: html
<style>
h1:before {
content: "torch.autograd.";
color: #103d3e;
}
</style>
axpby
=====
bitwise_and
===========
.. autofunction:: dragon.vm.torch.axpby
.. autofunction:: dragon.vm.torch.bitwise_and
.. raw:: html
......
repeat
======
bitwise_or
==========
.. autofunction:: dragon.vm.torch.bitwise_or
.. autofunction:: dragon.vm.torch.repeat
.. raw:: html
......
......@@ -7,6 +7,24 @@ __init__
--------
.. automethod:: dragon.vm.torch.device.__init__
Properties
----------
index
#####
.. autoattribute:: dragon.vm.torch.device.index
type
####
.. autoattribute:: dragon.vm.torch.device.type
Methods
-------
copy
####
.. automethod:: dragon.vm.torch.device.copy
.. raw:: html
<style>
......
ROIAlign
========
logical_and
===========
.. autofunction:: dragon.vm.torch.logical_and
.. autoclass:: dragon.vm.caffe.core.layers.ROIAlign
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
content: "torch.";
color: #103d3e;
}
</style>
set_execution
=============
logical_not
===========
.. autofunction:: dragon.vm.torch.logical_not
.. autofunction:: dragon.autograph.set_execution
.. raw:: html
<style>
h1:before {
content: "dragon.autograph.";
content: "torch.";
color: #103d3e;
}
</style>
logical_or
==========
.. autofunction:: dragon.vm.torch.logical_or
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
drop_block2d
============
logical_xor
===========
.. autofunction:: dragon.vm.torch.logical_xor
.. autofunction:: dragon.nn.drop_block2d
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
content: "torch.";
color: #103d3e;
}
</style>
......@@ -142,6 +142,10 @@ vm.torch.nn
`class Linear <nn/Linear.html>`_
: Apply the linear transformation.
`class LayerNorm <nn/LayerNorm.html>`_
: Apply the layer normalization.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
`class LocalResponseNorm <nn/LocalResponseNorm.html>`_
: Apply the local response normalization.
`[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_.
......@@ -169,9 +173,16 @@ vm.torch.nn
`class Module <nn/Module.html>`_
: The base class of modules.
`class ModuleList <nn/ModuleList.html>`_
: The list module container.
`class MSELoss <nn/MSELoss.html>`_
: Compute the element-wise squared error.
`class MultiheadAttention <nn/MultiheadAttention.html>`_
: Apply the multihead attention.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class NLLLoss <nn/NLLLoss.html>`_
: Compute the negative likelihood loss with sparse labels.
......@@ -216,6 +227,9 @@ vm.torch.nn
: Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`class Sequential <nn/Sequential.html>`_
: The sequential module container.
`class Sigmoid <nn/Sigmoid.html>`_
: Apply the sigmoid function.
......@@ -237,6 +251,22 @@ vm.torch.nn
`class Tanh <nn/Tanh.html>`_
: Apply the tanh function.
`class TransformerDecoder <nn/TransformerDecoder.html>`_
: Standard transformer decoder.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class TransformerDecoderLayer <nn/TransformerDecoderLayer.html>`_
: Layer for a standard transformer decoder.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class TransformerEncoder <nn/TransformerEncoder.html>`_
: Standard transformer encoder.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class TransformerEncoderLayer <nn/TransformerEncoderLayer.html>`_
: Layer for a standard transformer encoder.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`class SyncBatchNorm <nn/SyncBatchNorm.html>`_
: Apply the sync batch normalization over input.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
......@@ -295,6 +325,7 @@ vm.torch.nn
nn/Identity
nn/KLDivLoss
nn/L1Loss
nn/LayerNorm
nn/LeakyReLU
nn/Linear
nn/LocalResponseNorm
......@@ -305,7 +336,9 @@ vm.torch.nn
nn/MaxPool2d
nn/MaxPool3d
nn/Module
nn/ModuleList
nn/MSELoss
nn/MultiheadAttention
nn/NLLLoss
nn/Parameter
nn/PReLU
......@@ -319,12 +352,17 @@ vm.torch.nn
nn/ReplicationPad3d
nn/RNN
nn/SELU
nn/Sequential
nn/Sigmoid
nn/SigmoidFocalLoss
nn/SmoothL1Loss
nn/Softmax
nn/Swish
nn/Tanh
nn/TransformerDecoder
nn/TransformerDecoderLayer
nn/TransformerEncoder
nn/TransformerEncoderLayer
nn/SyncBatchNorm
nn/Upsample
nn/UpsamplingBilinear2d
......
LayerNorm
=========
.. autoclass:: dragon.vm.torch.nn.LayerNorm
__init__
--------
.. automethod:: dragon.vm.torch.nn.LayerNorm.__init__
.. _torch.nn.functional.layer_norm(...): functional/layer_norm.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
ModuleList
==========
.. autoclass:: dragon.vm.torch.nn.ModuleList
__init__
--------
.. automethod:: dragon.vm.torch.nn.ModuleList.__init__
Methods
-------
add_module
##########
.. automethod:: dragon.vm.torch.nn.Module.add_module
:noindex:
append
######
.. automethod:: dragon.vm.torch.nn.ModuleList.apply
apply
#####
.. automethod:: dragon.vm.torch.nn.Module.apply
:noindex:
buffers
#######
.. automethod:: dragon.vm.torch.nn.Module.buffers
:noindex:
children
########
.. automethod:: dragon.vm.torch.nn.Module.children
:noindex:
cpu
###
.. automethod:: dragon.vm.torch.nn.Module.cpu
:noindex:
cuda
####
.. automethod:: dragon.vm.torch.nn.Module.cuda
:noindex:
double
######
.. automethod:: dragon.vm.torch.nn.Module.double
:noindex:
eval
####
.. automethod:: dragon.vm.torch.nn.Module.eval
:noindex:
extend
######
.. automethod:: dragon.vm.torch.nn.ModuleList.extend
float
#####
.. automethod:: dragon.vm.torch.nn.Module.float
:noindex:
half
####
.. automethod:: dragon.vm.torch.nn.Module.half
:noindex:
insert
######
.. automethod:: dragon.vm.torch.nn.ModuleList.insert
load_state_dict
###############
.. automethod:: dragon.vm.torch.nn.Module.load_state_dict
:noindex:
modules
#######
.. automethod:: dragon.vm.torch.nn.Module.modules
:noindex:
named_buffers
#############
.. automethod:: dragon.vm.torch.nn.Module.named_buffers
:noindex:
named_children
##############
.. automethod:: dragon.vm.torch.nn.Module.named_children
:noindex:
named_modules
#############
.. automethod:: dragon.vm.torch.nn.Module.named_modules
:noindex:
named_parameters
################
.. automethod:: dragon.vm.torch.nn.Module.named_parameters
:noindex:
parameters
##########
.. automethod:: dragon.vm.torch.nn.Module.parameters
:noindex:
register_buffer
###############
.. automethod:: dragon.vm.torch.nn.Module.register_buffer
:noindex:
register_forward_hook
#####################
.. automethod:: dragon.vm.torch.nn.Module.register_forward_hook
:noindex:
register_parameter
##################
.. automethod:: dragon.vm.torch.nn.Module.register_parameter
:noindex:
state_dict
##########
.. automethod:: dragon.vm.torch.nn.Module.state_dict
:noindex:
train
#####
.. automethod:: dragon.vm.torch.nn.Module.train
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.nn.Module.zero_grad
:noindex:
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
MultiheadAttention
==================
.. autoclass:: dragon.vm.torch.nn.MultiheadAttention
__init__
--------
.. automethod:: dragon.vm.torch.nn.MultiheadAttention.__init__
.. _torch.nn.functional.multi_head_attention_forward(...): functional/multi_head_attention_forward.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
Sequential
==========
.. autoclass:: dragon.vm.torch.nn.Sequential
__init__
--------
.. automethod:: dragon.vm.torch.nn.Sequential.__init__
Methods
-------
add_module
##########
.. automethod:: dragon.vm.torch.nn.Module.add_module
:noindex:
apply
#####
.. automethod:: dragon.vm.torch.nn.Module.apply
:noindex:
buffers
#######
.. automethod:: dragon.vm.torch.nn.Module.buffers
:noindex:
children
########
.. automethod:: dragon.vm.torch.nn.Module.children
:noindex:
cpu
###
.. automethod:: dragon.vm.torch.nn.Module.cpu
:noindex:
cuda
####
.. automethod:: dragon.vm.torch.nn.Module.cuda
:noindex:
double
######
.. automethod:: dragon.vm.torch.nn.Module.double
:noindex:
eval
####
.. automethod:: dragon.vm.torch.nn.Module.eval
:noindex:
float
#####
.. automethod:: dragon.vm.torch.nn.Module.float
:noindex:
forward
#######
.. automethod:: dragon.vm.torch.nn.Sequential.forward
half
####
.. automethod:: dragon.vm.torch.nn.Module.half
:noindex:
load_state_dict
###############
.. automethod:: dragon.vm.torch.nn.Module.load_state_dict
:noindex:
modules
#######
.. automethod:: dragon.vm.torch.nn.Module.modules
:noindex:
named_buffers
#############
.. automethod:: dragon.vm.torch.nn.Module.named_buffers
:noindex:
named_children
##############
.. automethod:: dragon.vm.torch.nn.Module.named_children
:noindex:
named_modules
#############
.. automethod:: dragon.vm.torch.nn.Module.named_modules
:noindex:
named_parameters
################
.. automethod:: dragon.vm.torch.nn.Module.named_parameters
:noindex:
parameters
##########
.. automethod:: dragon.vm.torch.nn.Module.parameters
:noindex:
register_buffer
###############
.. automethod:: dragon.vm.torch.nn.Module.register_buffer
:noindex:
register_forward_hook
#####################
.. automethod:: dragon.vm.torch.nn.Module.register_forward_hook
:noindex:
register_parameter
##################
.. automethod:: dragon.vm.torch.nn.Module.register_parameter
:noindex:
state_dict
##########
.. automethod:: dragon.vm.torch.nn.Module.state_dict
:noindex:
train
#####
.. automethod:: dragon.vm.torch.nn.Module.train
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.nn.Module.zero_grad
:noindex:
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
TransformerDecoder
==================
.. autoclass:: dragon.vm.torch.nn.TransformerDecoder
__init__
--------
.. automethod:: dragon.vm.torch.nn.TransformerDecoder.__init__
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
TransformerDecoderLayer
=======================
.. autoclass:: dragon.vm.torch.nn.TransformerDecoderLayer
__init__
--------
.. automethod:: dragon.vm.torch.nn.TransformerDecoderLayer.__init__
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
TransformerEncoder
=======================
.. autoclass:: dragon.vm.torch.nn.TransformerEncoder
__init__
--------
.. automethod:: dragon.vm.torch.nn.TransformerEncoder.__init__
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
TransformerEncoderLayer
=======================
.. autoclass:: dragon.vm.torch.nn.TransformerEncoderLayer
__init__
--------
.. automethod:: dragon.vm.torch.nn.TransformerEncoderLayer.__init__
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
......@@ -41,22 +41,22 @@ vm.torch.nn.functional
: Compute the sigmoid cross entropy with contiguous target.
`conv1d(...) <functional/conv1d.html>`_
: Apply 1d convolution to the input.
: Apply the 1d convolution to input.
`conv2d(...) <functional/conv2d.html>`_
: Apply 2d convolution to the input.
: Apply the 2d convolution to input.
`conv3d(...) <functional/conv3d.html>`_
: Apply 3d convolution to the input.
: Apply the 3d convolution to input.
`conv_transpose1d(...) <functional/conv_transpose1d.html>`_
: Apply 1d deconvolution to the input.
: Apply the 1d deconvolution to input.
`conv_transpose2d(...) <functional/conv_transpose2d.html>`_
: Apply 2d deconvolution to the input.
: Apply the 2d deconvolution to input.
`conv_transpose3d(...) <functional/conv_transpose3d.html>`_
: Apply 3d deconvolution to the input.
: Apply the 3d deconvolution to input.
`cross_entropy(...) <functional/cross_entropy.html>`_
: Compute the softmax cross entropy with sparse labels.
......@@ -66,11 +66,11 @@ vm.torch.nn.functional
`[Graves & Gomez, 2006] <http://www.cs.utoronto.ca/~graves/icml_2006.pdf>`_.
`depthwise_conv2d(...) <functional/depthwise_conv2d.html>`_
: Apply 2d depthwise convolution to the input.
: Apply the 2d depthwise convolution to input.
`[Chollet, 2016] <https://arxiv.org/abs/1610.02357>`_.
`drop_block2d(...) <functional/drop_block2d.html>`_
: Set the spatial blocks over input to zero randomly.
: Set the blocks over input to zero randomly.
`[Ghiasi et.al, 2018] <https://arxiv.org/abs/1810.12890>`_.
`drop_path(...) <functional/drop_path.html>`_
......@@ -78,7 +78,7 @@ vm.torch.nn.functional
`[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_.
`dropout(...) <functional/dropout.html>`_
: Set the elements of the input to zero randomly.
: Set the elements of input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
`elu(...) <functional/elu.html>`_
......@@ -102,6 +102,10 @@ vm.torch.nn.functional
`l1_loss(...) <functional/l1_loss.html>`_
: Compute the element-wise absolute value difference.
`layer_norm(...) <functional/layer_norm.html>`_
: Apply the layer normalization to input.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
`leaky_relu(...) <functional/leaky_relu.html>`_
: Apply the leaky rectified linear unit to input.
......@@ -130,6 +134,10 @@ vm.torch.nn.functional
`mse_loss(...) <functional/mse_loss.html>`_
: Compute the element-wise squared error.
`multi_head_attention_forward(...) <functional/multi_head_attention_forward.html>`_
: Apply the multihead attention to input.
`[Vaswani et.al, 2017] <https://arxiv.org/abs/1706.03762>`_.
`nll_loss(...) <functional/nll_loss.html>`_
: Compute the negative likelihood loss with sparse labels.
......@@ -216,6 +224,7 @@ vm.torch.nn.functional
functional/l1_loss
functional/leaky_relu
functional/linear
functional/layer_norm
functional/local_response_norm
functional/log_softmax
functional/interpolate
......@@ -223,6 +232,7 @@ vm.torch.nn.functional
functional/max_pool2d
functional/max_pool3d
functional/mse_loss
functional/multi_head_attention_forward
functional/nll_loss
functional/normalize
functional/pad
......
layer_norm
==========
.. autofunction:: dragon.vm.torch.nn.functional.layer_norm
.. _torch.nn.LayerNorm(...): ../LayerNorm.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
sparse_softmax_cross_entropy
multi_head_attention_forward
============================
.. autofunction:: dragon.losses.sparse_softmax_cross_entropy
.. autofunction:: dragon.vm.torch.nn.functional.multi_head_attention_forward
.. _torch.nn.MultiheadAttention(...): ../MultiheadAttention.html
.. raw:: html
<style>
h1:before {
content: "dragon.losses.";
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
......@@ -27,6 +27,9 @@ vm.torch.nn.init
`normal_(...) <init/normal_.html>`_
: Fill tensor from a normal distribution.
`ones_(...) <init/ones_.html>`_
: Fill tensor with ones.
`uniform_(...) <init/uniform_.html>`_
: Fill tensor from an uniform distribution.
......@@ -36,6 +39,9 @@ vm.torch.nn.init
`xavier_uniform_(...) <init/xavier_uniform_.html>`_
: Fill tensor from a xavier uniform distribution.
`zeros_(...) <init/zeros_.html>`_
: Fill tensor with zeros.
.. toctree::
:hidden:
......@@ -46,9 +52,11 @@ vm.torch.nn.init
init/kaiming_normal_
init/kaiming_uniform_
init/normal_
init/ones_
init/uniform_
init/xavier_normal_
init/xavier_uniform_
init/zeros_
.. raw:: html
......
ones\_
======
.. autofunction:: dragon.vm.torch.nn.init.ones_
.. raw:: html
<style>
h1:before {
content: "torch.nn.init.";
color: #103d3e;
}
</style>
zeros\_
=======
.. autofunction:: dragon.vm.torch.nn.init.zeros_
.. raw:: html
<style>
h1:before {
content: "torch.nn.init.";
color: #103d3e;
}
</style>
......@@ -10,11 +10,6 @@ __init__
Methods
-------
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
......@@ -25,6 +20,11 @@ step
.. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex:
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
......
......@@ -10,10 +10,6 @@ __init__
Methods
-------
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
......@@ -22,6 +18,10 @@ step
####
.. automethod:: dragon.vm.torch.optim.Optimizer.step
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
......
......@@ -10,11 +10,6 @@ __init__
Methods
-------
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
......@@ -25,6 +20,11 @@ step
.. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex:
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
......
......@@ -10,11 +10,6 @@ __init__
Methods
-------
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
......@@ -25,6 +20,11 @@ step
.. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex:
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
......
tile
====
.. autofunction:: dragon.vm.torch.tile
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
tril
====
.. autofunction:: dragon.vm.torch.tril
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
triu
====
.. autofunction:: dragon.vm.torch.triu
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_CONTEXT_CNML_H_
#define DRAGON_CORE_CONTEXT_CNML_H_
#include "dragon/core/common.h"
struct cnrtStream;
struct cnmlCpuTensor;
struct cnmlTensor;
struct cnmlFusionOp;
typedef struct cnrtStream* cnrtStream_t;
typedef struct cnmlCpuTensor* cnmlCpuTensor_t;
typedef struct cnmlTensor* cnmlTensor_t;
typedef struct cnmlFusionOp* cnmlFusionOp_t;
namespace dragon {
/*!
* \brief The cnml device context.
*/
class CNMLContext {
public:
/*! \brief Default constructor */
CNMLContext() : device_id_(0), random_seed_(DEFAULT_RNG_SEED) {}
/*! \brief Constructor with the device index */
explicit CNMLContext(int device)
: device_id_(device), random_seed_(DEFAULT_RNG_SEED) {}
/*! \brief Constructor with the device option */
explicit CNMLContext(const DeviceOption& option)
: device_id_(option.device_id()),
random_seed_(
option.has_random_seed() ? option.random_seed()
: DEFAULT_RNG_SEED) {
CHECK_EQ(option.device_type(), PROTO_CNML);
}
/*! \brief Allocate a block of memory */
static void* New(size_t size) {
return nullptr;
}
/*! \brief Set a memory block to the given value */
static void Memset(size_t n, void* ptr, int value) {}
/*! \brief Set a memory block to the given value asynchronously */
void MemsetAsync(size_t n, void* ptr, int value) {
Memset(n, ptr, value);
}
/*! \brief Copy a memory block to the destination */
template <class DestContext, class SrcContext>
static void Memcpy(size_t n, void* dest, const void* src) {}
/*! \brief Copy a memory block to the destination asynchronously */
template <class DestContext, class SrcContext>
void MemcpyAsync(size_t n, void* dest, const void* src) {
Memcpy<DestContext, SrcContext>(dest, src, n);
}
/*! \brief Deallocate a memory block */
static void Delete(void* ptr) {}
/*! \brief Switch to the device in current thread */
void SwitchToDevice() {
SwitchToDevice(0);
}
/*! \brief Switch to the device and select given stream in current thread */
void SwitchToDevice(int stream) {}
/*! \brief Wait for the dispatched computation to complete */
void FinishDeviceComputation() {}
/*! \brief Return the cnrt stream */
cnrtStream_t cnrt_stream() {
return cnrt_stream(device_id_, stream_id_);
}
/*! \brief Return the specified cnrt stream */
static cnrtStream_t cnrt_stream(int device_id, int stream_id) {
return (cnrtStream_t) nullptr;
}
/*! \brief Return the device index */
int device() const {
return device_id_;
}
/*! \brief Return the stream index */
int stream() const {
return stream_id_;
}
private:
int device_id_, stream_id_ = 1, random_seed_;
unique_ptr<std::mt19937> rand_generator_;
};
} // namespace dragon
#endif // DRAGON_CORE_CONTEXT_CNML_H_
#include "dragon/core/gradient.h"
namespace dragon {
void GradientTape::CreateGradientDefs(
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& grad_targets) {
def_.clear_op();
Set<string> split_grads;
Map<string, string> sources_to_grads;
Map<string, string> targets_to_grads;
Map<string, int> inputs_count, splits_count;
// Function to check if grad op can be removed.
auto IsNoGradient = [&](const OperatorDef& op,
vector<pair<string, int>>& init_grads) mutable {
if (NoGradientRegistry()->Has(op.type())) {
return true;
}
bool maybe_skip = false;
for (int i = 0; i < op.output_size(); ++i) {
const auto& y = op.output(i);
if (!sources_to_grads.count(y)) {
maybe_skip = true;
if (targets_to_grads.count(y)) {
init_grads.push_back({y, i});
sources_to_grads[y] = y + "_grad";
}
}
}
return maybe_skip && init_grads.empty() && op.output_size() == 1;
};
// Set the gradient of targets.
for (int i = 0; i < targets.size(); ++i) {
targets_to_grads[targets[i]] =
i < grad_targets.size() ? grad_targets[i] : "";
}
// PLAY for the forward.
for (auto* op : op_defs) {
if (NoGradientRegistry()->Has(op->type())) continue;
for (const auto& x : op->input()) {
if (std::find(op->output().begin(), op->output().end(), x) ==
op->output().end()) {
inputs_count[x]++;
}
}
}
// PLAY for the backward.
for (auto op_iter = op_defs.rbegin(); op_iter != op_defs.rend(); op_iter++) {
const auto& op = *op_iter;
vector<pair<string, int>> init_grads;
if (IsNoGradient(*op, init_grads)) continue;
vector<string> grad_ys;
for (const auto& y : op->output()) {
const auto& iter = sources_to_grads.find(y);
grad_ys.emplace_back(iter != sources_to_grads.end() ? iter->second : "");
}
CHECK(GradientRegistry()->Has(op->type()))
<< "\nMissing gradient maker for " << op->type() << ".";
unique_ptr<GradientMakerBase> maker(
GradientRegistry()->Create(op->type(), *op, grad_ys));
maker->Make();
vector<OperatorDef> gather_defs;
for (auto& grad_def : maker->grad_defs()) {
for (int i = 0; i < grad_def.output_size(); ++i) {
const auto& grad_ys = grad_def.input();
const auto& grad_x = grad_def.output(i);
if (std::find(grad_ys.begin(), grad_ys.end(), grad_x) != grad_ys.end())
continue;
int x_index = -1;
for (int j = 0; j < maker->grad_inputs().size(); ++j) {
if (grad_x == maker->grad_inputs()[j]) x_index = j;
}
if (x_index == -1) continue;
const auto& x = op->input(x_index);
if (inputs_count[x] <= 1) continue;
auto split_prefix = grad_x + "_split_";
auto grad_x_split = split_prefix + str::to(splits_count[grad_x]++);
split_grads.insert(grad_x_split);
if (splits_count[grad_x] == inputs_count[x]) {
gather_defs.emplace_back(CreateOperatorDef(
"GradientGather",
"",
vector<string>({}),
vector<string>({grad_x}),
vector<Argument>(),
grad_def.device_option()));
for (int j = 0; j < splits_count[grad_x]; ++j) {
auto iter = split_grads.find(split_prefix + str::to(j));
if (iter != split_grads.end()) {
gather_defs.back().add_input(*iter);
}
}
}
grad_def.set_output(i, grad_x_split);
}
}
for (int i = 0; i < op->input_size(); ++i) {
sources_to_grads[op->input(i)] = maker->grad_inputs()[i];
}
if (init_grads.size() > 0) {
Argument values;
values.set_name("values");
vector<string> inputs, outputs;
auto fills = maker->grad_defaults();
for (auto& iter : init_grads) {
const auto& grad = targets_to_grads[iter.first];
inputs.emplace_back(grad.empty() ? iter.first : grad);
outputs.emplace_back(iter.first + "_grad");
values.add_floats(grad.empty() ? fills[iter.second] : -100.f);
}
def_.add_op()->CopyFrom(CreateOperatorDef(
"GradientFill",
"",
inputs,
outputs,
vector<Argument>({values}),
op->device_option()));
}
for (const auto& grad_def : maker->grad_defs()) {
def_.add_op()->CopyFrom(grad_def);
}
for (const auto& gather_def : gather_defs) {
def_.add_op()->CopyFrom(gather_def);
}
}
}
void GradientTape::Optimize(const vector<string>& sources) {
Set<int> noop_indices;
Set<string> required_grads;
Map<string, int> inputs_count;
Map<string, string> grads_to_buffers;
Map<string, pair<int, string>> splits;
for (int op_index = 0; op_index < def_.op_size(); ++op_index) {
const auto& op = def_.op(op_index);
if (!str::find(op.type(), "Gradient")) continue;
// Count noops.
if (op.type() == "GradientGather") noop_indices.insert(op_index);
// Count inputs.
for (const auto& input : op.input()) {
inputs_count[input] += 1;
}
}
// Initialize the required grads before optimization.
for (const auto& input : sources) {
required_grads.insert(input + "_grad");
}
for (auto op_index : noop_indices) {
const auto& op = def_.op(op_index);
if (op.type() == "GradientGather") {
if (inputs_count.count(op.output(0)) == 0 &&
required_grads.count(op.output(0)) == 0) {
for (const auto& input : op.input()) {
inputs_count.erase(input);
}
} else {
string first_input;
for (const auto& input : op.input()) {
if (!input.empty()) {
if (first_input.empty()) first_input = input;
splits[input] = {op_index, first_input};
}
}
}
}
}
optimized_def_ = def_;
optimized_def_.clear_op();
for (int op_index = 0; op_index < def_.op_size(); ++op_index) {
if (noop_indices.count(op_index)) continue;
const auto& op = def_.op(op_index);
optimized_def_.add_op()->CopyFrom(op);
if (!str::find(op.type(), "Gradient")) continue;
for (const auto& output : op.output()) {
// Decouple the gathering of split grads.
const auto& split_iter = splits.find(output);
if (split_iter != splits.end()) {
auto& gather_op = def_.op(split_iter->second.first);
auto* decouple_op = optimized_def_.add_op();
decouple_op->CopyFrom(gather_op);
decouple_op->clear_input();
if (output != split_iter->second.second) {
decouple_op->set_type("GradientAdd");
decouple_op->add_input(gather_op.output(0));
const auto& count_iter = inputs_count.find(gather_op.output(0));
if (count_iter != inputs_count.end()) count_iter->second++;
}
decouple_op->add_input(output);
if (!op.arg().empty()) {
const auto& arg = *(op.arg().end() - 1);
if (arg.name() == "cache_key") {
auto* new_arg = decouple_op->add_arg();
const auto& dev = decouple_op->device_option();
new_arg->set_name("cache_key");
new_arg->set_s(
decouple_op->type() + "/" + str::to(dev.device_type()) + ":" +
str::to(dev.device_id()));
}
}
}
}
}
// Prepare the pool
int buffer_index = 0;
std::deque<string> pool;
auto get_buffer = [&]() mutable {
if (pool.empty()) {
return "shared/buffer/grad:" + str::to(++buffer_index);
} else {
auto buffer = pool.back();
pool.pop_back();
return buffer;
}
};
for (int op_index = 0; op_index < optimized_def_.op_size(); ++op_index) {
auto* op = optimized_def_.mutable_op(op_index);
if (!str::find(op->type(), "Gradient")) continue;
// Check output aliases.
vec32_t output_aliases(op->output_size(), -1);
for (int i = 0; i < op->output_size(); ++i) {
for (int j = 0; j < op->input_size(); ++j) {
if (op->output(i) != op->input(j)) continue;
output_aliases[i] = j;
break;
}
}
// Rewrite inputs.
vector<string> dead_buffers;
for (int i = 0; i < op->input_size(); ++i) {
const string& input = op->input(i);
const auto& count_iter = inputs_count.find(input);
if (count_iter == inputs_count.end()) continue;
count_iter->second--;
const auto& buffer_iter = grads_to_buffers.find(input);
if (buffer_iter == grads_to_buffers.end()) continue;
if (count_iter->second == 0) {
dead_buffers.emplace_back(buffer_iter->second);
}
op->set_input(i, buffer_iter->second);
}
// Rewrite outputs.
for (int i = 0; i < op->output_size(); ++i) {
const string& output = op->output(i);
if (output.empty() || required_grads.count(output) > 0) continue;
if (inputs_count.count(output) == 0) {
op->mutable_output(i)->clear();
continue;
}
if (output_aliases[i] >= 0) {
op->set_output(i, op->input(output_aliases[i]));
} else {
*op->mutable_output(i) = grads_to_buffers[output] = get_buffer();
}
}
// Update pool.
for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer);
}
}
}
DEFINE_REGISTRY(
GradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
DEFINE_REGISTRY(
NoGradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
} // namespace dragon
......@@ -10,32 +10,21 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_
#define DRAGON_CORE_OPERATOR_GRADIENT_H_
#ifndef DRAGON_CORE_GRADIENT_H_
#define DRAGON_CORE_GRADIENT_H_
#include "dragon/core/common.h"
#include "dragon/core/operator.h"
#include "dragon/core/registry.h"
#include "dragon/utils/proto_utils.h"
namespace dragon {
struct GradientPack {
GradientPack(
const vector<OperatorDef>& grad_defs,
const vector<string>& grad_inputs,
const vector<float>& defaults)
: grad_defs(grad_defs), grad_inputs(grad_inputs), defaults(defaults) {}
vector<OperatorDef> grad_defs;
vector<string> grad_inputs;
vector<float> defaults;
};
class GradientMakerBase {
public:
GradientMakerBase(const OperatorDef& def, const vector<string>& grad_outputs)
: def(def), grad_inputs_(def.input_size()), grad_outputs_(grad_outputs) {}
: def_(def),
grad_outputs_(grad_outputs),
grad_inputs_(def.input_size()) {}
virtual ~GradientMakerBase() {}
......@@ -49,45 +38,54 @@ class GradientMakerBase {
return true;
}
virtual GradientPack Make() {
auto new_defs = MakeDef();
if (def.has_cache_key()) {
// Attach the handle to name if having cache key
for (size_t i = 0; i < new_defs.size(); i++) {
new_defs[i].set_name(def.name());
virtual void Make() {
CreateGradientDefs();
string cache_key;
if (!def_.arg().empty()) {
const auto& arg = *(def_.arg().end() - 1);
if (arg.name() == "cache_key") cache_key = arg.s();
}
Argument new_arg;
new_arg.set_name("handle");
new_arg.set_s(def_.name());
for (auto& grad_def : grad_defs_) {
if (CopyDeviceOption() && def_.has_device_option()) {
grad_def.mutable_device_option()->CopyFrom(def_.device_option());
}
if (CopyArguments() && !def_.arg().empty()) {
grad_def.mutable_arg()->MergeFrom(def_.arg());
if (!cache_key.empty()) grad_def.mutable_arg()->RemoveLast();
}
} else {
// Otherwise, just put it into the arguments
Argument arg;
arg.set_name("handle");
arg.set_s(def.name());
for (size_t i = 0; i < new_defs.size(); i++) {
new_defs[i].add_arg()->CopyFrom(arg);
grad_def.add_arg()->CopyFrom(new_arg);
}
if (!cache_key.empty()) {
cache_key += "/grad";
new_arg.set_name("cache_key");
for (int i = 0; i < grad_defs_.size(); ++i) {
new_arg.set_s(cache_key + (i > 0 ? ("/" + str::to(i)) : ""));
grad_defs_[i].add_arg()->CopyFrom(new_arg);
}
}
return GradientPack(new_defs, grad_inputs_, defaults());
};
virtual vector<OperatorDef> MakeDef() {
return vector<OperatorDef>();
}
virtual void CreateGradientDefs() {}
template <class... Args>
static vector<OperatorDef> SingleDef(const Args&... args) {
return vector<OperatorDef>{MakeOperatorDef(args...)};
void AddGradientDef(const Args&... args) {
grad_defs_.emplace_back(CreateOperatorDef(args...));
}
const string I(const int i) const {
return i < int(def.input_size()) ? def.input(i) : "";
return i < int(def_.input_size()) ? def_.input(i) : "";
}
const string O(const int i) const {
return i < int(def.output_size()) ? def.output(i) : "";
return i < int(def_.output_size()) ? def_.output(i) : "";
}
string GI(const int i) {
if (i >= int(grad_inputs_.size())) return "";
grad_inputs_[i] = def.input(i) + "_grad";
grad_inputs_[i] = def_.input(i) + "_grad";
return grad_inputs_[i];
}
......@@ -95,82 +93,89 @@ class GradientMakerBase {
return i < int(grad_outputs_.size()) ? grad_outputs_[i] : "";
}
virtual vector<float> defaults() {
const OperatorDef& def() const {
return def_;
}
vector<OperatorDef>& grad_defs() {
return grad_defs_;
}
vector<string>& grad_inputs() {
return grad_inputs_;
}
virtual vector<float> grad_defaults() {
return vector<float>(grad_outputs_.size(), 1.f);
}
protected:
const OperatorDef& def;
vector<string> grad_inputs_;
const OperatorDef& def_;
vector<OperatorDef> grad_defs_;
const vector<string>& grad_outputs_;
vector<string> grad_inputs_;
};
DRAGON_API GradientPack
MakeGradientForOp(const OperatorDef& op_def, const vector<string>& g_outputs);
#define GRADIENT_MAKER_CTOR(name) \
name(const OperatorDef& def, const vector<string>& g_output) \
: GradientMakerBase(def, g_output) {}
name(const OperatorDef& def, const vector<string>& grad_outputs) \
: GradientMakerBase(def, grad_outputs) {}
class NoGradient : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(NoGradient);
vector<OperatorDef> MakeDef() override {
return vector<OperatorDef>();
}
};
namespace {
// Here we define some common gradient makers
// Reuse them to make the codes cleaner
class GenericGradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GenericGradientMaker);
void CreateGradientDefs() override {
/*!
* Inputs: X1, X2, ..., Xn, dY1, dY2, ..., dYm
* Outputs: dX1, dX2, ..., dXn
* X1, X2, ..., Xn, dY1, dY2, ..., dYm
* dX1, dX2, ..., dXn
*/
GRADIENT_MAKER_CTOR(GenericGradientMaker);
vector<OperatorDef> MakeDef() override {
vector<string> inputs, outputs;
for (const auto& input : def.input())
inputs.push_back(input);
for (int i = 0; i < def.output_size(); ++i)
inputs.push_back(GO(i));
for (int i = 0; i < def.input_size(); ++i)
outputs.push_back(GI(i));
return SingleDef(def.type() + "Gradient", "", inputs, outputs);
vector<string> inputs({def().input().begin(), def().input().end()});
vector<string> outputs;
for (int i = 0; i < def().output_size(); ++i) {
inputs.emplace_back(GO(i));
}
for (int i = 0; i < def().input_size(); ++i) {
outputs.emplace_back(GI(i));
}
AddGradientDef(def().type() + "Gradient", "", inputs, outputs);
}
};
class SimpleGradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(SimpleGradientMaker);
void CreateGradientDefs() override {
/*!
* Inputs: dY1, dY2, ..., dYm
* Outputs: dX1, dX2, ..., dXn
* dY1, dY2, ..., dYm
* dX1, dX2, ..., dXn
*/
GRADIENT_MAKER_CTOR(SimpleGradientMaker);
vector<OperatorDef> MakeDef() override {
vector<string> inputs, outputs;
for (int i = 0; i < def.output_size(); ++i)
inputs.push_back(GO(i));
for (int i = 0; i < def.input_size(); ++i)
outputs.push_back(GI(i));
return SingleDef(def.type() + "Gradient", "", inputs, outputs);
for (int i = 0; i < def().output_size(); ++i) {
inputs.emplace_back(GO(i));
}
for (int i = 0; i < def().input_size(); ++i) {
outputs.emplace_back(GI(i));
}
AddGradientDef(def().type() + "Gradient", "", inputs, outputs);
}
};
class InplaceGradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(InplaceGradientMaker);
void CreateGradientDefs() override {
/*!
* Inputs: Y, dY
* Outputs: dX
* Y, dY
* dX
*/
GRADIENT_MAKER_CTOR(InplaceGradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
AddGradientDef(
def().type() + "Gradient",
"",
vector<string>({O(0), GO(0)}),
vector<string>({GI(0)}));
......@@ -179,6 +184,39 @@ class InplaceGradientMaker final : public GradientMakerBase {
} // namespace
class DRAGON_API GradientTape {
public:
GradientTape() {}
GradientTape(const GraphDef& def) : def_(def) {}
GradientTape(
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& grad_targets) {
CreateGradientDefs(op_defs, targets, grad_targets);
}
/*! \brief Create gradient defs */
void CreateGradientDefs(
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& grad_targets);
/*! \brief Optimize gradient computations */
void Optimize(const vector<string>& sources = vector<string>());
/*! \brief Return gradient defs */
const GraphDef& def() {
if (optimized_def_.op_size() > 0) {
return optimized_def_;
}
return def_;
}
private:
GraphDef def_;
GraphDef optimized_def_;
};
DECLARE_REGISTRY(
GradientRegistry,
GradientMakerBase,
......@@ -191,7 +229,6 @@ DECLARE_REGISTRY(
const OperatorDef&,
const vector<string>&);
// Defined in the operator.cc
#define REGISTER_GRADIENT(name, ...) \
REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
......@@ -201,4 +238,4 @@ DECLARE_REGISTRY(
} // namespace dragon
#endif // DRAGON_CORE_OPERATOR_GRADIENT_H_
#endif // DRAGON_CORE_GRADIENT_H_
#include <regex>
#include "dragon/core/graph.h"
#include "dragon/core/graph_gradient.h"
#include "dragon/core/graph_optimizer.h"
#include "dragon/core/workspace.h"
......@@ -9,99 +8,78 @@ namespace dragon {
GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
: def_(def), ws_(ws), name_(def.name()), phase_("TEST") {
// Collect arguments
// Collect arguments.
for (auto& arg : def_.arg()) {
CHECK_GT(arg.name().size(), 0);
CHECK_EQ(args_.count(arg.name()), 0);
args_[arg.name()] = &arg;
if (arg.name() == "phase") phase_ = arg.s();
}
// Collect outputs
// Check inputs.
Set<string> outputs;
for (const auto& op : def.op()) {
for (const auto& input : op.input())
CHECK(outputs.count(input) || ws_->HasTensor(input))
<< "\nThe input <" << input << "> is not in graph.";
<< "\nInput " << input << " is not in the graph.";
for (const auto& output : op.output()) {
outputs.insert(output);
}
}
// Check targets
Set<string> targets;
for (const auto& target : def.output()) {
CHECK(outputs.count(target) || ws_->HasTensor(target))
<< "\nThe output <" << target << "> is not in graph.";
targets.insert(target);
}
// Check gradients
for (const auto& grad_info : def.grad_info()) {
const auto& y = grad_info.y();
CHECK_GT(targets.count(y), 0)
<< "\nThe derivative target <" << y << "> is not in outputs.";
for (const auto& x : grad_info.xs()) {
CHECK(outputs.count(x) || ws_->HasTensor(x))
<< "\nThe differentiated input <" << x << "> is not in graph.";
}
// Check outputs.
for (const auto& output : def.output()) {
CHECK(outputs.count(output) || ws_->HasTensor(output))
<< "\nOutput " << output << " is not in the graph.";
}
}
bool Graph::Create(const GraphDef& def) {
this->optimized_def_ = def; // Store for debugging
this->optimized_def_ = def;
bool has_device_option = def.has_device_option();
for (int i = 0; i < def.op_size(); i++) {
auto op_def(def.op(i));
LOG(DEBUG) << "Create Operator " << op_def.name() << ": " << op_def.type();
// Inherit device option if necessary
// Inherit device if not provided.
if (!op_def.has_device_option() && has_device_option) {
op_def.mutable_device_option()->CopyFrom(def.device_option());
}
Argument arg;
// For the last operator, enforce the synchronization
if (i == def.op_size() - 1) {
arg.set_name("do_sync");
arg.set_i(1);
op_def.add_arg()->CopyFrom(arg);
}
cached_ops_.push_back(NewOperator(op_def, ws_));
cached_ops_.back()->set_output_aliases(output_aliases_);
LOG(DEBUG) << "Create: " << op_def.name() << " [" << op_def.type() << "]";
ops_.push_back(OperatorBase::New(op_def, ws_));
ops_.back()->set_output_aliases(output_aliases_);
}
return true;
}
Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
// Apply the optimizations
// Apply the optimizations.
GraphDef def_v2(def);
GraphOptimizer graph_optimizer(ws);
GraphGradientMaker gradient_maker;
GraphOptimizer optimizer(ws);
Map<string, vec32_t> subgraph_indices;
int opt = 3; // default: O3
int opt = 1;
if (args().count("optimization")) opt = arg("optimization").i();
if (opt >= 1) def_v2 = graph_optimizer.EliminateUnused(def);
if (opt >= 2) graph_optimizer.PlanInplace(def_v2, output_aliases_);
if (opt >= 2) optimizer.PlanInplace(def_v2, output_aliases_);
if (opt >= 3) {
if (phase() == "TRAIN") {
def_v2 = graph_optimizer.PlanCheckpoint(def_v2, subgraph_indices);
def_v2 = gradient_maker.Optimize(def_v2);
def_v2 = optimizer.PlanCheckpoint(def_v2, subgraph_indices);
if (args().count("grad_sources")) {
GradientTape tape(def_v2);
auto& grad_sources = args_["grad_sources"]->strings();
tape.Optimize({grad_sources.begin(), grad_sources.end()});
def_v2 = tape.def();
}
} else {
def_v2 = graph_optimizer.SimulateGC(def_v2);
def_v2 = optimizer.EliminateIntermediates(def_v2);
}
}
// Create
// Create graph.
Create(def_v2);
// Recomputation and SubGraph
// Create subgraphs.
if (subgraph_indices.size() > 0) {
Map<string, vector<OperatorBase*>> subgraph;
for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>();
for (auto op_idx : subgraph_indices[it.first])
subgraph[it.first].push_back(cached_ops_[op_idx]);
subgraph[it.first].push_back(ops_[op_idx]);
}
for (auto* op : cached_ops_) {
for (auto* op : ops_) {
op->set_subgraph(subgraph);
}
}
......@@ -111,27 +89,28 @@ bool Graph::Run(int stream, const string& include, const string& exclude) {
unique_ptr<std::regex> regex_incl, regex_excl;
if (!include.empty()) regex_incl.reset(new std::regex(include));
if (!exclude.empty()) regex_excl.reset(new std::regex(exclude));
LOG(DEBUG) << "Run Graph: " << name();
for (auto* op : cached_ops_) {
LOG(DEBUG) << "Run: " << name();
for (auto* op : ops_) {
if (regex_incl && !regex_match(op->type(), *regex_incl)) continue;
if (regex_excl && regex_match(op->type(), *regex_excl)) continue;
op->SwitchToPhase(phase());
LOG(DEBUG) << "Run Op: " << op->name();
LOG(DEBUG) << "Run: " << op->name();
op->Run(stream);
LOG(DEBUG) << "Finish Op: " << op->name();
LOG(DEBUG) << "Finish: " << op->name();
}
LOG(DEBUG) << "Finish: " << name();
return true;
}
GraphBase* NewGraph(const GraphDef& def, Workspace* ws) {
if (!def.has_graph_type() || def.graph_type().empty()) {
return new Graph(def, ws); // Sequential scheduler
GraphBase* GraphBase::New(const GraphDef& def, Workspace* ws) {
if (!def.has_type() || def.type().empty()) {
// Sequential scheduler.
return new Graph(def, ws);
}
return GraphRegistry()->Create(def.graph_type(), def, ws);
return GraphRegistry()->Create(def.type(), def, ws);
}
/* Graph Registry */
DEFINE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*);
} // namespace dragon
......@@ -18,6 +18,8 @@
namespace dragon {
class Workspace;
/*!
* \brief The base graph class.
*/
......@@ -29,6 +31,9 @@ class DRAGON_API GraphBase {
/*! \brief Destructor */
virtual ~GraphBase() {}
/*! \brief Create a new graph */
static GraphBase* New(const GraphDef& def, Workspace* ws);
/*! \brief Create graph in the workspace */
virtual bool Create(const GraphDef& def) = 0;
......@@ -102,8 +107,8 @@ class Graph : public GraphBase {
/*! \brief Destructor */
virtual ~Graph() {
for (auto* cached_op : cached_ops_) {
delete cached_op;
for (auto* op : ops_) {
delete op;
}
}
......@@ -117,16 +122,13 @@ class Graph : public GraphBase {
const string& exclude = "") override;
protected:
/*! \brief The cached operators */
vector<OperatorBase*> cached_ops_;
/*! \brief The created operators */
vector<OperatorBase*> ops_;
/*! \brief The candidate output aliases */
/*! \brief The output aliases */
Map<string, Set<string>> output_aliases_;
};
/*! \brief Create a graph from the raw def */
GraphBase* NewGraph(const GraphDef&, Workspace*);
/* Macros */
DECLARE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*);
......
#include "dragon/core/graph_gradient.h"
#include "dragon/core/operator.h"
namespace dragon {
bool GraphGradientMaker::CheckGrad(
const OperatorDef& op,
const Set<string>& targets,
vector<pair<string, int>>& gen_grads) {
if (NoGradientRegistry()->Has(op.type())) {
return true;
}
bool maybe_skip = false;
for (int i = 0; i < op.output_size(); ++i) {
const auto& out = op.output(i);
if (!inputs_to_grads_.count(out)) {
maybe_skip = true;
if (targets.count(out)) {
gen_grads.push_back({out, i});
inputs_to_grads_[out] = out + "_grad";
}
}
}
return maybe_skip && gen_grads.empty() && op.output_size() == 1;
}
void GraphGradientMaker::Make(
const vector<OperatorDef*>& ops,
const vector<string>& targets,
const vector<string>& input_grads,
GraphDef& graph) {
Set<string> split_grads, targets_v2;
Map<string, int> inputs_count, grads_count;
// PLAY for the forward
for (auto* op : ops) {
if (NoGradientRegistry()->Has(op->type())) continue;
for (const auto& input : op->input()) {
bool input_in_outputs = false;
for (auto& output : op->output())
if (output == input) {
input_in_outputs = true;
break;
}
// Avoid to count the duplicate input (i.e. the in-place output)
if (!input_in_outputs) inputs_count[input]++;
}
}
// Set the gradient of targets
for (int i = 0; i < targets.size(); ++i) {
if (i < input_grads.size()) {
inputs_to_grads_[targets[i]] = input_grads[i];
}
targets_v2.insert(targets[i]);
}
// PLAY for the backward
for (int op_idx = (int)ops.size() - 1; op_idx >= 0; --op_idx) {
const auto& op = *ops[op_idx];
// Generate def by registered gradient maker
vector<pair<string, int>> gen_grads;
vector<string> grad_outputs;
bool is_skip = CheckGrad(op, targets_v2, gen_grads);
for (const auto& out : op.output()) {
string grad_out = "";
const auto& it = inputs_to_grads_.find(out);
if (it != inputs_to_grads_.end()) grad_out = it->second;
grad_outputs.push_back(grad_out);
}
auto pack = MakeGradientForOp(op, grad_outputs);
// Split and gather gradient for multi-used inputs
vector<OperatorDef> gather_ops;
for (auto& grad_def : pack.grad_defs) {
if (!grad_def.has_name()) {
grad_def.set_name(GetOperatorName());
}
for (int i = 0; i < grad_def.output_size(); ++i) {
const auto& grad_name = grad_def.output(i);
int original_index = -1;
for (int j = 0; j < pack.grad_inputs.size(); ++j) {
if (grad_name == pack.grad_inputs[j]) {
original_index = j;
}
}
if (original_index == -1) continue;
bool output_in_inputs = false;
for (const auto& name : grad_def.input()) {
if (grad_name == name) {
output_in_inputs = true;
break;
}
}
if (output_in_inputs) continue;
// Detect a split branch
const auto& original_name = op.input(original_index);
if (inputs_count[original_name] > 1) {
auto grad_name_v2 =
grad_name + "_autosplit_" + str::to(grads_count[grad_name]++);
if (!is_skip) split_grads.insert(grad_name_v2);
if (grads_count[grad_name] == inputs_count[original_name]) {
auto gather_op = MakeOperatorDef(
"GradientGather",
GetOperatorName(),
vector<string>({}),
vector<string>({grad_name}));
if (grad_def.has_device_option()) {
gather_op.mutable_device_option()->CopyFrom(
grad_def.device_option());
}
for (int j = 0; j < grads_count[grad_name]; j++) {
auto name = grad_name + "_autosplit_" + str::to(j);
if (split_grads.count(name)) gather_op.add_input(name);
}
gather_ops.push_back(gather_op);
}
*grad_def.mutable_output(i) = grad_name_v2;
}
}
}
// Add gradient ops
if (!is_skip) {
for (int i = 0; i < op.input_size(); ++i) {
inputs_to_grads_[op.input(i)] = pack.grad_inputs[i];
}
// Add ``GradientGenerateOp``
if (gen_grads.size() > 0) {
vector<string> inputs, outputs;
Argument arg_defaults;
arg_defaults.set_name("defaults");
for (auto& gen_grad : gen_grads) {
inputs.push_back(gen_grad.first);
outputs.emplace_back(gen_grad.first + "_grad");
arg_defaults.add_floats(pack.defaults[gen_grad.second]);
}
auto gen_op = MakeOperatorDef(
"GradientGenerate",
GetOperatorName(),
inputs,
outputs,
vector<Argument>({arg_defaults}));
if (op.has_device_option()) {
gen_op.mutable_device_option()->CopyFrom(op.device_option());
}
graph.add_op()->CopyFrom(gen_op);
}
// Add ``GradientOp``
for (const auto& grad_def : pack.grad_defs) {
graph.add_op()->CopyFrom(grad_def);
}
}
// Add ``GradientGatherOp``
for (const auto& gather_op : gather_ops) {
graph.add_op()->CopyFrom(gather_op);
}
}
}
GraphDef GraphGradientMaker::Optimize(const GraphDef& graph) {
Set<int> invalid_ops;
Map<string, int> ref_count;
Map<string, pair<int, string>> gather_map;
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx);
if (!str::find(op.type(), "Gradient")) continue;
// Flag the gathering gradients
if (op.type() == "GradientGather") {
invalid_ops.insert(op_idx);
if (empty_grads_.count(op.output(0))) {
for (const auto& input : op.input()) {
empty_grads_.insert(input);
}
continue;
} else {
string first_input;
for (const auto& input : op.input()) {
if (!input.empty()) {
if (first_input.empty()) first_input = input;
gather_map[input] = {op_idx, first_input};
}
}
}
}
// Count the references to detect leafs
for (const auto& input : op.input()) {
if (str::endswith(input, "_grad")) {
ref_count[input] += 1;
}
}
}
// Decompose the <GradientGather> into <GradientAdd>
// This trick accumulates the split to target right after computing,
// which helps to reduce the total number of buffers.
auto graph_v2(graph);
graph_v2.clear_op();
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
if (invalid_ops.count(op_idx)) continue;
const auto& op = graph.op(op_idx);
graph_v2.add_op()->CopyFrom(op);
if (!str::find(op.type(), "Gradient")) continue;
for (const auto& output : op.output()) {
const auto& find_iter = gather_map.find(output);
if (find_iter != gather_map.end()) {
const auto& gather_op = graph.op(find_iter->second.first);
auto add_op(gather_op);
add_op.clear_input();
if (output != find_iter->second.second) {
add_op.set_type("GradientAdd");
// Make an in-place to avoid a new buffer
add_op.add_input(gather_op.output(0));
const auto& ref_iter = ref_count.find(gather_op.output(0));
if (ref_iter != ref_count.end()) ref_iter->second++;
}
add_op.add_input(output);
graph_v2.add_op()->CopyFrom(add_op);
}
}
}
// Prepare the pool
int buffer_idx = 0;
std::deque<string> pool;
Map<string, string> grad_to_buffer;
auto get_buffer = [&]() mutable {
if (pool.empty()) {
return "/share/buffer/grad:" + str::to(buffer_idx++);
} else {
/*!
* LIFO is more memory efficient than FIFO usually,
* Because the larger gradients will bring out later.
*
* Memory distribution turns out to be uniform,
* if the early temporary tensors are selected prior.
*/
auto buffer = pool.back();
pool.pop_back();
return buffer;
}
};
for (int op_idx = 0; op_idx < graph_v2.op_size(); ++op_idx) {
auto* op = graph_v2.mutable_op(op_idx);
// Ignore the non-gradient ops
if (!str::find(op->type(), "Gradient")) continue;
// Check if output is an alias of input
vec32_t inplace_flags;
for (int i = 0; i < op->output_size(); ++i) {
int flag = -1;
for (int j = 0; j < op->input_size(); ++j)
if (op->output(i) == op->input(j)) {
flag = j;
break;
}
inplace_flags.emplace_back(flag);
}
// Besides, we need to collect the dead buffers
// Reuse them when current operator is done
vector<string> dead_buffers;
// Rewrite input gradients
for (int i = 0; i < op->input_size(); ++i) {
const string& in = op->input(i);
if (ref_count.count(in) > 0) {
ref_count[in] -= 1; // Decref
if (grad_to_buffer.count(in) == 0) continue;
string in_v2 = grad_to_buffer[in];
if (ref_count[in] == 0) {
dead_buffers.emplace_back(in_v2);
}
*op->mutable_input(i) = in_v2;
}
}
// Rewrite output gradients
for (int i = 0; i < op->output_size(); ++i) {
if (str::startswith(op->type(), "Python")) continue;
const string& out = op->output(i);
if (out.empty() || str::startswith(out, "/share/buffer")) continue;
if (empty_grads_.count(out) > 0) {
*op->mutable_output(i) = "";
continue;
}
// Protection for leafs
if (ref_count.count(out) == 0) continue;
// Protection for sources and leafs
if (retained_grads_.count(out) > 0) continue;
string out_v2 = out;
if (inplace_flags[i] >= 0) {
out_v2 = op->input(inplace_flags[i]);
} else {
grad_to_buffer[out] = out_v2 = get_buffer();
}
*op->mutable_output(i) = out_v2;
}
// Update the pool
for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer);
}
}
return graph_v2;
}
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_CORE_GRAPH_GRADIENT_H_
#define DRAGON_CORE_GRAPH_GRADIENT_H_
#include "dragon/core/common.h"
namespace dragon {
class DRAGON_API GraphGradientMaker {
public:
/*! \brief Generate graph from the executed ops */
void Make(
const vector<OperatorDef*>& ops,
const vector<string>& targets,
const vector<string>& input_grads,
GraphDef& graph);
/*! \brief Eliminate the unused and make sharing of outputs */
GraphDef Optimize(const GraphDef& graph);
/*! \brief Add an empty gradient */
void add_empty_grad(const string& name) {
empty_grads_.insert(name);
}
/*! \brief Add a retained gradient */
void add_retained_grad(const string& name) {
retained_grads_.insert(name);
}
/*! \brief Set the prefix of backward op name */
void set_op_prefix(const string& prefix) {
op_prefix_ = prefix;
}
private:
/*! \brief Check the missing grads */
bool CheckGrad(
const OperatorDef& op,
const Set<string>& targets,
vector<pair<string, int>>& gen_grads);
/*! \brief Return a dummy operator name */
string GetOperatorName() {
if (op_prefix_.empty()) return "GradientOp";
return op_prefix_ + str::to(op_idx_++);
}
/*! \brief The mapping from input to grad */
Map<string, string> inputs_to_grads_;
/*! \brief The gradients should be retained */
Set<string> retained_grads_;
/*! \brief The gradients should be set to empty */
Set<string> empty_grads_;
/*! \brief The prefix of op name */
string op_prefix_;
/*! \brief The counter of op name */
int64_t op_idx_ = 0;
};
} // namespace dragon
#endif // DRAGON_CORE_GRAPH_GRADIENT_H_
#include "dragon/core/graph_optimizer.h"
#include "dragon/core/graph_gradient.h"
#include "dragon/core/operator_schema.h"
#include "dragon/core/workspace.h"
......@@ -9,11 +8,11 @@ namespace dragon {
void GraphOptimizer::BuildDAG(const GraphDef& graph) {
nodes_.clear();
reference_count_.clear();
inputs_count_.clear();
for (int i = 0; i < graph.op_size(); ++i) {
const auto& op = graph.op(i);
for (const auto& in : op.input()) {
reference_count_[in] += 1;
inputs_count_[in] += 1;
}
for (const auto& out : op.output()) {
if (out.empty()) continue;
......@@ -32,95 +31,24 @@ void GraphOptimizer::BuildDAG(const GraphDef& graph) {
}
}
GraphDef GraphOptimizer::EliminateUnused(const GraphDef& graph) {
// Initialization
BuildDAG(graph);
used_.clear();
// Eliminate the unused nodes
for (const auto& out : graph.output()) {
EliminateUnusedNode(out);
}
for (const auto& grad_info : graph.grad_info()) {
const auto grad_y = grad_info.y() + "_grad";
for (const auto& x : grad_info.xs()) {
visited_.clear();
EliminateUnusedNode(grad_y, x + "_grad");
}
}
// Select the used operators
set<int> selected_op_indices;
for (auto it : used_) {
if (nodes_[it.first].op_idx == -1) continue;
selected_op_indices.insert(nodes_[it.first].op_idx);
}
// Prepare the registered placeholders
Set<string> outputs;
for (const auto& name : ws_->tensors()) {
outputs.insert(name);
}
// Rewrite graph
GraphDef graph_v2(graph);
graph_v2.clear_op();
for (auto op_idx : selected_op_indices) {
const auto& op = graph.op(op_idx);
auto* op_v2 = graph_v2.add_op();
op_v2->CopyFrom(op);
// Rewrite inputs
for (int i = 0; i < op.input_size(); ++i) {
const auto& in = op.input(i);
if (!used_[in] || outputs.count(in) == 0) {
*op_v2->mutable_input(i) = "";
}
}
// Rewrite outputs
for (int i = 0; i < op.output_size(); ++i) {
const auto& out = op.output(i);
if (!used_[out]) {
*op_v2->mutable_output(i) = "";
} else {
outputs.insert(out);
}
}
// Rewrite hand-craft cases
if (op.type() == "AffineGradient") {
if (op_v2->output(1).empty()) *op_v2->mutable_input(0) = "";
} else if (op.type() == "MulGradient") {
if (op_v2->output(0).empty()) *op_v2->mutable_input(1) = "";
if (op_v2->output(1).empty()) *op_v2->mutable_input(0) = "";
} else if (op.type() == "DivGradient") {
if (op_v2->output(1).empty()) {
*op_v2->mutable_input(0) = "";
if (op_v2->output(0).empty()) *op_v2->mutable_input(1) = "";
}
}
}
return graph_v2;
}
void GraphOptimizer::PlanInplace(
const GraphDef& graph,
Map<string, Set<string>>& output_aliases) {
// Initialization
// Initialization.
BuildDAG(graph);
// Generate aliases map to apply in-place
for (const auto& iter : reference_count_) {
const auto& in = iter.first;
if (iter.second == 1 && !in.empty() && nodes_[in].childs.size() > 0) {
const auto& op = nodes_[nodes_[in].childs[0]].op_def;
// Generate aliases map to apply in-place.
for (const auto& iter : inputs_count_) {
if (iter.second > 1 || iter.first.empty()) continue;
const auto& input = iter.first;
const auto& input_node = nodes_[input];
if (input_node.childs.empty() || input_node.parents.empty()) continue;
const auto& op = nodes_[input_node.childs[0]].op_def;
const auto* schema = OpSchemaRegistry::Schema(op.type());
for (int i = 0; i < op.input_size(); ++i) {
if (op.input(i) == in) {
if (op.input(i) != input) continue;
for (int j = 0; j < op.output_size(); ++j) {
if (schema->CheckInplace(i, j)) {
output_aliases[op.output(j)].insert(in);
}
}
}
if (!schema->CheckInplace(i, j)) continue;
output_aliases[op.output(j)].insert(input);
}
}
}
......@@ -134,7 +62,7 @@ GraphDef GraphOptimizer::PlanCheckpoint(
Map<string, string> rename_map;
Map<string, int> versions;
// Check the mirror stage setting
// Check the mirror stage setting.
for (const auto& op : graph.op()) {
if (str::find(op.type(), "Gradient")) continue;
bool mirror_stage = false;
......@@ -144,12 +72,12 @@ GraphDef GraphOptimizer::PlanCheckpoint(
}
}
if (mirror_stage) {
// We only assume X(0) can be recomputed
// We only assume X(0) can be recomputed.
rename_map[op.input(0)] = "placeholder";
}
}
// Allocate the temporal buffers
// Allocate the temporal buffers.
string v2_name, version_name;
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx);
......@@ -173,7 +101,7 @@ GraphDef GraphOptimizer::PlanCheckpoint(
continue;
}
for (int j = 0; j < GRAPH_TEMPORAL_OUTPUT_MAX_SIZE; ++j) {
v2_name = "/share/buffer/symbol:" + str::to(j);
v2_name = "shared/buffer/output:" + str::to(j);
for (const auto& buffer : used_buffers)
if (str::find(buffer, v2_name)) {
v2_name.clear();
......@@ -221,18 +149,18 @@ GraphDef GraphOptimizer::PlanCheckpoint(
return graph_v2;
}
GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) {
Set<string> blacklist = {""};
Map<string, int> ref_count;
Map<string, string> rename_map;
static Set<string> star_ops = {"Shape"};
GraphDef GraphOptimizer::EliminateIntermediates(const GraphDef& graph) {
Set<string> required_outputs;
Map<string, int> inputs_count;
Map<string, string> outputs_to_buffers;
static Set<string> skip_ops = {"Shape"};
// Prepare the pool
// Prepare pool.
int buffer_idx = 0;
std::deque<string> pool;
auto get_buffer = [&]() mutable {
if (pool.empty()) {
return "/share/buffer/output:" + str::to(buffer_idx++);
return "shared/buffer/output:" + str::to(++buffer_idx);
} else {
auto buffer = pool.back();
pool.pop_back();
......@@ -240,56 +168,58 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) {
}
};
// Count the references
// Count inputs.
for (const auto& op : graph.op()) {
for (const auto& in : op.input()) {
ref_count[in] += 1;
for (const auto& input : op.input()) {
inputs_count[input] += 1;
}
}
// Preserve the graph outputs
for (auto& out : graph.output()) {
blacklist.insert(out);
// Initialize the required outputs before optimization.
for (const auto& output : graph.output()) {
required_outputs.insert(output);
}
// Rewrite the inputs and outputs
// Rewrite the inputs and outputs.
auto graph_v2(graph);
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx);
auto* op_v2 = graph_v2.mutable_op(op_idx);
// Ignore the init ops
if (op.input_size() == 0) continue;
// We need to collect the dead buffers.
// Reuse them when current operator is done.
vector<string> dead_buffers;
// Rewrite inputs
for (int i = 0; i < op.input_size(); ++i) {
const auto& name = op.input(i);
if (rename_map.count(name)) {
*op_v2->mutable_input(i) = rename_map[name];
}
ref_count[name]--;
if (ref_count[name] == 0 &&
str::startswith(op_v2->input(i), "/share/buffer/output:")) {
dead_buffers.push_back(op_v2->input(i));
auto* op_v2 = graph_v2.mutable_op(op_idx);
// Check output aliases.
vec32_t output_aliases(op.output_size(), -1);
for (int i = 0; i < op.output_size(); ++i) {
for (int j = 0; j < op.input_size(); ++j) {
if (op.output(i) != op.input(j)) continue;
output_aliases[i] = j;
break;
}
}
// Rewrite outputs
if (!star_ops.count(op.type())) {
// Rewrite inputs.
vector<string> dead_buffers;
for (int i = 0; i < op.input_size(); ++i) {
const auto& input = op.input(i);
const auto& count_iter = inputs_count.find(input);
count_iter->second--;
const auto& buffer_iter = outputs_to_buffers.find(input);
if (buffer_iter == outputs_to_buffers.end()) continue;
if (count_iter->second == 0) {
dead_buffers.emplace_back(buffer_iter->second);
}
op_v2->set_input(i, buffer_iter->second);
}
if (skip_ops.count(op.type())) continue;
// Rewrite outputs.
for (int i = 0; i < op.output_size(); ++i) {
const auto& name = op.output(i);
bool inplace_flag = false;
if (blacklist.count(name)) continue;
for (const auto& input : op.input())
if (name == input) inplace_flag = true;
if (inplace_flag) {
*op_v2->mutable_output(i) = op_v2->input(i);
const auto& output = op.output(i);
if (output.empty() || required_outputs.count(output) > 0) continue;
if (output_aliases[i] >= 0) {
op_v2->set_output(i, op_v2->input(output_aliases[i]));
} else {
rename_map[name] = *op_v2->mutable_output(i) = get_buffer();
*op_v2->mutable_output(i) = outputs_to_buffers[output] = get_buffer();
}
}
}
// Update the pool
// Update pool.
for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer);
}
......@@ -297,36 +227,4 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) {
return graph_v2;
}
void GraphOptimizer::EliminateUnusedNode(
const string& source,
const string& sink) {
if (visited_.count(source)) return;
visited_[source] = false;
for (const auto& next : nodes_[source].childs) {
if (next == sink) {
visited_[next] = used_[next] = true;
visited_[source] = used_[source] = true;
return;
}
EliminateUnusedNode(next, sink);
if (visited_[next]) {
visited_[source] = used_[source] = true;
}
}
}
void GraphOptimizer::EliminateUnusedNode(const string& sink) {
std::queue<const string*> q;
q.push(&sink);
while (!q.empty()) {
const auto& source = *q.front();
q.pop();
used_[source] = true;
for (const auto& last : nodes_[source].parents) {
if (used_.count(last)) continue;
q.push(&last);
}
}
}
} // namespace dragon
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!