Commit 02ad90d5 by Ting PAN

Remove the duplicate workspace singletons

Summary:
This commit moves the workspace api into the current workspace instance.
For this reason, the namespace ``dragon.workspace`` is removed for simplicity.
1 parent adb6fa64
Showing with 1163 additions and 1799 deletions
<p align="center"> <p align="center">
<img width="40%" src="http://dragon.seetatech.com/static/images/styles-dragon.png"/> <img width="40%" src="https://dragon.seetatech.com/static/images/styles-dragon.png"/>
</p> </p>
[Dragon](http://dragon.seetatech.com) is a **C**(Computation)**G**(Graph)**V**(Virtual)**M**(Machine) based distributed deep learning framework. [Dragon](https://dragon.seetatech.com) is a **C**(Computation)**G**(Graph)**V**(Virtual)**M**(Machine) based distributed deep learning framework.
It fuses several modern frameworks and integrations together, powered by a unified engine. It fuses several modern frameworks and integrations together, powered by a unified engine.
The computation between different programming styles is deterministic and reproduceable. The computation between different programming styles is deterministic and reproduceable.
...@@ -11,7 +11,7 @@ promoting internal interfaces. We will always learn from the AI community to evo ...@@ -11,7 +11,7 @@ promoting internal interfaces. We will always learn from the AI community to evo
## Installation ## Installation
See the [install guide](http://dragon.seetatech.com/install) for the pip package See the [install guide](https://dragon.seetatech.com/install) for the pip package
or how to build from source. or how to build from source.
## License ## License
......
...@@ -15,9 +15,13 @@ from __future__ import absolute_import ...@@ -15,9 +15,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy
from dragon.core.autograph.tensor import TensorRef from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context as eager_context from dragon.core.eager import context as eager_context
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.util import logging
from dragon.vm.caffe.proto import caffe_pb2
class Layer(object): class Layer(object):
...@@ -34,24 +38,26 @@ class Layer(object): ...@@ -34,24 +38,26 @@ class Layer(object):
""" """
self._proto = layer_param self._proto = layer_param
self._name = layer_param.name self._name = layer_param.name
self._arguments, self.arguments = {'name': self._name}, {} self._arguments, self.arguments = {'name': 'output'}, {}
# Store the inputs, outputs and trainable parameters. # Store the inputs, outputs and trainable parameters.
self._bottom, self._top, self._blobs = [], [], [] self._bottom, self._top, self._blobs = [], [], []
for blob in layer_param.bottom: for blob in layer_param.bottom:
self._bottom.append(blob) self._bottom.append(blob)
for blob in layer_param.top: for blob in layer_param.top:
self._top.append(blob) self._top.append(blob)
# Store the loss weight to apply gradients. # Store the loss weight to apply gradients.
self._loss_weight = layer_param.loss_weight \ self._loss_weight = layer_param.loss_weight \
if len(layer_param.loss_weight) > 0 else None if len(layer_param.loss_weight) > 0 else None
# Optional mirror stage argument for memory optimization. # Optional mirror stage argument for memory optimization.
if layer_param.HasField('mirror_stage'): if layer_param.HasField('mirror_stage'):
self._arguments['mirror_stage'] = layer_param.mirror_stage self._arguments['mirror_stage'] = layer_param.mirror_stage
@property @property
def blobs(self):
"""Return the blobs."""
return self._blobs
@property
def bottom(self): def bottom(self):
"""Return the bottom names.""" """Return the bottom names."""
return self._bottom return self._bottom
...@@ -62,49 +68,91 @@ class Layer(object): ...@@ -62,49 +68,91 @@ class Layer(object):
return self._loss_weight return self._loss_weight
@property @property
def name(self):
"""Return the layer name."""
return self._name
@property
def top(self): def top(self):
"""Return the top names.""" """Return the top names."""
return self._top return self._top
def add_blob(self, value=None, filler=None, no_grad=False): def add_blob(self, value=None, filler=None, no_grad=False):
"""Add a weight blob into this layer.""" """Add a blob into this layer."""
# Use a fixed name in the current workspace. # Set the name for reference explicitly.
# Note that a non-empty tensor scope will make it data_name = context.get_name_scope() + 'param:{}'.format(len(self._blobs))
# impossible to load/save models. You should use data, diff = TensorRef(data_name), TensorRef(data_name + '_grad')
# a new workspace instead of the terrible name scope.
scoped_name = context.get_name_scope() + self._name
param_name = scoped_name + '/param:{}'.format(len(self._blobs))
# Set the name explicitly.
variable = TensorRef(param_name)
variable_grad = TensorRef(param_name + '_grad')
if filler is not None: if filler is not None:
variable._register_as(**filler) data._register_as(**filler)
else: else:
# Register a constant filler by default. # Register a constant filler by default.
value = value if value else 0 value = value if value else 0
variable.constant(value=value) data.constant(value=value)
# Append to the blobs.
self._blobs.append({'data': data, 'diff': None if no_grad else diff})
# Determine whether to disable the gradients explicitly. def from_proto(self, proto):
if no_grad is True: """Deserialize from the proto.
variable_grad = None
# Append to the blobs. Parameters
self._blobs.append({'data': variable, 'diff': variable_grad}) ----------
proto : LayerParameter
The ``LayerParameter`` protocol buffer.
"""
for i in range(len(self._blobs)):
if i < len(proto.blobs):
blob_proto = proto.blobs[i]
if len(blob_proto.data) > 0:
value = numpy.array(blob_proto.data, dtype='float32')
elif len(blob_proto.double_data) > 0:
value = numpy.array(blob_proto.double_data, dtype='float64')
else:
raise ValueError('Neither <data> or <double_data> in blob proto.')
if len(blob_proto.shape.dim) > 0:
value = value.reshape([dim for dim in blob_proto.shape.dim])
self._blobs[i]['data'].set_value(value)
logging.info('Blob({}/param:{}) loaded, shape: {}, size: {}'
.format(self._name, i, value.shape, value.size))
def setup(self, bottom): def setup(self, bottom):
# Merge the arguments, then setup up the specific layer. """Setup the layer."""
self.arguments = dict(self.arguments, **self._arguments) self.arguments = dict(self.arguments, **self._arguments)
bottom = bottom[0] if len(bottom) == 1 else bottom bottom = bottom[0] if len(bottom) == 1 else bottom
with eager_context.graph_mode(): with eager_context.graph_mode():
return self.__call__(bottom) return self.__call__(bottom)
@classmethod def to_proto(self):
def get_filler(cls, layer_param, filler_name): """Serialize to the proto.
"""Construct a filler from the parameter."""
if layer_param.HasField(filler_name): Returns
filler = getattr(layer_param, filler_name) -------
LayerParameter
The ``LayerParameter`` protocol buffer.
"""
proto = caffe_pb2.LayerParameter()
proto.CopyFrom(self._proto)
for blob in self._blobs:
value = blob['data'].get_value()
if str(value.dtype) == 'float32':
blob_proto = caffe_pb2.BlobProto(
data=value.flatten(),
shape=caffe_pb2.BlobShape(dim=value.shape))
elif str(value.dtype) == 'float64':
blob_proto = caffe_pb2.BlobProto(
double_data=value.flatten(),
shape=caffe_pb2.BlobShape(dim=value.shape))
else:
raise ValueError('Either float32 or float64 blob is required.')
proto.blobs.extend([blob_proto])
return proto
@staticmethod
def get_filler(proto, filler_name):
"""Return the filler from proto."""
if proto.HasField(filler_name):
filler = getattr(proto, filler_name)
return { return {
'type': filler.type.lower(), 'type': filler.type.lower(),
'value': filler.value, 'value': filler.value,
......
...@@ -16,14 +16,10 @@ from __future__ import print_function ...@@ -16,14 +16,10 @@ from __future__ import print_function
from dragon.vm.caffe.layers.common import Accuracy from dragon.vm.caffe.layers.common import Accuracy
from dragon.vm.caffe.layers.common import ArgMax from dragon.vm.caffe.layers.common import ArgMax
from dragon.vm.caffe.layers.common import BatchNorm from dragon.vm.caffe.layers.common import BatchNorm
from dragon.vm.caffe.layers.common import Cast
from dragon.vm.caffe.layers.common import Concat from dragon.vm.caffe.layers.common import Concat
from dragon.vm.caffe.layers.common import Crop from dragon.vm.caffe.layers.common import Crop
from dragon.vm.caffe.layers.common import Eltwise from dragon.vm.caffe.layers.common import Eltwise
from dragon.vm.caffe.layers.common import Flatten from dragon.vm.caffe.layers.common import Flatten
from dragon.vm.caffe.layers.common import FusedBatchNorm
from dragon.vm.caffe.layers.common import FusedGroupNorm
from dragon.vm.caffe.layers.common import GroupNorm
from dragon.vm.caffe.layers.common import InnerProduct from dragon.vm.caffe.layers.common import InnerProduct
from dragon.vm.caffe.layers.common import Input from dragon.vm.caffe.layers.common import Input
from dragon.vm.caffe.layers.common import Normalize from dragon.vm.caffe.layers.common import Normalize
...@@ -46,12 +42,10 @@ from dragon.vm.caffe.layers.neuron import ELU ...@@ -46,12 +42,10 @@ from dragon.vm.caffe.layers.neuron import ELU
from dragon.vm.caffe.layers.neuron import Power from dragon.vm.caffe.layers.neuron import Power
from dragon.vm.caffe.layers.neuron import PReLU from dragon.vm.caffe.layers.neuron import PReLU
from dragon.vm.caffe.layers.neuron import ReLU from dragon.vm.caffe.layers.neuron import ReLU
from dragon.vm.caffe.layers.neuron import SELU
from dragon.vm.caffe.layers.neuron import Sigmoid from dragon.vm.caffe.layers.neuron import Sigmoid
from dragon.vm.caffe.layers.neuron import TanH from dragon.vm.caffe.layers.neuron import TanH
from dragon.vm.caffe.layers.vision import Convolution from dragon.vm.caffe.layers.vision import Convolution
from dragon.vm.caffe.layers.vision import Deconvolution from dragon.vm.caffe.layers.vision import Deconvolution
from dragon.vm.caffe.layers.vision import DepthwiseConv2d
from dragon.vm.caffe.layers.vision import LRN from dragon.vm.caffe.layers.vision import LRN
from dragon.vm.caffe.layers.vision import Pooling from dragon.vm.caffe.layers.vision import Pooling
from dragon.vm.caffe.layers.vision import ROIAlign from dragon.vm.caffe.layers.vision import ROIAlign
......
...@@ -33,8 +33,9 @@ class _DataPlugin(object): ...@@ -33,8 +33,9 @@ class _DataPlugin(object):
def forward(self, inputs, outputs): def forward(self, inputs, outputs):
blobs = self.iterator.next() blobs = self.iterator.next()
current_ws = workspace.get_workspace()
for i, blob in enumerate(blobs): for i, blob in enumerate(blobs):
workspace.feed_tensor(outputs[i], blob) current_ws.feed_tensor(outputs[i], blob)
class Data(Layer): class Data(Layer):
...@@ -44,42 +45,46 @@ class Data(Layer): ...@@ -44,42 +45,46 @@ class Data(Layer):
```python ```python
layer { layer {
type: "Data" type: "Data"
top: "data" top: "data"
top: "label" top: "label"
include { phase: TRAIN } include {
data_param { phase: TRAIN
source: "/data/imagenet/train" }
batch_size: 128 data_param {
shuffle: true source: "/data/train"
num_chunks: 0 batch_size: 128
prefetch: 5 shuffle: true
} num_chunks: 0
transform_param { prefetch: 5
mirror: true }
random_crop_size: 224 transform_param {
augment_color: true mirror: true
mean_value: 104.00698793 random_crop_size: 224
mean_value: 116.66876762 augment_color: true
mean_value: 122.67891434 mean_value: 104.00698793
} mean_value: 116.66876762
mean_value: 122.67891434
}
} }
layer { layer {
type: "Data" type: "Data"
top: "data" top: "data"
top: "label" top: "label"
include { phase: TEST } include {
data_param { phase: TEST
source: "/data/imagenet/val" }
batch_size: 100 data_param {
} source: "/data/val"
transform_param { batch_size: 64
resize: 256 }
crop_size: 224 transform_param {
mean_value: 104.00698793 resize: 256
mean_value: 116.66876762 crop_size: 224
mean_value: 122.67891434 mean_value: 104.00698793
} mean_value: 116.66876762
mean_value: 122.67891434
}
} }
``` ```
......
...@@ -30,13 +30,13 @@ class EuclideanLoss(Layer): ...@@ -30,13 +30,13 @@ class EuclideanLoss(Layer):
```python ```python
layer { layer {
type: "EuclideanLoss" type: "EuclideanLoss"
bottom: "bbox_pred" bottom: "bbox_pred"
bottom: "bbox_target" bottom: "bbox_target"
top: "bbox_loss" top: "bbox_loss"
loss_param { loss_param {
normalization: BATCH_SIZE normalization: BATCH_SIZE
} }
} }
``` ```
...@@ -67,13 +67,13 @@ class SigmoidCrossEntropyLoss(Layer): ...@@ -67,13 +67,13 @@ class SigmoidCrossEntropyLoss(Layer):
```python ```python
layer { layer {
type: "SigmoidCrossEntropyLoss" type: "SigmoidCrossEntropyLoss"
bottom: "rpn_cls_score" bottom: "rpn_cls_score"
bottom: "rpn_labels" bottom: "rpn_labels"
top: "rpn_loss" top: "rpn_loss"
loss_param { loss_param {
normalization: VALID normalization: VALID
} }
} }
``` ```
...@@ -106,15 +106,15 @@ class SmoothL1Loss(Layer): ...@@ -106,15 +106,15 @@ class SmoothL1Loss(Layer):
```python ```python
layer { layer {
type: "SmoothL1Loss" type: "SmoothL1Loss"
bottom: "bbox_pred" bottom: "bbox_pred"
bottom: "bbox_targets" bottom: "bbox_targets"
bottom: "bbox_inside_weights" bottom: "bbox_inside_weights"
bottom: "bbox_outside_weights" bottom: "bbox_outside_weights"
top: "bbox_loss" top: "bbox_loss"
loss_param { loss_param {
normalization: BATCH_SIZE normalization: BATCH_SIZE
} }
} }
``` ```
...@@ -155,15 +155,17 @@ class SoftmaxWithLoss(Layer): ...@@ -155,15 +155,17 @@ class SoftmaxWithLoss(Layer):
```python ```python
layer { layer {
type: "SoftmaxWithLoss" type: "SoftmaxWithLoss"
bottom: "cls_score" bottom: "cls_score"
bottom: "labels" bottom: "labels"
top: "cls_loss" top: "cls_loss"
softmax_param { axis: 1 } softmax_param {
loss_param { axis: 1
ignore_label: -1 }
normalization: VALID loss_param {
} ignore_label: -1
normalization: VALID
}
} }
``` ```
......
...@@ -32,12 +32,12 @@ class Dropout(Layer): ...@@ -32,12 +32,12 @@ class Dropout(Layer):
```python ```python
layer { layer {
type: "Dropout" type: "Dropout"
bottom: "fc6" bottom: "fc6"
top: "fc6" top: "fc6"
dropout_param { dropout_param {
dropout_ratio: 0.5 dropout_ratio: 0.5
} }
} }
``` ```
...@@ -73,12 +73,12 @@ class ELU(Layer): ...@@ -73,12 +73,12 @@ class ELU(Layer):
```python ```python
layer { layer {
type: "ELU" type: "ELU"
bottom: "conv2" bottom: "conv2"
top: "conv2" top: "conv2"
elu_param { elu_param {
alpha: 1. alpha: 1.
} }
} }
``` ```
...@@ -101,14 +101,14 @@ class Power(Layer): ...@@ -101,14 +101,14 @@ class Power(Layer):
```python ```python
layer { layer {
type: "Power" type: "Power"
bottom: "x" bottom: "x"
top: "y" top: "y"
power_param { power_param {
scale: 1. scale: 1.
shift: 0. shift: 0.
power: 2. power: 2.
} }
} }
``` ```
...@@ -148,16 +148,16 @@ class PReLU(Layer): ...@@ -148,16 +148,16 @@ class PReLU(Layer):
```python ```python
layer { layer {
type: "PReLU" type: "PReLU"
bottom: "conv2" bottom: "conv2"
top: "conv2/relu" top: "conv2/relu"
prelu_param { prelu_param {
channel_shared: false channel_shared: false
filler { filler {
type: "constant" type: "constant"
value: 0.25 value: 0.25
}
} }
}
} }
``` ```
...@@ -194,12 +194,12 @@ class ReLU(Layer): ...@@ -194,12 +194,12 @@ class ReLU(Layer):
```python ```python
layer { layer {
type: "ReLU" type: "ReLU"
bottom: "conv2" bottom: "conv2"
top: "conv2/relu" top: "conv2/relu"
relu_param { relu_param {
negative_slope: 0. negative_slope: 0.
} }
} }
``` ```
...@@ -215,38 +215,6 @@ class ReLU(Layer): ...@@ -215,38 +215,6 @@ class ReLU(Layer):
return activation_ops.relu(bottom, **self.arguments) return activation_ops.relu(bottom, **self.arguments)
class SELU(Layer):
r"""Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
The **SELU** function is defined as:
.. math::
\text{SELU}(x) = 1.0507 *
\begin{cases}
x, & \text{ if } x \geq 0 \\
1.6733 * (e^{x} - 1), & \text{ otherwise }
\end{cases}
Examples:
```python
layer {
type: "SELU"
bottom: "conv2"
top: "conv2/relu"
}
```
"""
def __init__(self, layer_param):
super(SELU, self).__init__(layer_param)
def __call__(self, bottom):
return activation_ops.selu(bottom, **self.arguments)
class Sigmoid(Layer): class Sigmoid(Layer):
r"""Apply the sigmoid function. r"""Apply the sigmoid function.
...@@ -258,9 +226,9 @@ class Sigmoid(Layer): ...@@ -258,9 +226,9 @@ class Sigmoid(Layer):
```python ```python
layer { layer {
type: "Sigmoid" type: "Sigmoid"
bottom: "rpn_cls_score" bottom: "rpn_cls_score"
top: "rpn_cls_prob" top: "rpn_cls_prob"
} }
``` ```
...@@ -284,9 +252,9 @@ class TanH(Layer): ...@@ -284,9 +252,9 @@ class TanH(Layer):
```python ```python
layer { layer {
type: "TanH" type: "TanH"
bottom: "g/conv5" bottom: "g/conv5"
top: "g/image" top: "g/image"
} }
``` ```
......
...@@ -23,39 +23,29 @@ from dragon.vm.caffe.layer import Layer ...@@ -23,39 +23,29 @@ from dragon.vm.caffe.layer import Layer
class Convolution(Layer): class Convolution(Layer):
r"""Apply the n-dimension convolution. r"""Apply the n-dimension convolution.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} +
2 * pad - \text{DK}_{size}) / stride + 1
\end{cases}
Examples: Examples:
```python ```python
layer { layer {
type: "Convolution" type: "Convolution"
bottom: "input" bottom: "input"
top: "conv1" top: "conv1"
convolution_param { convolution_param {
num_output: 32 num_output: 32
bias_term: true bias_term: true
kernel_size: 3 kernel_size: 3
pad: 1 pad: 1
stride: 1 stride: 1
dilation: 1 dilation: 1
group: 1 group: 1
weight_filler { weight_filler {
type: "xavier" type: "xavier"
}
bias_filler {
type: "constant"
value: 0
}
} }
bias_filler {
type: "constant"
value: 0
}
}
} }
``` ```
...@@ -83,7 +73,6 @@ class Convolution(Layer): ...@@ -83,7 +73,6 @@ class Convolution(Layer):
if param.HasField('pad_h'): if param.HasField('pad_h'):
assert param.HasField('pad_w') assert param.HasField('pad_w')
self.arguments['pads'] = [param.pad_h, param.pad_w] self.arguments['pads'] = [param.pad_h, param.pad_w]
self.add_blob(filler=self.get_filler(param, 'weight_filler')) self.add_blob(filler=self.get_filler(param, 'weight_filler'))
if param.bias_term: if param.bias_term:
self.add_blob(filler=self.get_filler(param, 'bias_filler')) self.add_blob(filler=self.get_filler(param, 'bias_filler'))
...@@ -96,39 +85,29 @@ class Convolution(Layer): ...@@ -96,39 +85,29 @@ class Convolution(Layer):
class Deconvolution(Convolution): class Deconvolution(Convolution):
r"""Apply the 2d deconvolution. r"""Apply the 2d deconvolution.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} - 1) *
stride + \text{DK}_{size} - 2 * pad
\end{cases}
Examples: Examples:
```python ```python
layer { layer {
type: "Deconvolution" type: "Deconvolution"
bottom: "conv5" bottom: "conv5"
top: "conv5/upscale" top: "conv5/upscale"
convolution_param { convolution_param {
num_output: 256 num_output: 256
bias_term: true bias_term: true
kernel_size: 2 kernel_size: 2
pad: 0 pad: 0
stride: 2 stride: 2
dilation: 1 dilation: 1
group: 1 group: 1
weight_filler { weight_filler {
type: "xavier" type: "xavier"
} }
bias_filler { bias_filler {
type: "constant" type: "constant"
value: 0 value: 0
}
} }
}
} }
``` ```
...@@ -142,77 +121,6 @@ class Deconvolution(Convolution): ...@@ -142,77 +121,6 @@ class Deconvolution(Convolution):
return vision_ops.conv2d_transpose(inputs, **self.arguments) return vision_ops.conv2d_transpose(inputs, **self.arguments)
class DepthwiseConv2d(Layer):
r"""Apply the 2d depthwise convolution.
`[Chollet, 2016] <https://arxiv.org/abs/1610.02357>`_.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} +
2 * pad - \text{DK}_{size}) / stride + 1
\end{cases}
Examples:
```python
layer {
type: "DepthwiseConv2d"
bottom: "input"
top: "conv1"
convolution_param {
num_output: 32
bias_term: true
kernel_size: 3
pad: 1
stride: 1
dilation: 1
weight_filler {
type: "xavier"
variance_norm: FAN_OUT
}
bias_filler {
type: "constant"
value: 0
}
}
}
```
"""
def __init__(self, layer_param):
super(DepthwiseConv2d, self).__init__(layer_param)
param = layer_param.convolution_param
self.arguments = {
'out_channels': param.num_output,
'kernel_shape': [int(e) for e in param.kernel_size],
'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1],
'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0],
'padding': 'VALID',
'data_format': 'NCHW',
}
if param.HasField('kernel_h'):
assert param.HasField('kernel_w')
self.arguments['kernel_shape'] = [param.kernel_h, param.kernel_w]
if param.HasField('stride_h'):
assert param.HasField('stride_w')
self.arguments['strides'] = [param.stride_h, param.stride_w]
if param.HasField('pad_h'):
assert param.HasField('pad_w')
self.arguments['pads'] = [param.pad_h, param.pad_w]
self.add_blob(filler=self.get_filler(param, 'weight_filler'))
if param.bias_term:
self.add_blob(filler=self.get_filler(param, 'bias_filler'))
def __call__(self, bottom):
inputs = [bottom] + [blob['data'] for blob in self._blobs]
return vision_ops.depthwise_conv2d(inputs, **self.arguments)
class LRN(Layer): class LRN(Layer):
r"""Apply the local response normalization. r"""Apply the local response normalization.
`[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_. `[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_.
...@@ -221,15 +129,15 @@ class LRN(Layer): ...@@ -221,15 +129,15 @@ class LRN(Layer):
```python ```python
layer { layer {
type: "LRN" type: "LRN"
bottom: "conv2" bottom: "conv2"
top: "conv2/norm" top: "conv2/norm"
lrn_param { lrn_param {
local_size: 5 local_size: 5
alpha: 1. alpha: 1.
beta: 0.75 beta: 0.75
k: 1. k: 1.
} }
} }
``` ```
...@@ -255,24 +163,18 @@ class LRN(Layer): ...@@ -255,24 +163,18 @@ class LRN(Layer):
class Pooling(Layer): class Pooling(Layer):
r"""Apply the n-dimension pooling. r"""Apply the n-dimension pooling.
The spatial output dimension is computed as:
.. math::
\text{Dim}_{out} = (\text{Dim}_{in} +
2 * pad - \text{K}_{size}) / stride + 1
Examples: Examples:
```python ```python
layer { layer {
type: "Pooling" type: "Pooling"
bottom: "conv2" bottom: "conv2"
top: "pool2" top: "pool2"
pooling_param { pooling_param {
kernel_size: 3 kernel_size: 3
stride: 2 stride: 2
pool: AVG pool: AVG
} }
} }
``` ```
...@@ -311,14 +213,14 @@ class ROIAlign(Layer): ...@@ -311,14 +213,14 @@ class ROIAlign(Layer):
```python ```python
layer { layer {
type: "ROIAlign" type: "ROIAlign"
bottom: "conv5_3" bottom: "conv5_3"
top: "roi_pool4" top: "roi_pool4"
roi_pooling_param { roi_pooling_param {
pooled_w: 7 pooled_w: 7
pooled_h: 7 pooled_h: 7
spatial_scale: 0.0625 spatial_scale: 0.0625
} }
} }
``` ```
...@@ -345,14 +247,14 @@ class ROIPooling(Layer): ...@@ -345,14 +247,14 @@ class ROIPooling(Layer):
```python ```python
layer { layer {
type: "ROIPooling" type: "ROIPooling"
bottom: "conv5_3" bottom: "conv5_3"
top: "roi_pool4" top: "roi_pool4"
roi_pooling_param { roi_pooling_param {
pooled_w: 7 pooled_w: 7
pooled_h: 7 pooled_h: 7
spatial_scale: 0.0625 spatial_scale: 0.0625
} }
} }
``` ```
......
...@@ -16,15 +16,14 @@ from __future__ import division ...@@ -16,15 +16,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import time import time
from google.protobuf import text_format from google.protobuf import text_format
from dragon.core.autograph import def_function from dragon.core.autograph import def_function
from dragon.core.framework import workspace
from dragon.core.training.adam import Adam from dragon.core.training.adam import Adam
from dragon.core.training.rmsprop import RMSprop from dragon.core.training.rmsprop import RMSprop
from dragon.core.training.sgd import SGD from dragon.core.training.sgd import SGD
from dragon.core.training.sgd import Nesterov from dragon.core.training.sgd import Nesterov
from dragon.core.util import logging
from dragon.vm.caffe.net import Net from dragon.vm.caffe.net import Net
from dragon.vm.caffe.proto import caffe_pb2 from dragon.vm.caffe.proto import caffe_pb2
...@@ -99,8 +98,9 @@ class Solver(object): ...@@ -99,8 +98,9 @@ class Solver(object):
if self._current_step < len(self._param.stepvalue) \ if self._current_step < len(self._param.stepvalue) \
and self.iter >= self._param.stepvalue[self._current_step]: and self.iter >= self._param.stepvalue[self._current_step]:
self._current_step = self._current_step + 1 self._current_step = self._current_step + 1
print('MultiStep Status: Iteration {}, step = {}' logging.info(
.format(self.iter, self._current_step)) 'MultiStep Status: Iteration {}, step = {}'
.format(self.iter, self._current_step))
new_lr = self._param.base_lr * \ new_lr = self._param.base_lr * \
pow(self._param.gamma, self._current_step) pow(self._param.gamma, self._current_step)
self.base_lr = new_lr self.base_lr = new_lr
...@@ -112,8 +112,9 @@ class Solver(object): ...@@ -112,8 +112,9 @@ class Solver(object):
else: else:
if self._current_step + 1 < len(stage_iters): if self._current_step + 1 < len(stage_iters):
self._current_step = self._current_step + 1 self._current_step = self._current_step + 1
print('MultiFixed Status: Iteration {}, stage = {}' logging.info(
.format(self.iter, self._current_step)) 'MultiFixed Status: Iteration {}, stage = {}'
.format(self.iter, self._current_step))
self.base_lr = stage_lrs[self._current_step] self.base_lr = stage_lrs[self._current_step]
elif policy == 'inv': elif policy == 'inv':
power = self._param.power power = self._param.power
...@@ -130,8 +131,7 @@ class Solver(object): ...@@ -130,8 +131,7 @@ class Solver(object):
def _apply_update(self): def _apply_update(self):
"""Apply the weights update.""" """Apply the weights update."""
for blob in self.net._layer_blobs: for blob in self.net._layer_blobs:
if blob.lr_multiplier > 0 and \ if blob.lr_multiplier > 0 and blob.diff is not None:
blob.diff is not None:
self._optimizer.apply_gradients( self._optimizer.apply_gradients(
values_and_grads=[(blob.data, blob.diff)], values_and_grads=[(blob.data, blob.diff)],
lr_mult=blob.lr_multiplier, lr_mult=blob.lr_multiplier,
...@@ -211,80 +211,18 @@ class Solver(object): ...@@ -211,80 +211,18 @@ class Solver(object):
""" """
return self._test_nets return self._test_nets
def one_step(self):
"""One step run the train net.
Returns
-------
dict
The stats.
"""
if self._param.test_interval and \
self.iter % self._param.test_interval == 0:
if (self.iter == 0 and
self._param.test_initialization) or self.iter != 0:
for test_idx in range(len(self._test_nets)):
self.test(test_idx)
# Forward, backward and compute loss.
run_time, stats = 0., {'loss': {'total': 0.}, 'iter': self.iter}
for i in range(self._param.iter_size):
tic = time.time()
self._net.forward_backward(return_outputs=False)
run_time += (time.time() - tic)
# Total loss.
for e in self.net.losses:
values = e.get_value().flatten()
if values.size == 1:
stats['loss']['total'] += values[0]
# Partial loss.
for key in self.net.outputs:
values = self.net.blobs[key].data
values = values.get_value().flatten()
if values.size != 1:
continue
if key not in stats['loss']:
stats['loss'][key] = 0.
stats['loss'][key] += values[0]
# Apply Update.
self._get_learning_rate()
tic = time.time()
self._apply_update()
run_time += (time.time() - tic)
self.iter = self.iter + 1
# Snapshot.
if self._param.snapshot:
if self.iter % self._param.snapshot == 0:
self.snapshot()
# Average loss by the iter size.
for k in stats['loss'].keys():
stats['loss'][k] /= self._param.iter_size
# Misc stats.
stats['lr'] = self.base_lr
stats['time'] = run_time
return stats
def snapshot(self): def snapshot(self):
"""Snapshot the parameters of train net.""" """Snapshot the parameters of train net."""
workspace.save( self._net.save(
tensors=[blob.data for blob in self.net._layer_blobs], '%s_iter_%d.caffemodel'
filename='_iter_%d' % self.iter, % (self._param.snapshot_prefix, self._iter))
prefix=self._param.snapshot_prefix,
suffix='.caffemodel', def step(self, num_iterations=1):
format='caffe',
)
def step(self, num_iterations):
"""Step the train net. """Step the train net.
Parameters Parameters
---------- ----------
num_iterations : int num_iterations : int, optional, default=1
The number of iterations to step. The number of iterations to step.
""" """
...@@ -293,19 +231,18 @@ class Solver(object): ...@@ -293,19 +231,18 @@ class Solver(object):
loss_vec, smoothed_loss = [], 0. loss_vec, smoothed_loss = [], 0.
tic = time.time() tic = time.time()
while self.iter < stop_step: while self.iter < stop_step:
# Test if necessary. # Test if necessary.
if self._param.test_interval and \ if self._is_root and self._param.test_interval > 0 and \
self.iter % self._param.test_interval == 0: self.iter % self._param.test_interval == 0:
if (self.iter == 0 and if (self.iter == 0 and self._param.test_initialization) or \
self._param.test_initialization) or self.iter != 0: self.iter != 0:
for test_idx in range(len(self._test_nets)): for test_idx in range(len(self._test_nets)):
self.test(test_idx) self.test(test_idx)
# Forward, backward and compute loss. # Forward, backward and compute loss.
loss = 0. loss = 0.
for i in range(self._param.iter_size): for i in range(self._param.iter_size):
self._net.forward_backward(return_outputs=False) self._net.forward_backward()
if self._is_root: if self._is_root:
for e in self.net.losses: for e in self.net.losses:
values = e.get_value().flatten() values = e.get_value().flatten()
...@@ -322,24 +259,23 @@ class Solver(object): ...@@ -322,24 +259,23 @@ class Solver(object):
idx = (self.iter - start_step) % self._param.average_loss idx = (self.iter - start_step) % self._param.average_loss
smoothed_loss += ((loss - loss_vec[idx]) / self._param.average_loss) smoothed_loss += ((loss - loss_vec[idx]) / self._param.average_loss)
loss_vec[idx] = loss loss_vec[idx] = loss
# Apply Update. # Apply Update.
self._get_learning_rate() self._get_learning_rate()
self._apply_update() self._apply_update()
# Display iteration info.
# Display.
if self._is_root and self._param.display: if self._is_root and self._param.display:
if self.iter % self._param.display == 0: if self.iter % self._param.display == 0:
print('Iteration %d, lr = %s, loss = %f, time = %.2fs' % ( logging.info(
self.iter, str(self.base_lr), smoothed_loss, time.time() - tic)) 'Iteration %d, lr = %s, loss = %f, time = %.2fs'
% (self.iter, str(self.base_lr), smoothed_loss, time.time() - tic))
tic = time.time() tic = time.time()
for idx, net_output in enumerate(self.net.outputs): for idx, net_output in enumerate(self.net.outputs):
values = self.net.blobs[net_output].data.get_value().flatten() values = self.net.blobs[net_output].data.get_value().flatten()
for v in values: for v in values:
print(' ' * 10 + 'Train net output #{}({}): {}' logging.info(
.format(idx, net_output, v)) ' ' * 10 + 'Train net output #{}({}): {}'
.format(idx, net_output, v))
self.iter = self.iter + 1 self.iter = self.iter + 1
# Snapshot if necessary. # Snapshot if necessary.
if self._param.snapshot: if self._param.snapshot:
if self.iter % self._param.snapshot == 0: if self.iter % self._param.snapshot == 0:
...@@ -359,7 +295,7 @@ class Solver(object): ...@@ -359,7 +295,7 @@ class Solver(object):
test_iter = self._param.test_iter[test_idx] test_iter = self._param.test_iter[test_idx]
for iter in range(test_iter): for iter in range(test_iter):
net.forward_backward(return_outputs=False) net.forward()
if not self._is_root: if not self._is_root:
continue continue
if iter == 0: if iter == 0:
...@@ -376,27 +312,25 @@ class Solver(object): ...@@ -376,27 +312,25 @@ class Solver(object):
test_score[i] += value test_score[i] += value
i += 1 i += 1
if not self._is_root: logging.info('Iteration {}, Test net #{}'.format(self.iter, test_idx))
return
print('Iteration {}, Test net #{}'.format(self.iter, test_idx))
for i, score in enumerate(test_score): for i, score in enumerate(test_score):
print(' ' * 10 + 'Test net output #%d(%s): %.4f' logging.info(
% (i, output_id[i], score / test_iter)) ' ' * 10 + 'Test net output #%d(%s): %.4f'
% (i, output_id[i], score / test_iter))
class AdamSolver(Solver): class AdamSolver(Solver):
r"""The Adam solver. r"""The Adam solver.
`[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_. `[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_.
Following hyper parameters will be taken: Examples:
```python ```python
caffe_pb2.SolverParameter( solver {
base_lr=0., base_lr=0.001,
momentum=0., momentum=0.9,
momentum2=0.999, momentum2=0.999,
delta=1e-8, delta=1e-8,
) )
``` ```
...@@ -425,13 +359,13 @@ class NesterovSolver(Solver): ...@@ -425,13 +359,13 @@ class NesterovSolver(Solver):
r"""The Nesterov-SGD solver. r"""The Nesterov-SGD solver.
`[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_. `[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_.
Following hyper parameters will be taken: Examples:
```python ```python
caffe_pb2.SolverParameter( solver {
base_lr=0., base_lr: 0.01
momentum=0., momentum: 0.9
) }
``` ```
""" """
...@@ -457,13 +391,13 @@ class RMSPropSolver(Solver): ...@@ -457,13 +391,13 @@ class RMSPropSolver(Solver):
r"""The RMSProp solver. r"""The RMSProp solver.
`[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_. `[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_.
Following hyper parameters will be taken: Examples:
```python ```python
caffe_pb2.SolverParameter( solver {
base_lr=0., base_lr=0.01,
rms_decay=0.99, rms_decay=0.99,
delta=1e-8, delta=1e-8,
) )
``` ```
...@@ -491,12 +425,12 @@ class SGDSolver(Solver): ...@@ -491,12 +425,12 @@ class SGDSolver(Solver):
r"""The Momentum-SGD solver. r"""The Momentum-SGD solver.
`[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_. `[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_.
Following hyper parameters will be taken: Examples:
```python ```python
caffe_pb2.SolverParameter( solver {
base_lr=0., base_lr=0.01,
momentum=0., momentum=0.9,
) )
``` ```
......
...@@ -3,9 +3,9 @@ Building Dragon Documentation ...@@ -3,9 +3,9 @@ Building Dragon Documentation
This page will help you to build the following documentations: This page will help you to build the following documentations:
Dragon C++ API: http://dragon.seetatech.com/api/cc Dragon C++ API: https://dragon.seetatech.com/api/cc
Dragon Python API: http://dragon.seetatech.com/api/python Dragon Python API: https://dragon.seetatech.com/api/python
Build Documentation of C++ API Build Documentation of C++ API
------------------------------ ------------------------------
......
...@@ -34,10 +34,6 @@ vm.caffe.layers ...@@ -34,10 +34,6 @@ vm.caffe.layers
`class Deconvolution <layers/Deconvolution.html>`_ `class Deconvolution <layers/Deconvolution.html>`_
: Apply the n-dimension deconvolution. : Apply the n-dimension deconvolution.
`class DepthwiseConv2d <layers/DepthwiseConv2d.html>`_
: Apply the 2d depthwise convolution.
`[Chollet, 2016] <https://arxiv.org/abs/1610.02357>`_.
`class Dropout <layers/Dropout.html>`_ `class Dropout <layers/Dropout.html>`_
: Set the elements of the input to zero randomly. : Set the elements of the input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_. `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
...@@ -58,18 +54,6 @@ vm.caffe.layers ...@@ -58,18 +54,6 @@ vm.caffe.layers
`class Flatten <layers/Flatten.html>`_ `class Flatten <layers/Flatten.html>`_
: Flatten the input along the given axes. : Flatten the input along the given axes.
`class FusedBatchNorm <layers/FusedBatchNorm.html>`_
: Apply the fused batch normalization.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`class FusedGroupNorm <layers/FusedBatchNorm.html>`_
: Apply the fused group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`class GroupNorm <layers/FusedBatchNorm.html>`_
: Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`class InnerProduct <layers/InnerProduct.html>`_ `class InnerProduct <layers/InnerProduct.html>`_
: Compute the dense matrix multiplication along the given axes. : Compute the dense matrix multiplication along the given axes.
...@@ -121,10 +105,6 @@ vm.caffe.layers ...@@ -121,10 +105,6 @@ vm.caffe.layers
`class Scale <layers/Scale.html>`_ `class Scale <layers/Scale.html>`_
: Compute the affine transformation along the given axes. : Compute the affine transformation along the given axes.
`class SELU <layers/SELU.html>`_
: Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`class Sigmoid <layers/Sigmoid.html>`_ `class Sigmoid <layers/Sigmoid.html>`_
: Apply the sigmoid function. : Apply the sigmoid function.
...@@ -145,7 +125,7 @@ vm.caffe.layers ...@@ -145,7 +125,7 @@ vm.caffe.layers
: Apply the tanh function. : Apply the tanh function.
`class Tile <layers/Tile.html>`_ `class Tile <layers/Tile.html>`_
: Tile the input according to the given multiples. : Repeat the input according to the given axis.
.. toctree:: .. toctree::
:hidden: :hidden:
...@@ -153,21 +133,16 @@ vm.caffe.layers ...@@ -153,21 +133,16 @@ vm.caffe.layers
layers/Accuracy layers/Accuracy
layers/ArgMax layers/ArgMax
layers/BatchNorm layers/BatchNorm
layers/Cast
layers/Concat layers/Concat
layers/Convolution layers/Convolution
layers/Crop layers/Crop
layers/Data layers/Data
layers/Deconvolution layers/Deconvolution
layers/DepthwiseConv2d
layers/Dropout layers/Dropout
layers/Eltwise layers/Eltwise
layers/ELU layers/ELU
layers/EuclideanLoss layers/EuclideanLoss
layers/Flatten layers/Flatten
layers/FusedBatchNorm
layers/FusedGroupNorm
layers/GroupNorm
layers/InnerProduct layers/InnerProduct
layers/Input layers/Input
layers/LRN layers/LRN
...@@ -183,7 +158,6 @@ vm.caffe.layers ...@@ -183,7 +158,6 @@ vm.caffe.layers
layers/ROIAlign layers/ROIAlign
layers/ROIPooling layers/ROIPooling
layers/Scale layers/Scale
layers/SELU
layers/Sigmoid layers/Sigmoid
layers/SigmoidCrossEntropyLoss layers/SigmoidCrossEntropyLoss
layers/SmoothL1Loss layers/SmoothL1Loss
......
DepthwiseConv2d
===============
.. autoclass:: dragon.vm.caffe.layers.DepthwiseConv2d
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
FusedBatchNorm
==============
.. autoclass:: dragon.vm.caffe.layers.FusedBatchNorm
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
FusedGroupNorm
==============
.. autoclass:: dragon.vm.caffe.layers.FusedGroupNorm
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
GroupNorm
=========
.. autoclass:: dragon.vm.caffe.layers.GroupNorm
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
SELU
====
.. autoclass:: dragon.vm.caffe.layers.SELU
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
...@@ -18,8 +18,8 @@ dragon ...@@ -18,8 +18,8 @@ dragon
`class TensorSpec <dragon/TensorSpec.html>`_ `class TensorSpec <dragon/TensorSpec.html>`_
: Spec to describe properties of a tensor. : Spec to describe properties of a tensor.
`class Workspace <dragon/Workspace_.html>`_ `class Workspace <dragon/Workspace.html>`_
: Space to isolate computations that share resources. : Sandbox to isolate the resources and computations.
Functions Functions
--------- ---------
...@@ -151,7 +151,7 @@ dragon ...@@ -151,7 +151,7 @@ dragon
: Return the identity of input with truncated gradient-flow. : Return the identity of input with truncated gradient-flow.
`tile(...) <dragon/tile.html>`_ `tile(...) <dragon/tile.html>`_
: Tile the input according to the given multiples. : Tile the input according to the given repeats.
`transpose(...) <dragon/transpose.html>`_ `transpose(...) <dragon/transpose.html>`_
: Permute the dimensions of input. : Permute the dimensions of input.
...@@ -217,7 +217,7 @@ dragon ...@@ -217,7 +217,7 @@ dragon
dragon/tile dragon/tile
dragon/transpose dragon/transpose
dragon/where dragon/where
dragon/Workspace_ dragon/Workspace
dragon/zeros dragon/zeros
dragon/zeros_like dragon/zeros_like
......
...@@ -14,10 +14,6 @@ gradient ...@@ -14,10 +14,6 @@ gradient
######## ########
.. automethod:: dragon.GradientTape.gradient .. automethod:: dragon.GradientTape.gradient
replay
######
.. automethod:: dragon.GradientTape.replay
reset reset
##### #####
.. automethod:: dragon.GradientTape.reset .. automethod:: dragon.GradientTape.reset
......
...@@ -30,6 +30,10 @@ shape ...@@ -30,6 +30,10 @@ shape
##### #####
.. autoattribute:: dragon.Tensor.shape .. autoattribute:: dragon.Tensor.shape
size
#####
.. autoattribute:: dragon.Tensor.size
Methods Methods
------- -------
......
Workspace
=========
.. autoclass:: dragon.Workspace
__init__
--------
.. automethod:: dragon.Workspace.__init__
Methods
-------
as_default
##########
.. automethod:: dragon.Workspace.as_default
clear
#####
.. automethod:: dragon.Workspace.clear
merge_from
##########
.. automethod:: dragon.Workspace.merge_from
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
dragon.workspace Workspace
================ =========
.. only:: html .. autoclass:: dragon.Workspace
Functions __init__
--------- --------
.. automethod:: dragon.Workspace.__init__
`feed_tensor(...) <workspace/feed_tensor.html>`_ Methods
: Copy the value to tensor. -------
`fetch_tensor(...) <workspace/fetch_tensor.html>`_ as_default
: Return the value of tensor. ##########
.. automethod:: dragon.Workspace.as_default
`has_tensor(...) <workspace/has_tensor.html>`_ feed_tensor
: Return a bool indicating if tensor is in current workspace. ###########
.. automethod:: dragon.Workspace.feed_tensor
`load(...) <workspace/load.html>`_ fetch_tensor
: Load tensors from a binary file. ############
.. automethod:: dragon.Workspace.fetch_tensor
`reset_tensor(...) <workspace/reset_tensor.html>`_ has_tensor
: Reset the memory of tensor. ##########
.. automethod:: dragon.Workspace.has_tensor
`run_operator(...) <workspace/run_operator.html>`_ merge_from
: Run the operators in current workspace. ##########
.. automethod:: dragon.Workspace.merge_from
`save(...) <workspace/save.html>`_ reset_tensor
: Serialize tensors into a binary file. ############
.. automethod:: dragon.Workspace.reset_tensor
.. toctree::
:hidden:
workspace/feed_tensor
workspace/fetch_tensor
workspace/has_tensor
workspace/load
workspace/reset_tensor
workspace/run_operator
workspace/save
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "Module: "; content: "dragon.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
feed_tensor
===========
.. autofunction:: dragon.workspace.feed_tensor
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
fetch_tensor
============
.. autofunction:: dragon.workspace.fetch_tensor
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
has_tensor
==========
.. autofunction:: dragon.workspace.has_tensor
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
load
====
.. autofunction:: dragon.workspace.load
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
reset_tensor
============
.. autofunction:: dragon.workspace.reset_tensor
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
run_operator
============
.. autofunction:: dragon.workspace.run_operator
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
save
====
.. autofunction:: dragon.workspace.save
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
...@@ -40,7 +40,6 @@ Dragon ...@@ -40,7 +40,6 @@ Dragon
* `dragon.nn <dragon/nn.html>`_ * `dragon.nn <dragon/nn.html>`_
* `dragon.optimizers <dragon/optimizers.html>`_ * `dragon.optimizers <dragon/optimizers.html>`_
* `dragon.random <dragon/random.html>`_ * `dragon.random <dragon/random.html>`_
* `dragon.workspace <dragon/workspace.html>`_
* `dragon.vision <dragon/vision.html>`_ * `dragon.vision <dragon/vision.html>`_
Caffe Caffe
...@@ -112,6 +111,7 @@ PyTorch ...@@ -112,6 +111,7 @@ PyTorch
This style involves the following components: This style involves the following components:
* `torch <torch.html>`_ * `torch <torch.html>`_
* `torch.autograd <torch/autograd.html>`_
* `torch.distributed <torch/distributed.html>`_ * `torch.distributed <torch/distributed.html>`_
* `torch.jit <torch/jit.html>`_ * `torch.jit <torch/jit.html>`_
* `torch.nn <torch/nn.html>`_ * `torch.nn <torch/nn.html>`_
...@@ -206,15 +206,9 @@ Modules ...@@ -206,15 +206,9 @@ Modules
`Module random <dragon/random.html>`_ `Module random <dragon/random.html>`_
: Native API for ``dragon.random`` namespace. : Native API for ``dragon.random`` namespace.
`Module workspace <dragon/workspace.html>`_
: Native API for ``dragon.workspace`` namespace.
`Module vision <dragon/vision.html>`_ `Module vision <dragon/vision.html>`_
: Native API for ``dragon.vision`` namespace. : Native API for ``dragon.vision`` namespace.
`Module workspace <dragon/workspace.html>`_
: Native API for ``dragon.workspace`` namespace.
`Module vm.caffe <caffe.html>`_ `Module vm.caffe <caffe.html>`_
: Virtual API for ``caffe`` namespace. : Virtual API for ``caffe`` namespace.
...@@ -278,6 +272,9 @@ Modules ...@@ -278,6 +272,9 @@ Modules
`Module vm.torch <torch.html>`_ `Module vm.torch <torch.html>`_
: Virtual API for ``torch`` namespace. : Virtual API for ``torch`` namespace.
`Module vm.torch.autograd <torch/autograd.html>`_
: Virtual API for ``torch.autograd`` namespace.
`Module vm.torch.distributed <torch/distributed.html>`_ `Module vm.torch.distributed <torch/distributed.html>`_
: Virtual API for ``torch.distributed`` namespace. : Virtual API for ``torch.distributed`` namespace.
...@@ -319,7 +316,6 @@ Modules ...@@ -319,7 +316,6 @@ Modules
dragon/nn dragon/nn
dragon/optimizers dragon/optimizers
dragon/random dragon/random
dragon/workspace
dragon/vision dragon/vision
caffe caffe
caffe/layers caffe/layers
...@@ -343,6 +339,7 @@ Modules ...@@ -343,6 +339,7 @@ Modules
tensorrt tensorrt
tensorrt/backend tensorrt/backend
torch torch
torch/autograd
torch/distributed torch/distributed
torch/jit torch/jit
torch/nn torch/nn
......
...@@ -15,11 +15,6 @@ gradient ...@@ -15,11 +15,6 @@ gradient
.. automethod:: dragon.GradientTape.gradient .. automethod:: dragon.GradientTape.gradient
:noindex: :noindex:
replay
######
.. automethod:: dragon.GradientTape.replay
:noindex:
reset reset
##### #####
.. automethod:: dragon.GradientTape.reset .. automethod:: dragon.GradientTape.reset
......
vm.torch.autograd
==================
.. only:: html
Functions
---------
`backward(...) <autograd/backward.html>`_
: Compute the derivatives of tensors w.r.t. graph leaves.
.. toctree::
:hidden:
autograd/backward
.. raw:: html
<style>
h1:before {
content: "Module: dragon.";
color: #103d3e;
}
</style>
Cast backward
==== ========
.. autoclass:: dragon.vm.caffe.layers.Cast .. autofunction:: dragon.vm.torch.autograd.backward
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "caffe.layers."; content: "torch.autograd.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -7,7 +7,7 @@ all_reduce ...@@ -7,7 +7,7 @@ all_reduce
<style> <style>
h1:before { h1:before {
content: "torch.nn.distributed."; content: "torch.distributed.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -7,7 +7,7 @@ broadcast ...@@ -7,7 +7,7 @@ broadcast
<style> <style>
h1:before { h1:before {
content: "torch.nn.distributed."; content: "torch.distributed.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -7,7 +7,7 @@ trace ...@@ -7,7 +7,7 @@ trace
<style> <style>
h1:before { h1:before {
content: "torch.nn.jit."; content: "torch.jit.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -10,25 +10,25 @@ __init__ ...@@ -10,25 +10,25 @@ __init__
Methods Methods
------- -------
accumulate_grad accumulate
############### ##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate_grad .. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex: :noindex:
add_param_group add_param_group
############### ###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group .. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
:noindex: :noindex:
step step
#### ####
.. automethod:: dragon.vm.torch.optim.Optimizer.step .. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex: :noindex:
zero_grad zero_grad
######### #########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad .. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
:noindex: :noindex:
.. raw:: html .. raw:: html
......
...@@ -10,9 +10,9 @@ __init__ ...@@ -10,9 +10,9 @@ __init__
Methods Methods
------- -------
accumulate_grad accumulate
############### ##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate_grad .. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
add_param_group add_param_group
############### ###############
......
...@@ -10,25 +10,25 @@ __init__ ...@@ -10,25 +10,25 @@ __init__
Methods Methods
------- -------
accumulate_grad accumulate
############### ##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate_grad .. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex: :noindex:
add_param_group add_param_group
############### ###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group .. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
:noindex: :noindex:
step step
#### ####
.. automethod:: dragon.vm.torch.optim.Optimizer.step .. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex: :noindex:
zero_grad zero_grad
######### #########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad .. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
:noindex: :noindex:
.. raw:: html .. raw:: html
......
...@@ -10,25 +10,25 @@ __init__ ...@@ -10,25 +10,25 @@ __init__
Methods Methods
------- -------
accumulate_grad accumulate
############### ##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate_grad .. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex: :noindex:
add_param_group add_param_group
############### ###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group .. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
:noindex: :noindex:
step step
#### ####
.. automethod:: dragon.vm.torch.optim.Optimizer.step .. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex: :noindex:
zero_grad zero_grad
######### #########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad .. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
:noindex: :noindex:
.. raw:: html .. raw:: html
......
...@@ -7,7 +7,7 @@ namespace dragon { ...@@ -7,7 +7,7 @@ namespace dragon {
GraphBase::GraphBase(const GraphDef& def, Workspace* ws) GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
: def_(def), ws_(ws), name_(def.name()), phase_("TEST") { : def_(def), ws_(ws), name_(def.name()), phase_("TEST") {
// Scan the defined arguments // Collect arguments
for (auto& arg : def_.arg()) { for (auto& arg : def_.arg()) {
CHECK_GT(arg.name().size(), 0); CHECK_GT(arg.name().size(), 0);
CHECK_EQ(args_.count(arg.name()), 0); CHECK_EQ(args_.count(arg.name()), 0);
...@@ -18,32 +18,31 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws) ...@@ -18,32 +18,31 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
// Collect outputs // Collect outputs
Set<string> outputs; Set<string> outputs;
for (const auto& op : def.op()) { for (const auto& op : def.op()) {
for (const auto& in : op.input()) for (const auto& input : op.input())
CHECK(outputs.count(in) || ws_->HasTensor(in)) CHECK(outputs.count(input) || ws_->HasTensor(input))
<< "\nInput: " << in << " for op: " << op.name() << " is unknown."; << "\nThe input <" << input << "> is not in graph.";
for (const auto& out : op.output()) for (const auto& output : op.output()) {
outputs.insert(out); outputs.insert(output);
}
} }
// Check targets // Check targets
Set<string> targets; Set<string> targets;
for (const auto& target : def.output()) { for (const auto& target : def.output()) {
CHECK(outputs.count(target) || ws_->HasTensor(target)) CHECK(outputs.count(target) || ws_->HasTensor(target))
<< "\nTarget: " << target << " does not exist in the graph."; << "\nThe output <" << target << "> is not in graph.";
targets.insert(target); targets.insert(target);
} }
// Check gradients // Check gradients
for (const auto& gradient : def.gradient()) { for (const auto& grad_info : def.grad_info()) {
const auto& cost = gradient.cost(); const auto& y = grad_info.y();
const auto& wrt = gradient.wrt(); CHECK_GT(targets.count(y), 0)
CHECK(outputs.count(cost) || ws_->HasTensor(cost)) << "\nThe derivative target <" << y << "> is not in outputs.";
<< "\nTarget: " << cost << "does not exist in the graph."; for (const auto& x : grad_info.xs()) {
CHECK(outputs.count(wrt) || ws_->HasTensor(wrt)) CHECK(outputs.count(x) || ws_->HasTensor(x))
<< "\nTarget: " << wrt << "does not exist in the graph."; << "\nThe differentiated input <" << x << "> is not in graph.";
CHECK_GT(targets.count(cost), 0) }
<< "\nTo solve d(" << cost << ")/d(" << wrt << "),\n"
<< cost << " should be set as a target.";
} }
} }
...@@ -54,21 +53,18 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) { ...@@ -54,21 +53,18 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) {
auto op_def(def.op(i)); auto op_def(def.op(i));
LOG(DEBUG) << "Create Operator " << op_def.name() << ": " << op_def.type(); LOG(DEBUG) << "Create Operator " << op_def.name() << ": " << op_def.type();
// Inherit device option if necessary // Inherit device option if necessary
if (!op_def.has_device_option() && has_device_option) if (!op_def.has_device_option() && has_device_option) {
op_def.mutable_device_option()->CopyFrom(def.device_option()); op_def.mutable_device_option()->CopyFrom(def.device_option());
}
Argument arg; Argument arg;
arg.set_name("allow_recomp");
arg.set_i(1);
op_def.add_arg()->CopyFrom(arg);
// For the last operator, enforce the synchronization // For the last operator, enforce the synchronization
if (i == def.op_size() - 1) { if (i == def.op_size() - 1) {
arg.set_name("do_sync"); arg.set_name("do_sync");
arg.set_i(1); arg.set_i(1);
op_def.add_arg()->CopyFrom(arg); op_def.add_arg()->CopyFrom(arg);
} }
ops_.push_back(NewOperator(op_def, ws)); cached_ops_.push_back(NewOperator(op_def, ws));
// Attatch the output aliases info cached_ops_.back()->set_output_aliases(output_aliases_);
ops_.back()->set_output_aliases(output_aliases_);
} }
return true; return true;
} }
...@@ -80,7 +76,7 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { ...@@ -80,7 +76,7 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
GraphGradientMaker gradient_maker; GraphGradientMaker gradient_maker;
Map<string, vec32_t> subgraph_indices; Map<string, vec32_t> subgraph_indices;
int opt = 3; // defaults: O3 int opt = 3; // defaults: O3
if (args().count("optimization_level")) opt = arg("optimization_level").i(); if (args().count("optimization")) opt = arg("optimization").i();
if (opt >= 1) opt_def = graph_optim.PruneNodes(def); if (opt >= 1) opt_def = graph_optim.PruneNodes(def);
if (opt >= 2) graph_optim.AddInplace(opt_def, output_aliases_); if (opt >= 2) graph_optim.AddInplace(opt_def, output_aliases_);
if (opt >= 3) { if (opt >= 3) {
...@@ -101,22 +97,23 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { ...@@ -101,22 +97,23 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
for (const auto& it : subgraph_indices) { for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>(); subgraph[it.first] = vector<OperatorBase*>();
for (const auto& idx : subgraph_indices[it.first]) for (const auto& idx : subgraph_indices[it.first])
subgraph[it.first].push_back(ops_[idx]); subgraph[it.first].push_back(cached_ops_[idx]);
} }
for (const auto& op : ops_) for (auto* op : cached_ops_) {
op->set_subgraph(subgraph); op->set_subgraph(subgraph);
}
} }
} }
bool Graph::Run(const string& incl, const string& excl, int stream_id) { bool Graph::Run(const string& include, const string& exclude, int stream) {
LOG(DEBUG) << "Run Graph: " << name(); LOG(DEBUG) << "Run Graph: " << name();
for (auto op : ops_) { for (auto* op : cached_ops_) {
if (!incl.empty() && !str::find(op->type(), incl)) continue; if (!include.empty() && !str::find(op->type(), include)) continue;
if (!excl.empty() && str::find(op->type(), excl)) continue; if (!exclude.empty() && str::find(op->type(), exclude)) continue;
op->SwitchToPhase(phase()); op->SwitchToPhase(phase());
LOG(DEBUG) << "$ Before Operator: " << op->name(); LOG(DEBUG) << "Run Op: " << op->name();
op->Run(stream_id); op->Run(stream);
LOG(DEBUG) << "$ After Operator: " << op->name(); LOG(DEBUG) << "Finish Op: " << op->name();
} }
return true; return true;
} }
......
...@@ -88,8 +88,8 @@ class Graph : public GraphBase { ...@@ -88,8 +88,8 @@ class Graph : public GraphBase {
/*! \brief Default Destructor */ /*! \brief Default Destructor */
virtual ~Graph() { virtual ~Graph() {
for (auto* op : ops_) { for (auto* cached_op : cached_ops_) {
delete op; delete cached_op;
} }
} }
...@@ -100,8 +100,8 @@ class Graph : public GraphBase { ...@@ -100,8 +100,8 @@ class Graph : public GraphBase {
bool Run(const string&, const string&, int = 0) override; bool Run(const string&, const string&, int = 0) override;
protected: protected:
/*! \brief Store the internal operators */ /*! \brief The cached operators */
vector<OperatorBase*> ops_; vector<OperatorBase*> cached_ops_;
/*! \brief Store the candidate output aliases */ /*! \brief Store the candidate output aliases */
Map<string, Set<string>> output_aliases_; Map<string, Set<string>> output_aliases_;
......
...@@ -4,23 +4,24 @@ ...@@ -4,23 +4,24 @@
namespace dragon { namespace dragon {
bool GraphGradientMaker::CheckGrad( bool GraphGradientMaker::CheckGrad(
const OperatorDef& forward_op, const OperatorDef& op_def,
const Set<string>& targets, const Set<string>& targets,
vector<pair<string, int>>& gen_grads) { vector<pair<string, int>>& gen_grads) {
if (NoGradientRegistry()->Has(forward_op.type())) { if (NoGradientRegistry()->Has(op_def.type())) {
for (auto& input : forward_op.input()) for (auto& input : op_def.input()) {
blacklist_set_.insert(input); blacklist_set_.insert(input);
}
return true; return true;
} }
for (int i = 0; i < forward_op.output_size(); ++i) { for (int i = 0; i < op_def.output_size(); ++i) {
const auto& output = forward_op.output(i); const auto& output = op_def.output(i);
if (!inputs_to_grads_.count(output)) { if (!inputs_to_grads_.count(output)) {
if (blacklist_set_.count(output)) return true; if (blacklist_set_.count(output)) return true;
if (targets.count(output)) { if (targets.count(output)) {
// Consider to generate virtual gradient for targets // Consider to generate virtual gradient for targets
gen_grads.push_back({output, i}); gen_grads.push_back({output, i});
inputs_to_grads_[output] = output + "_grad"; inputs_to_grads_[output] = output + "_grad";
} else if (forward_op.output_size() == 1) { } else if (op_def.output_size() == 1) {
return true; // We can skip this op, obviously return true; // We can skip this op, obviously
} }
} }
...@@ -30,7 +31,7 @@ bool GraphGradientMaker::CheckGrad( ...@@ -30,7 +31,7 @@ bool GraphGradientMaker::CheckGrad(
} }
void GraphGradientMaker::Make( void GraphGradientMaker::Make(
const vector<OperatorDef*>& forward_ops, const vector<OperatorDef*>& op_defs,
const vector<string>& targets, const vector<string>& targets,
const vector<string>& input_grads, const vector<string>& input_grads,
GraphDef& backward_def) { GraphDef& backward_def) {
...@@ -39,11 +40,11 @@ void GraphGradientMaker::Make( ...@@ -39,11 +40,11 @@ void GraphGradientMaker::Make(
Map<string, string> targets_to_grads; Map<string, string> targets_to_grads;
// PLAY for the forward // PLAY for the forward
for (auto* op : forward_ops) { for (auto* op_def : op_defs) {
if (NoGradientRegistry()->Has(op->type())) continue; if (NoGradientRegistry()->Has(op_def->type())) continue;
for (const auto& input : op->input()) { for (const auto& input : op_def->input()) {
bool input_in_outputs = false; bool input_in_outputs = false;
for (auto& output : op->output()) for (auto& output : op_def->output())
if (output == input) { if (output == input) {
input_in_outputs = true; input_in_outputs = true;
break; break;
...@@ -62,9 +63,9 @@ void GraphGradientMaker::Make( ...@@ -62,9 +63,9 @@ void GraphGradientMaker::Make(
targets_set.insert(targets[i]); targets_set.insert(targets[i]);
} }
for (int op_idx = (int)forward_ops.size() - 1; op_idx >= 0; --op_idx) { for (int op_idx = (int)op_defs.size() - 1; op_idx >= 0; --op_idx) {
// Collect inputs and outputs, generate raw gradient ops // Collect inputs and outputs, generate raw gradient ops
const OperatorDef& op = *forward_ops[op_idx]; const OperatorDef& op = *op_defs[op_idx];
vector<pair<string, int>> gen_grads; vector<pair<string, int>> gen_grads;
bool is_skip = CheckGrad(op, targets_set, gen_grads); bool is_skip = CheckGrad(op, targets_set, gen_grads);
vector<string> g_outputs; vector<string> g_outputs;
...@@ -183,9 +184,9 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -183,9 +184,9 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
// Flag the gathering gradients // Flag the gathering gradients
if (op.type() == "GradientGather") { if (op.type() == "GradientGather") {
invalid_ops.insert(op_idx); invalid_ops.insert(op_idx);
if (ignored_grads_.count(op.output(0))) { if (empty_grads_.count(op.output(0))) {
for (const auto& input : op.input()) { for (const auto& input : op.input()) {
ignored_grads_.insert(input); empty_grads_.insert(input);
} }
continue; continue;
} else { } else {
...@@ -200,7 +201,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -200,7 +201,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
} }
// Count the references to detect leafs // Count the references to detect leafs
for (const auto& input : op.input()) { for (const auto& input : op.input()) {
if (str::find(input, "grad")) { if (str::endswith(input, "_grad")) {
ref_count[input] += 1; ref_count[input] += 1;
} }
} }
...@@ -293,21 +294,17 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -293,21 +294,17 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
// Rewrite output gradients // Rewrite output gradients
for (int i = 0; i < op->output_size(); ++i) { for (int i = 0; i < op->output_size(); ++i) {
if (str::startswith(op->type(), "Python")) continue;
const string& output = op->output(i); const string& output = op->output(i);
if (output.empty() || str::startswith(output, "/share/")) continue; if (output.empty() || str::startswith(output, "/share/buffer")) continue;
if (ignored_grads_.count(output) > 0) { if (empty_grads_.count(output) > 0) {
// Prune for non-trainable leafs
*op->mutable_output(i) = ""; *op->mutable_output(i) = "";
continue; continue;
} }
if (hooked_grads_.empty()) { // Protection for leafs
// Protection for leafs if (ref_count.count(output) == 0) continue;
if (ref_count.count(output) == 0) continue; // Protection for sources and leafs
} else { if (retained_grads_.count(output) > 0) continue;
// Protection for sources
if (hooked_grads_.count(output) > 0) continue;
}
if (op->type() == "PythonPluginGradient") continue;
string new_output = output; string new_output = output;
if (inplace_flags[i] >= 0) { if (inplace_flags[i] >= 0) {
new_output = op->input(inplace_flags[i]); new_output = op->input(inplace_flags[i]);
......
...@@ -21,22 +21,22 @@ class DRAGON_API GraphGradientMaker { ...@@ -21,22 +21,22 @@ class DRAGON_API GraphGradientMaker {
public: public:
/*! \brief Generate a backward graph from the forward ops */ /*! \brief Generate a backward graph from the forward ops */
void Make( void Make(
const vector<OperatorDef*>& forward_ops, const vector<OperatorDef*>& op_defs,
const vector<string>& targets, const vector<string>& targets,
const vector<string>& input_grads, const vector<string>& input_grads,
GraphDef& backward_def); GraphDef& graph_def);
/*! \brief Rewrite a graph to share the intermediate grads */ /*! \brief Rewrite a graph to share the intermediate grads */
GraphDef Share(const GraphDef& input_def); GraphDef Share(const GraphDef& input_def);
/*! \brief Add a hooked gradient */ /*! \brief Add an empty gradient */
void add_hooked_grad(const string& name) { void add_empty_grad(const string& name) {
hooked_grads_.insert(name); empty_grads_.insert(name);
} }
/*! \brief Add an ignored gradient */ /*! \brief Add a retained gradient */
void add_ignored_grad(const string& name) { void add_retained_grad(const string& name) {
ignored_grads_.insert(name); retained_grads_.insert(name);
} }
/*! \brief Set the prefix of backward op name */ /*! \brief Set the prefix of backward op name */
...@@ -47,32 +47,32 @@ class DRAGON_API GraphGradientMaker { ...@@ -47,32 +47,32 @@ class DRAGON_API GraphGradientMaker {
private: private:
/*! \brief Check the missing grads of backward procedure */ /*! \brief Check the missing grads of backward procedure */
bool CheckGrad( bool CheckGrad(
const OperatorDef& forward_op, const OperatorDef& op_def,
const Set<string>& targets, const Set<string>& targets,
vector<pair<string, int>>& gen_grads); vector<pair<string, int>>& gen_grads);
/*! \brief Return a dummy operator name */ /*! \brief Return a dummy operator name */
string GetOperatorName() { string GetOperatorName() {
if (op_prefix_.empty()) return "Generic"; if (op_prefix_.empty()) return "GradientOp";
return op_prefix_ + str::to(op_index_++); return op_prefix_ + str::to(op_index_++);
} }
/*! \brief Store the mapping of intermediate grads */ /*! \brief The mapping from input to grad */
Map<string, string> inputs_to_grads_; Map<string, string> inputs_to_grads_;
/*! \brief Store the non-gradient outputs */ /*! \brief The non-gradient outputs */
Set<string> blacklist_set_; Set<string> blacklist_set_;
/*! \brief Store the non-shared gradients */ /*! \brief The gradients should be retained */
Set<string> hooked_grads_; Set<string> retained_grads_;
/*! \brief Store the gradients that are not required */ /*! \brief The gradients should be set to empty */
Set<string> ignored_grads_; Set<string> empty_grads_;
/*! \brief Store the prefix of dummy operator name */ /*! \brief The prefix of op name */
string op_prefix_; string op_prefix_;
/*! \brief Store the counter of dummy operator name */ /*! \brief The counter of op name */
int64_t op_index_ = 0; int64_t op_index_ = 0;
}; };
......
...@@ -39,14 +39,12 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) { ...@@ -39,14 +39,12 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) {
BackwardPrunePass(target); BackwardPrunePass(target);
} }
// Forward pass from gradients for (const auto& grad_info : input_def.grad_info()) {
for (const auto& gradient : input_def.gradient()) { const auto u = grad_info.y() + "_grad";
auto u = gradient.cost() + "_grad"; for (const auto& x : grad_info.xs()) {
auto v = gradient.wrt() + "_grad"; visited_.clear();
if (ws_->HasTensor(u)) u = ws_->GetTensor(u)->name(); ForwardPrunePass(u, x + "_grad", std::deque<string>({u}));
if (ws_->HasTensor(v)) v = ws_->GetTensor(v)->name(); }
visited_.clear();
ForwardPrunePass(u, v, vector<string>({u}));
} }
// Select all colored operators // Select all colored operators
...@@ -64,7 +62,6 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) { ...@@ -64,7 +62,6 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) {
// Generate the final op sequence // Generate the final op sequence
map<int, OperatorDef> final_sequence; map<int, OperatorDef> final_sequence;
for (auto op_idx : selected_op_indices) { for (auto op_idx : selected_op_indices) {
const auto& op = input_def.op(op_idx); const auto& op = input_def.op(op_idx);
auto new_op(input_def.op(op_idx)); auto new_op(input_def.op(op_idx));
...@@ -308,11 +305,13 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) { ...@@ -308,11 +305,13 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) {
void GraphOptimizer::ForwardPrunePass( void GraphOptimizer::ForwardPrunePass(
const string& u, const string& u,
const string& leaf, const string& leaf,
const vector<string>& path) { const std::deque<string>& path) {
if (visited_.count(u)) { if (visited_.count(u)) {
if (visited_[u]) if (visited_[u]) {
for (const auto& node : path) for (const auto& node : path) {
visited_[node] = colored_[node] = true; visited_[node] = colored_[node] = true;
}
}
return; return;
} }
visited_[u] = false; visited_[u] = false;
...@@ -321,8 +320,9 @@ void GraphOptimizer::ForwardPrunePass( ...@@ -321,8 +320,9 @@ void GraphOptimizer::ForwardPrunePass(
auto new_path(path); auto new_path(path);
new_path.push_back(v); new_path.push_back(v);
if (v == leaf) { if (v == leaf) {
for (const auto& node : new_path) for (const auto& node : new_path) {
visited_[node] = colored_[node] = true; visited_[node] = colored_[node] = true;
}
return; return;
} }
ForwardPrunePass(v, leaf, new_path); ForwardPrunePass(v, leaf, new_path);
......
...@@ -56,7 +56,7 @@ class GraphOptimizer { ...@@ -56,7 +56,7 @@ class GraphOptimizer {
void ForwardPrunePass( void ForwardPrunePass(
const string& u, const string& u,
const string& leaf, const string& leaf,
const vector<string>& path); const std::deque<string>& path);
/*! \brief Pass from targets to remove unused nodes */ /*! \brief Pass from targets to remove unused nodes */
void BackwardPrunePass(const string& v); void BackwardPrunePass(const string& v);
......
...@@ -41,14 +41,11 @@ OperatorBase::OperatorBase(const OperatorDef& def, Workspace* ws) ...@@ -41,14 +41,11 @@ OperatorBase::OperatorBase(const OperatorDef& def, Workspace* ws)
} }
} }
template <class Context> // template <class Context>
Operator<Context>::Operator(const OperatorDef& def, Workspace* ws) // Operator<Context>::Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws), // : OperatorBase(def, ws),
ctx_(def.device_option()), // ctx_(def.device_option()),
do_sync_(OpArg<bool>("do_sync", false)), // do_sync_(OpArg<bool>("do_sync", false)) {}
allow_recomp_(OpArg<bool>("allow_recomp", false)) {
allow_run_ = (!(OutputSize() == 1 && !Output(0)->has_name()));
}
Tensor& OperatorBase::Input(int i) { Tensor& OperatorBase::Input(int i) {
CHECK_LT(i, (int)inputs_.size()); CHECK_LT(i, (int)inputs_.size());
...@@ -112,32 +109,32 @@ OperatorBase* OperatorBase::UpdateFrom(const OperatorDef& def) { ...@@ -112,32 +109,32 @@ OperatorBase* OperatorBase::UpdateFrom(const OperatorDef& def) {
handle_ = def.name(); handle_ = def.name();
inputs_.resize(def.input_size()); inputs_.resize(def.input_size());
outputs_.resize(def.output_size()); outputs_.resize(def.output_size());
for (int i = 0; i < inputs_.size(); i++) for (int i = 0; i < inputs_.size(); i++) {
inputs_[i] = ws()->GetTensor(def.input(i)); inputs_[i] = ws()->GetTensor(def.input(i));
for (int i = 0; i < outputs_.size(); i++) }
for (int i = 0; i < outputs_.size(); i++) {
outputs_[i] = ws()->CreateTensor(def.output(i)); outputs_[i] = ws()->CreateTensor(def.output(i));
}
return this; return this;
} }
template <class Context> template <class Context>
void Operator<Context>::Prepare() { void Operator<Context>::Prepare() {
string tensor_name;
size_t ver_pos;
int version;
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); i++) {
if (Input(i).version() >= 0) { if (Input(i).version() >= 0) {
tensor_name = def().input(i); const auto& name = def().input(i);
ver_pos = tensor_name.find("/ver:"); auto ver_pos = name.find("/ver:");
version = std::atoi(tensor_name.substr(ver_pos + 5).c_str()); auto version = std::atoi(name.substr(ver_pos + 5).c_str());
if (version == Input(i).version()) continue; if (version == Input(i).version()) continue;
LOG(DEBUG) << "Excepted version of Tensor(" + Input(i).name() + ") " LOG(DEBUG) << "Excepted version of Tensor(" + Input(i).name() + ") "
<< "is " << version << ", got " << Input(i).version() << "is " << version << ", got " << Input(i).version()
<< ". Recompute."; << ". Recompute.";
Tensor* flag = ws()->GetTensor("/share/flag/recomputing"); Tensor* flag = ws()->GetTensor("/share/flag/recomputing");
flag->mutable_data<bool, CPUContext>()[0] = true; flag->mutable_data<bool, CPUContext>()[0] = true;
vector<OperatorBase*>& chain = subgraph()[tensor_name]; vector<OperatorBase*>& chain = subgraph()[name];
for (auto* op : chain) for (auto* op : chain) {
op->Run(ctx()->stream_id()); op->Run(ctx()->stream_id());
}
flag->mutable_data<bool, CPUContext>()[0] = false; flag->mutable_data<bool, CPUContext>()[0] = false;
} }
} }
...@@ -145,14 +142,11 @@ void Operator<Context>::Prepare() { ...@@ -145,14 +142,11 @@ void Operator<Context>::Prepare() {
template <class Context> template <class Context>
void Operator<Context>::Release() { void Operator<Context>::Release() {
string tensor_name;
size_t ver_pos;
int version;
for (int i = 0; i < OutputSize(); i++) { for (int i = 0; i < OutputSize(); i++) {
if (Output(i)->version() >= 0) { if (Output(i)->version() >= 0) {
tensor_name = def().output(i); const auto& name = def().output(i);
ver_pos = tensor_name.find("/ver:"); auto ver_pos = name.find("/ver:");
version = std::atoi(tensor_name.substr(ver_pos + 5).c_str()); auto version = std::atoi(name.substr(ver_pos + 5).c_str());
Output(i)->set_version(version); Output(i)->set_version(version);
} }
} }
...@@ -195,8 +189,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) { ...@@ -195,8 +189,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) {
OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) { OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) {
auto* schema = OpSchemaRegistry::Schema(def.type()); auto* schema = OpSchemaRegistry::Schema(def.type());
if (schema) { if (schema != nullptr) {
// Check the Inputs and Outputs if necessary
CHECK(schema->Verify(def)) CHECK(schema->Verify(def))
<< "\nOperator failed to pass the schema checking."; << "\nOperator failed to pass the schema checking.";
} }
...@@ -219,7 +212,7 @@ Gradient MakeGradientForOp( ...@@ -219,7 +212,7 @@ Gradient MakeGradientForOp(
<< "not implemented."; << "not implemented.";
Gradient grad = maker->Make(); Gradient grad = maker->Make();
OperatorDef reference_def(def); OperatorDef reference_def(def);
// Map the cache key // Set the cache key
if (reference_def.has_cache_key()) { if (reference_def.has_cache_key()) {
for (int i = 0; i < grad.ops.size(); ++i) { for (int i = 0; i < grad.ops.size(); ++i) {
grad.ops[i].set_cache_key( grad.ops[i].set_cache_key(
......
...@@ -40,7 +40,7 @@ class DRAGON_API OperatorBase { ...@@ -40,7 +40,7 @@ class DRAGON_API OperatorBase {
} }
/*! \brief Run operator on the specified stream */ /*! \brief Run operator on the specified stream */
virtual void Run(int stream_id = 0) { virtual void Run(int stream = 0) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -154,12 +154,12 @@ class DRAGON_API OperatorBase { ...@@ -154,12 +154,12 @@ class DRAGON_API OperatorBase {
} }
/*! \brief Set the output aliases for in-place */ /*! \brief Set the output aliases for in-place */
void set_output_aliases(const Map<string, Set<string>>& aliases_map) { void set_output_aliases(const Map<string, Set<string>>& alias_map) {
output_aliases_.resize(outputs_.size()); output_aliases_.resize(outputs_.size());
for (int i = 0; i < outputs_.size(); ++i) { for (int i = 0; i < outputs_.size(); ++i) {
auto aliases_iter = aliases_map.find(outputs_[i]->name()); const auto& it = alias_map.find(outputs_[i]->name());
if (aliases_iter != aliases_map.end()) { if (it != alias_map.end()) {
output_aliases_[i] = aliases_iter->second; output_aliases_[i] = it->second;
} else { } else {
output_aliases_[i].clear(); output_aliases_[i].clear();
} }
...@@ -196,7 +196,10 @@ template <class Context> ...@@ -196,7 +196,10 @@ template <class Context>
class DRAGON_API Operator : public OperatorBase { class DRAGON_API Operator : public OperatorBase {
public: public:
/*! \brief Default constructor */ /*! \brief Default constructor */
Operator(const OperatorDef& def, Workspace* ws); Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws),
ctx_(def.device_option()),
do_sync_(OperatorBase::Arg<bool>("do_sync", false)) {}
/*! \brief Prepare the content of inputs */ /*! \brief Prepare the content of inputs */
virtual void Prepare(); virtual void Prepare();
...@@ -207,36 +210,32 @@ class DRAGON_API Operator : public OperatorBase { ...@@ -207,36 +210,32 @@ class DRAGON_API Operator : public OperatorBase {
/*! \brief Coordinate the context of inputs and outputs */ /*! \brief Coordinate the context of inputs and outputs */
virtual void SwitchToDevice(); virtual void SwitchToDevice();
/*! \brief Implement the detailed execution */ /*! \brief The detailed execution on device */
virtual void RunOnDevice() = 0; virtual void RunOnDevice() = 0;
/*! \brief Run this operator on the specified stream */ /*! \brief Run this operator */
void Run(int stream_id = 0) final { void Run(int stream = 0) final {
if (!allow_run_) return; Prepare();
if (allow_recomp_) Prepare(); ctx()->SwitchToDevice(stream);
ctx()->SwitchToDevice(stream_id);
SwitchToDevice(); SwitchToDevice();
RunOnDevice(); RunOnDevice();
if (do_sync_ || stream_id > 0) { if (do_sync_ || stream > 0) {
ctx()->FinishDeviceComputation(); ctx()->FinishDeviceComputation();
} }
if (allow_recomp_) Release(); Release();
} }
/*! \brief Return a bool indicating the run is available */ /*! \brief Return the context */
bool allow_run() const {
return allow_run_;
}
/*! \brief Return the internal context */
Context* ctx() { Context* ctx() {
return &ctx_; return &ctx_;
} }
protected: protected:
/*! \brief Store the internal context */ /*! \brief The context */
Context ctx_; Context ctx_;
bool do_sync_, allow_run_, allow_recomp_;
/*! \brief The executing flags */
bool do_sync_;
}; };
/*! \brief New a operator from the raw def */ /*! \brief New a operator from the raw def */
...@@ -266,9 +265,8 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*); ...@@ -266,9 +265,8 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
using OperatorBase::def; \ using OperatorBase::def; \
using OperatorBase::ws using OperatorBase::ws
#define USE_OPERATOR_FUNCTIONS \ #define USE_OPERATOR_FUNCTIONS \
USE_OPERATOR_BASE_FUNCTIONS; \ USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<Context>::allow_run; \
using Operator<Context>::ctx using Operator<Context>::ctx
#define STORE_INPUT_SPEC(i) \ #define STORE_INPUT_SPEC(i) \
...@@ -342,46 +340,46 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType); ...@@ -342,46 +340,46 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
/* Fillers */ /* Fillers */
#define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \ #define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \
if (tensor.count() == 0) { \ if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \ auto* filler_info = ws()->GetFillerInfo(tensor.name()); \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \ CHECK(filler_info) << "\nTensor(" << tensor.name() << ") is empty.\n" \
<< "may be specify a filler for it?"; \ << "May be specify a filler for it?"; \
tensor.Reshape(shape); \ tensor.Reshape(shape); \
unique_ptr<Filler<type, Context>> filler( \ unique_ptr<Filler<type, Context>> filler( \
CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \ CreateFiller<type, Context>(*filler_info)); \
filler->Fill(&tensor, ctx()); \ filler->Fill(&tensor, ctx()); \
} else { \ } else { \
int64_t count = 1; \ int64_t count = 1; \
for (int i = 0; i < shape.size(); i++) \ for (int i = 0; i < shape.size(); i++) \
count *= shape[i]; \ count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \ CHECK_EQ(count, tensor.count()) \
<< "\nExcepted Tensor(" << tensor.name() << ")'s " \ << "\nExcepted Tensor(" << tensor.name() << ")'s " \
<< "size is " << count << ", \n" \ << "size is " << count << ", \n" \
<< "but now is " << tensor.count() << ", " \ << "but now is " << tensor.count() << ", " \
<< "did you feed the incorrect data before?"; \ << "did you feed the incorrect data before?"; \
tensor.Reshape(shape); \ tensor.Reshape(shape); \
} }
#define TENSOR_FILL(tensor, shape) \ #define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \ if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \ auto* filler_info = ws()->GetFillerInfo(tensor.name()); \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \ CHECK(filler_info) << "\nTensor(" << tensor.name() << ") is empty.\n" \
<< "Maybe specify a filler for it?"; \ << "May be specify a filler for it?"; \
tensor.Reshape(shape); \ tensor.Reshape(shape); \
unique_ptr<Filler<T, Context>> filler( \ unique_ptr<Filler<T, Context>> filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \ CreateFiller<T, Context>(*filler_info)); \
filler->Fill(&tensor, ctx()); \ filler->Fill(&tensor, ctx()); \
} else { \ } else { \
int64_t count = 1; \ int64_t count = 1; \
for (int i = 0; i < shape.size(); i++) \ for (int i = 0; i < shape.size(); i++) \
count *= shape[i]; \ count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \ CHECK_EQ(count, tensor.count()) \
<< "\nExcepted Tensor(" << tensor.name() << ")'s " \ << "\nExcepted Tensor(" << tensor.name() << ")'s " \
<< "size is " << count << ", \n" \ << "size is " << count << ", \n" \
<< "but now is " << tensor.count() << ", " \ << "but now is " << tensor.count() << ", " \
<< "did you feed the incorrect data before?"; \ << "did you feed the incorrect data before?"; \
tensor.Reshape(shape); \ tensor.Reshape(shape); \
} }
/* Arguments */ /* Arguments */
......
...@@ -4,176 +4,150 @@ ...@@ -4,176 +4,150 @@
namespace dragon { namespace dragon {
vector<string> Workspace::tensors() const { Workspace::Workspace(const string& name) : name_(name) {
vector<string> locals;
// Search the local workspace
for (const auto& it : tensor_map_)
locals.push_back(it.first);
// Search the remote workspaces
for (const auto& it : external_tensor_map_) {
locals.push_back(it.first);
}
return locals;
}
vector<string> Workspace::graphs() const {
vector<string> names;
for (const auto& it : graph_map_) {
names.push_back(it.first);
}
return names;
}
void Workspace::Initialize() {
CreateTensor(""); // Empty placeholder CreateTensor(""); // Empty placeholder
CreateTensor("/share/flag/recomputing") CreateTensor("/share/flag/recomputing")
->Reshape({1}) ->Reshape({})
->mutable_data<bool, CPUContext>()[0] = false; ->mutable_data<bool, CPUContext>()[0] = false;
} }
void Workspace::Clear() { void Workspace::MergeFrom(Workspace* other) {
// Remove and Initialize again if (other != nullptr) {
tensor_map_.clear(); // Add the external tensors
Initialize(); for (const auto& it : other->tensor_map_) {
} if (!it.first.empty() && !str::startswith(it.first, "/")) {
external_tensor_map_[it.first] = it.second.get();
void Workspace::MergeFrom(Workspace* ws) { }
CHECK(ws) << "\nThe given Workspace is invalid."; }
for (const auto& it : ws->tensor_map_) { // Recount the unique index to avoid duplicate names
if (!it.first.empty() && !str::startswith(it.first, "/")) { for (const auto& i : other->unique_index_map_) {
external_tensor_map_[it.first] = it.second.get(); auto& index_map = unique_index_map_[i.first];
for (const auto& j : i.second) {
index_map[j.first] = std::max(index_map[j.first], j.second);
}
} }
} }
} }
string Workspace::GetTensorName(const string& name) const { Tensor* Workspace::TryGetTensor(const string& name, bool external) const {
const auto& it = alias_active_map_.find(name); // Check the alias firstly
if (it != alias_active_map_.end()) return it->second; const auto& alias_it = alias_map_.find(name);
return name; auto name_v2 = alias_it != alias_map_.end() ? alias_it->second : name;
} // Search this workspace
const auto& it = tensor_map_.find(name_v2);
Tensor* Workspace::TryGetTensor(const string& name, bool use_remote) const {
// Check the proxy of this tensor firstly
string query = GetTensorName(name);
// Search the local workspace
const auto& it = tensor_map_.find(query);
if (it != tensor_map_.end()) return it->second.get(); if (it != tensor_map_.end()) return it->second.get();
if (external) {
if (use_remote) { // Search external workspaces
// Search the remote workspaces const auto& it = external_tensor_map_.find(name_v2);
const auto& it = external_tensor_map_.find(query);
if (it != external_tensor_map_.end()) return it->second; if (it != external_tensor_map_.end()) return it->second;
} }
return nullptr; return nullptr;
} }
Tensor* Workspace::CreateTensor(const string& name) { Tensor* Workspace::CreateTensor(const string& name, FillerInfo* filler) {
Tensor* tensor = TryGetTensor(name); auto* tensor = TryGetTensor(name);
if (!tensor) { // Create only if name not existed
tensor_map_[name] = unique_ptr<Tensor>(new Tensor(name)); if (tensor == nullptr) {
return tensor_map_[name].get(); tensor = new Tensor(name);
tensor_map_[name] = unique_ptr<Tensor>(tensor);
}
// Maybe bind it with a filler
if (filler != nullptr) {
filler_map_[tensor->name()] = std::move(FillerInfo(*filler));
} }
return tensor; return tensor;
} }
Tensor* Workspace::GetTensor(const string& name, bool use_remote) const { Tensor* Workspace::GetTensor(const string& name, bool external) const {
Tensor* tensor = TryGetTensor(name, use_remote); auto* tensor = TryGetTensor(name, external);
CHECK(tensor) << "\nTensor(" << name << ") does not " CHECK(tensor) << "\nTensor(" << name << ") is not in current workspace.";
<< "exist in current workspace.";
return tensor; return tensor;
} }
void Workspace::ResetTensor(const string& name) { void Workspace::ResetTensor(const string& name) {
Tensor* tensor = TryGetTensor(name, false); auto* tensor = TryGetTensor(name, false);
CHECK(tensor) << "\nTensor(" << name << ") does not " CHECK(tensor) << "\nTensor(" << name << ") is not in current workspace.";
<< "belong to current workspace.";
tensor->Reset(); tensor->Reset();
} }
bool Workspace::HasFiller(const string& name) const { FillerInfo* Workspace::GetFillerInfo(const string& name) {
return tensor_filler_map_.count(name) > 0; const auto& it = filler_map_.find(name);
} if (it != filler_map_.end()) return &it->second;
void Workspace::CreateFiller(const TensorFillerProto& filler) {
CHECK_GT(filler.tensor().size(), 0)
<< "\nTensor with an empty name can not be filled.";
if (HasFiller(filler.tensor())) return;
tensor_filler_map_[filler.tensor()] = filler;
}
TensorFillerProto* Workspace::GetFiller(const string& name) {
const auto& it = tensor_filler_map_.find(name);
if (it != tensor_filler_map_.end()) return &it->second;
return nullptr; return nullptr;
} }
OperatorBase* Workspace::CreateOperator(const OperatorDef& def) {
const auto& it = operator_map_.find(def.cache_key());
if (it == operator_map_.end()) {
auto* new_op = NewOperator(def, this);
operator_map_[def.cache_key()] = unique_ptr<OperatorBase>(new_op);
return new_op;
}
return it->second.get();
}
void Workspace::RunOperator(const OperatorDef& def) { void Workspace::RunOperator(const OperatorDef& def) {
if (def.has_cache_key()) { if (def.has_cache_key()) {
CreateOperator(def)->UpdateFrom(def)->Run(0); OperatorBase* cached_op = nullptr;
const auto& it = operator_map_.find(def.cache_key());
if (it == operator_map_.end()) {
cached_op = NewOperator(def, this);
operator_map_[def.cache_key()] = unique_ptr<OperatorBase>(cached_op);
} else {
cached_op = it->second.get();
}
cached_op->UpdateFrom(def)->Run();
} else { } else {
unique_ptr<OperatorBase> op(NewOperator(def, this)); OperatorBase* temporal_op = NewOperator(def, this);
op->Run(0); temporal_op->Run();
delete temporal_op;
} }
} }
GraphBase* Workspace::CreateGraph(const GraphDef& def) { GraphBase* Workspace::CreateGraph(const GraphDef& def) {
CHECK(def.has_name()) << "\nGraph name is missing."; CHECK(def.has_name()) << "\nExcepted non-empty graph name.";
auto name = GetDummyName(def.name(), "", "Graph", false); GraphDef def_v2(def); // Copy to set an unique name
LOG(DEBUG) << "Create Graph: " << name << "(" << def.name() << ")"; def_v2.set_name(UniqueName(def.name(), "", "Graph", false));
GraphDef _def(def); LOG(DEBUG) << "Create Graph: " << def_v2.name() << "(" << def.name() << ")";
_def.set_name(name); auto* cached_graph = NewGraph(def_v2, this);
graph_map_[name] = unique_ptr<GraphBase>(NewGraph(_def, this)); graph_map_[def_v2.name()] = unique_ptr<GraphBase>(cached_graph);
return graph_map_[name].get(); return cached_graph;
} }
void Workspace::RunGraph( void Workspace::RunGraph(
const string& graph_name, const string& name,
const string& incl, const string& include,
const string& excl, const string& exclude,
int stream_id) { const int stream) {
if (!graph_map_.count(graph_name)) { CHECK(graph_map_.count(name))
LOG(FATAL) << "Graph(" << graph_name << ") does not exist."; << "\nGraph(" << name << ") is not in current workspace.";
} graph_map_[name]->Run(include, exclude, stream);
graph_map_[graph_name]->Run(incl, excl, stream_id);
} }
bool Workspace::ActivateAlias(const string& name, const string& alias) { void Workspace::RegisterAlias(const string& target, const string& alias) {
bool status = alias_active_map_.count(alias) > 0; alias_map_[alias] = target;
alias_active_map_[alias] = name;
return status; // True if activated otherwise false
} }
string Workspace::GetDummyName( string Workspace::UniqueName(
const string& base_name, const string& name,
const string& suffix, const string& suffix,
const string& domain, const string& scope,
bool zero_based) { bool zero_based) {
string accepted_name; auto& index_map = unique_index_map_[scope];
int64_t index; auto required_name = name + suffix;
const auto required_name = base_name + suffix; auto index = index_map[required_name]++;
auto& dmap = dummy_name_map_[domain]; if (index > 0) return name + "_" + str::to(index) + suffix;
while (1) { if (zero_based) return required_name;
index = dmap[required_name]++; return name + "_" + str::to(index_map[required_name]++) + suffix;
accepted_name = index ? base_name + "_" + str::to(index) + suffix }
: zero_based
? required_name vector<string> Workspace::tensors() const {
: base_name + "_" + str::to(dmap[required_name]++) + suffix; vector<string> names;
if (external_tensor_map_.empty()) break; for (const auto& it : tensor_map_) {
if (!HasTensor(accepted_name)) break; names.push_back(it.first);
} }
return accepted_name; for (const auto& it : external_tensor_map_) {
names.push_back(it.first);
}
return names;
}
vector<string> Workspace::graphs() const {
vector<string> names;
for (const auto& it : graph_map_) {
names.push_back(it.first);
}
return names;
} }
} // namespace dragon } // namespace dragon
...@@ -20,83 +20,63 @@ namespace dragon { ...@@ -20,83 +20,63 @@ namespace dragon {
class Workspace { class Workspace {
public: public:
/*! \brief Constructor */ /*! \brief Constructor */
explicit Workspace(const string& name) : name_(name) { DRAGON_API explicit Workspace(const string& name);
Initialize();
}
/*! \brief Create some internal tensors */
DRAGON_API void Initialize();
/*! \brief Merge tensors from a external workspace */ /*! \brief Merge resources from other */
DRAGON_API void MergeFrom(Workspace*); DRAGON_API void MergeFrom(Workspace*);
/*! \brief Destory all the tensors */ /* \brief Return an unique name */
DRAGON_API void Clear(); DRAGON_API string UniqueName(
const string& name,
/* \brief Return a unique dummy name within this workspace */
DRAGON_API string GetDummyName(
const string& base_name,
const string& suffix, const string& suffix,
const string& domain = "", const string& scope = "",
bool zero_based = true); const bool zero_based = false);
/*! \brief Whether the specified tensor is in this workspace */ /* \brief Register an alias for the target */
DRAGON_API bool HasTensor(const string& name, bool use_remote = true) const { DRAGON_API void RegisterAlias(const string& target, const string& alias);
return TryGetTensor(name, use_remote) ? true : false;
}
/*! \brief Query the real name of specified tensor */ /*! \brief Return whether tensor is existing */
DRAGON_API string GetTensorName(const string&) const; DRAGON_API bool HasTensor(const string& name, bool external = true) const {
return TryGetTensor(name, external) == nullptr ? false : true;
/* \brief Activate an alias for the target */ }
DRAGON_API bool ActivateAlias(const string& name, const string& alias);
/*! \brief Create a tensor in this workspace */ /*! \brief Create the tensor */
DRAGON_API Tensor* CreateTensor(const string&); DRAGON_API Tensor* CreateTensor(const string&, FillerInfo* = nullptr);
/*! \brief Try to search the specified tensor in this workspace */ /*! \brief Try to return the tensor */
DRAGON_API Tensor* TryGetTensor(const string&, bool = true) const; DRAGON_API Tensor* TryGetTensor(const string&, bool = true) const;
/*! \brief Return the specified tensor */ /*! \brief Return the tensor */
DRAGON_API Tensor* GetTensor(const string&, bool = true) const; DRAGON_API Tensor* GetTensor(const string&, bool = true) const;
/*! \brief Reset the specified tensor */ /*! \brief Reset the tensor */
DRAGON_API void ResetTensor(const string&); DRAGON_API void ResetTensor(const string&);
/* \brief Whether the specified filler is existing */ /*! \brief Return the filler info */
DRAGON_API bool HasFiller(const string&) const; DRAGON_API FillerInfo* GetFillerInfo(const string&);
/*! \brief Create a filler in this workspace */
DRAGON_API void CreateFiller(const TensorFillerProto&);
/*! \brief Return the specified filler */ /*! \brief Run the operator */
DRAGON_API TensorFillerProto* GetFiller(const string&);
/*! \brief Create an operator in this workspace */
DRAGON_API OperatorBase* CreateOperator(const OperatorDef&);
/*! \brief Run an operator in this workspace */
DRAGON_API void RunOperator(const OperatorDef&); DRAGON_API void RunOperator(const OperatorDef&);
/*! \brief Create a graph in this workspace */ /*! \brief Create the graph */
DRAGON_API GraphBase* CreateGraph(const GraphDef&); DRAGON_API GraphBase* CreateGraph(const GraphDef&);
/*! \brief Run the specifed graph by name and rules */ /*! \brief Run the graph */
DRAGON_API void RunGraph( DRAGON_API void RunGraph(
const string& graph_name, const string& graph_name,
const string& incl = "", const string& include = "",
const string& excl = "", const string& exclude = "",
int stream_id = 0); const int stream = 0);
/*! \brief Return the name of this workspace */ /*! \brief Return the workspace name */
const string& name() { const string& name() {
return name_; return name_;
} }
/*! \brief Return the name of stored tensors */ /*! \brief Return the name of cached tensors */
DRAGON_API vector<string> tensors() const; DRAGON_API vector<string> tensors() const;
/*! \brief Return the name of stored graphs */ /*! \brief Return the name of cached graphs */
DRAGON_API vector<string> graphs() const; DRAGON_API vector<string> graphs() const;
/*! \brief Provide a group of the shared byte data */ /*! \brief Provide a group of the shared byte data */
...@@ -127,28 +107,28 @@ class Workspace { ...@@ -127,28 +107,28 @@ class Workspace {
} }
private: private:
/*! \brief The unique workspace name */ /*! \brief The workspace name */
string name_; string name_;
/*! \brief The dummy name indices */ /*! \brief The external tensors */
Map<string, Map<string, int64_t>> dummy_name_map_; Map<string, Tensor*> external_tensor_map_;
/*! \brief Store the created tensors */ /*! \brief The unique indices */
Map<string, unique_ptr<Tensor>> tensor_map_; Map<string, Map<string, int64_t>> unique_index_map_;
/*! \brief Store the external tensors */ /*! \brief The registered fillers */
Map<string, Tensor*> external_tensor_map_; Map<string, FillerInfo> filler_map_;
/*! \brief Store the registered tensor fillers */ /*! \brief The registered aliases */
Map<string, TensorFillerProto> tensor_filler_map_; Map<string, string> alias_map_;
/*! \brief Store the active aliases */ /*! \brief The cached tensors */
Map<string, string> alias_active_map_; Map<string, unique_ptr<Tensor>> tensor_map_;
/*! \brief Store the registered operators for dynamic graph */ /*! \brief The cached operators */
Map<string, unique_ptr<OperatorBase>> operator_map_; Map<string, unique_ptr<OperatorBase>> operator_map_;
/*! \brief Store the registered graphs for static graph */ /*! \brief The cached graphs */
Map<string, unique_ptr<GraphBase>> graph_map_; Map<string, unique_ptr<GraphBase>> graph_map_;
}; };
......
...@@ -425,7 +425,7 @@ void PReluWGrad<float16, CUDAContext>( ...@@ -425,7 +425,7 @@ void PReluWGrad<float16, CUDAContext>(
CUDA_THREADS, CUDA_THREADS,
0, 0,
ctx->cuda_stream()>>>( ctx->cuda_stream()>>>(
N * C * S, N * S,
C, C,
S, S,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
...@@ -437,7 +437,7 @@ void PReluWGrad<float16, CUDAContext>( ...@@ -437,7 +437,7 @@ void PReluWGrad<float16, CUDAContext>(
CUDA_THREADS, CUDA_THREADS,
0, 0,
ctx->cuda_stream()>>>( ctx->cuda_stream()>>>(
N * C * S, N * S,
C, C,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
...@@ -536,13 +536,13 @@ void PReluWGrad<float16, CUDAContext>( ...@@ -536,13 +536,13 @@ void PReluWGrad<float16, CUDAContext>(
CUDA_2D_BLOCKS(C), \ CUDA_2D_BLOCKS(C), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(N * C * S, C, S, dy, x, dw); \ ctx->cuda_stream()>>>(N * S, C, S, dy, x, dw); \
} else if (data_format == "NHWC") { \ } else if (data_format == "NHWC") { \
_PReluWGradNHWC<<< \ _PReluWGradNHWC<<< \
CUDA_2D_BLOCKS(C), \ CUDA_2D_BLOCKS(C), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(N * C * S, C, dy, x, dw); \ ctx->cuda_stream()>>>(N * S, C, dy, x, dw); \
} else { \ } else { \
LOG(FATAL) << "Unknown data format: " << data_format; \ LOG(FATAL) << "Unknown data format: " << data_format; \
} \ } \
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/modules/python/types.h" #include "dragon/modules/python/types.h"
#include "dragon/onnx/onnx_backend.h" #include "dragon/onnx/onnx_backend.h"
#include "dragon/utils/caffemodel.h"
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
......
...@@ -74,7 +74,7 @@ class DLPackWrapper { ...@@ -74,7 +74,7 @@ class DLPackWrapper {
} }
Tensor* From(py::object obj) { Tensor* From(py::object obj) {
CHECK(PyCapsule_CheckExact(obj.ptr())) << "\nExpected DLPack capsule"; CHECK(PyCapsule_CheckExact(obj.ptr())) << "\nExpected DLPack capsule.";
auto* managed_tensor = auto* managed_tensor =
(DLManagedTensor*)PyCapsule_GetPointer(obj.ptr(), "dltensor"); (DLManagedTensor*)PyCapsule_GetPointer(obj.ptr(), "dltensor");
CHECK(managed_tensor) << "\nInvalid DLPack capsule"; CHECK(managed_tensor) << "\nInvalid DLPack capsule";
......
...@@ -44,48 +44,38 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -44,48 +44,38 @@ PYBIND11_MODULE(libdragon_python, m) {
/*! \brief Return the name of stored graphs */ /*! \brief Return the name of stored graphs */
.def_property_readonly("graphs", &Workspace::graphs) .def_property_readonly("graphs", &Workspace::graphs)
/*! \brief Destory all the tensors */ /*! \brief Merge resources from another workspace */
.def("Clear", &Workspace::Clear)
/*! \brief Merge a external workspace into self */
.def("MergeFrom", &Workspace::MergeFrom) .def("MergeFrom", &Workspace::MergeFrom)
/*! \brief Return a unique dummy name */ /*! \brief Return an unique name */
.def("GetDummyName", &Workspace::GetDummyName) .def("UniqueName", &Workspace::UniqueName)
/*! \brief Return the unique name of given tensor */
.def("GetTensorName", &Workspace::GetTensorName)
/*! \brief Reset a tensor with the given name */ /*! \brief Reset the tensor */
.def("ResetTensor", &Workspace::ResetTensor) .def("ResetTensor", &Workspace::ResetTensor)
/*! \brief Indicate whether the given tensor is existing */ /*! \brief Return whether the tensor is existing */
.def( .def(
"HasTensor", "HasTensor",
[](Workspace* self, const string& name) { [](Workspace* self, const string& name) {
return self->HasTensor(name); return self->HasTensor(name);
}) })
/*! \brief Create a tensor with the given name */ /*! \brief Create the tensor */
.def( .def(
"CreateTensor", "CreateTensor",
[](Workspace* self, const string& name) { [](Workspace* self, const string& name, const string& filler_str) {
if (!filler_str.empty()) {
FillerInfo filler_info;
if (!filler_info.ParseFromString(filler_str)) {
LOG(FATAL) << "Failed to parse the FillerInfo.";
}
return self->CreateTensor(name, &filler_info);
}
return self->CreateTensor(name); return self->CreateTensor(name);
}, },
py::return_value_policy::reference_internal) py::return_value_policy::reference_internal)
/*! \brief Create a tensor from the specified filler */ /*! \brief Return the tensor */
.def(
"CreateFiller",
[](Workspace* self, const string& serialized) {
TensorFillerProto filler_proto;
if (!filler_proto.ParseFromString(serialized))
LOG(FATAL) << "Failed to parse the TensorFiller.";
self->CreateFiller(filler_proto);
self->CreateTensor(filler_proto.tensor());
})
/*! \brief Return the CXX Tensor reference */
.def( .def(
"GetTensor", "GetTensor",
[](Workspace* self, const string& name) { [](Workspace* self, const string& name) {
...@@ -93,11 +83,11 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -93,11 +83,11 @@ PYBIND11_MODULE(libdragon_python, m) {
}, },
py::return_value_policy::reference_internal) py::return_value_policy::reference_internal)
/* \brief Set an alias for the tensor */ /* \brief Register an alias for the name */
.def( .def(
"SetTensorAlias", "RegisterAlias",
[](Workspace* self, const string& name, const string& alias) { [](Workspace* self, const string& name, const string& alias) {
return self->ActivateAlias(name, alias); return self->RegisterAlias(name, alias);
}) })
/*! \brief Copy the array data to tensor */ /*! \brief Copy the array data to tensor */
...@@ -118,7 +108,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -118,7 +108,7 @@ PYBIND11_MODULE(libdragon_python, m) {
dev, reinterpret_cast<PyArrayObject*>(value.ptr()), tensor); dev, reinterpret_cast<PyArrayObject*>(value.ptr()), tensor);
}) })
/*! \brief Copy the tensor data to the array */ /*! \brief Copy the tensor data to array */
.def( .def(
"FetchTensor", "FetchTensor",
[](Workspace* self, const string& name) { [](Workspace* self, const string& name) {
...@@ -142,7 +132,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -142,7 +132,7 @@ PYBIND11_MODULE(libdragon_python, m) {
} }
}) })
/*! \brief Run a operator from the def reference */ /*! \brief Run the operator */
.def( .def(
"RunOperator", "RunOperator",
[](Workspace* self, OperatorDef* def, const bool verbose) { [](Workspace* self, OperatorDef* def, const bool verbose) {
...@@ -156,7 +146,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -156,7 +146,7 @@ PYBIND11_MODULE(libdragon_python, m) {
self->RunOperator(*def); self->RunOperator(*def);
}) })
/*! \brief Run operators from the def reference */ /*! \brief Run the operators */
.def( .def(
"RunOperator", "RunOperator",
[](Workspace* self, vector<OperatorDef*>& defs, const bool verbose) { [](Workspace* self, vector<OperatorDef*>& defs, const bool verbose) {
...@@ -172,7 +162,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -172,7 +162,7 @@ PYBIND11_MODULE(libdragon_python, m) {
} }
}) })
/*! \brief Run a operator from the serialized def */ /*! \brief Run the operator from serialized def */
.def( .def(
"RunOperator", "RunOperator",
[](Workspace* self, const string& serialized, const bool verbose) { [](Workspace* self, const string& serialized, const bool verbose) {
...@@ -188,7 +178,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -188,7 +178,7 @@ PYBIND11_MODULE(libdragon_python, m) {
self->RunOperator(def); self->RunOperator(def);
}) })
/*! \brief Create a graph from the serialized def */ /*! \brief Create the graph */
.def( .def(
"CreateGraph", "CreateGraph",
[](Workspace* self, const string& serialized, const bool verbose) { [](Workspace* self, const string& serialized, const bool verbose) {
...@@ -213,89 +203,49 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -213,89 +203,49 @@ PYBIND11_MODULE(libdragon_python, m) {
return graph->name(); return graph->name();
}) })
/*! \brief Run an existing graph */ /*! \brief Run the graph */
.def( .def(
"RunGraph", "RunGraph",
[](Workspace* self, [](Workspace* self,
const string& name, const string& name,
const string& incl, const string& include,
const string& excl) { const string& exclude) {
py::gil_scoped_release g; py::gil_scoped_release g;
self->RunGraph(name, incl, excl); self->RunGraph(name, include, exclude);
}) })
/*! \brief Run the backward */
.def( .def(
"RunBackward", "RunBackward",
[](Workspace* self, [](Workspace* self,
const vector<OperatorDef*>& forward_ops, const vector<OperatorDef*>& op_defs,
const vector<string>& targets, const vector<string>& targets,
const vector<string>& sources, const vector<string>& sources,
const vector<string>& input_grads, const vector<string>& input_grads,
const vector<string>& ignored_grads, const vector<string>& empty_grads,
const bool is_sharing, const bool retain_grads,
const bool verbose) { const bool verbose) {
GraphDef backward_ops; GraphDef graph_def;
GraphGradientMaker maker; GraphGradientMaker maker;
for (const auto& name : ignored_grads) { for (const auto& name : empty_grads) {
maker.add_ignored_grad(name); maker.add_empty_grad(name);
} }
for (const auto& name : sources) { for (const auto& name : sources) {
maker.add_hooked_grad(name + "_grad"); maker.add_retained_grad(name + "_grad");
} }
maker.Make(forward_ops, targets, input_grads, backward_ops); maker.Make(op_defs, targets, input_grads, graph_def);
py::gil_scoped_release g; py::gil_scoped_release g;
if (is_sharing) { if (!retain_grads) {
backward_ops = maker.Share(backward_ops); graph_def = maker.Share(graph_def);
} }
for (const auto& def : backward_ops.op()) { for (const auto& op_def : graph_def.op()) {
if (verbose) { if (verbose) {
auto msg = string("\n") + def.DebugString(); auto msg = string("\n") + op_def.DebugString();
msg.pop_back(); msg.pop_back();
PRINT(INFO) PRINT(INFO)
<< "op {" << str::replace_all(msg, "\n", "\n ") << "\n}\n"; << "op {" << str::replace_all(msg, "\n", "\n ") << "\n}\n";
} }
self->RunOperator(def); self->RunOperator(op_def);
}
})
/*! \brief Serialize tensors into a binary file */
.def(
"Save",
[](Workspace* self,
const string& filename,
const vector<string>& tensors,
const int format) {
vector<Tensor*> refs;
switch (format) {
case 0: // Pickle
LOG(FATAL) << "Format depends on Pickle. "
<< "Can't be used in C++.";
break;
case 1: // CaffeModel
for (const auto& name : tensors) {
refs.emplace_back(self->GetTensor(name));
}
SavaCaffeModel(filename, refs);
break;
default:
LOG(FATAL) << "Unknown format, code: " << format;
}
})
/*! \brief Load tensors from a binary file */
.def(
"Load",
[](Workspace* self, const string& filename, const int format) {
switch (format) {
case 0: // Pickle
LOG(FATAL) << "Format depends on Pickle. "
<< "Can't be used in C++.";
break;
case 1: // CaffeModel
LoadCaffeModel(filename, self);
break;
default:
LOG(FATAL) << "Unknown format, code: " << format;
} }
}) })
......
...@@ -20,7 +20,6 @@ PythonPluginInferOp<Context>::PythonPluginInferOp( ...@@ -20,7 +20,6 @@ PythonPluginInferOp<Context>::PythonPluginInferOp(
class_name_(OpArg<string>("class_name", "")), class_name_(OpArg<string>("class_name", "")),
kwargs_str_((OpArg<string>("kwargs_str", ""))) { kwargs_str_((OpArg<string>("kwargs_str", ""))) {
// Optimization for all python ops // Optimization for all python ops
if (!allow_run()) return;
this->do_sync_ = false; this->do_sync_ = false;
// Initialize interpreter and load module // Initialize interpreter and load module
......
...@@ -24,6 +24,9 @@ namespace tensor { ...@@ -24,6 +24,9 @@ namespace tensor {
void RegisterModule(py::module& m) { void RegisterModule(py::module& m) {
/*! \brief Export the Tensor class */ /*! \brief Export the Tensor class */
py::class_<Tensor>(m, "Tensor") py::class_<Tensor>(m, "Tensor")
/*! \brief Return the tensor name */
.def_property_readonly("name", &Tensor::name)
/*! \brief Return the number of dimensions */ /*! \brief Return the number of dimensions */
.def_property_readonly("ndim", &Tensor::ndim) .def_property_readonly("ndim", &Tensor::ndim)
...@@ -46,9 +49,9 @@ void RegisterModule(py::module& m) { ...@@ -46,9 +49,9 @@ void RegisterModule(py::module& m) {
"device", "device",
[](Tensor* self) { [](Tensor* self) {
if (self->has_memory()) { if (self->has_memory()) {
auto mem_info = self->memory()->info(); auto info = self->memory()->info();
return std::tuple<string, int>( return std::tuple<string, int>(
mem_info["device_type"], atoi(mem_info["device_id"].c_str())); info["device_type"], atoi(info["device_id"].c_str()));
} else { } else {
return std::tuple<string, int>("Unknown", 0); return std::tuple<string, int>("Unknown", 0);
} }
......
...@@ -119,8 +119,6 @@ DRAGON_API void DestroyGraphDef(GraphDef_t graph_def); ...@@ -119,8 +119,6 @@ DRAGON_API void DestroyGraphDef(GraphDef_t graph_def);
* Model API * Model API
*/ */
DRAGON_API void LoadCaffeModel(const std::string& model_file, Workspace_t ws);
DRAGON_API void LoadONNXModel( DRAGON_API void LoadONNXModel(
const std::string& model_file, const std::string& model_file,
GraphDef_t init_graph, GraphDef_t init_graph,
......
#include "dragon/core/common.h" #include "dragon/core/common.h"
#include "dragon/modules/runtime/dragon_runtime.h" #include "dragon/modules/runtime/dragon_runtime.h"
#include "dragon/onnx/onnx_backend.h" #include "dragon/onnx/onnx_backend.h"
#include "dragon/utils/caffemodel.h"
#include "dragon/utils/proto_utils.h" #include "dragon/utils/proto_utils.h"
namespace dragon { namespace dragon {
...@@ -161,46 +160,6 @@ DRAGON_API void DestroyGraphDef(GraphDef_t graph_def) { ...@@ -161,46 +160,6 @@ DRAGON_API void DestroyGraphDef(GraphDef_t graph_def) {
if (graph_def) delete graph_def; if (graph_def) delete graph_def;
} }
void LoadCaffeModel(const string& model_file, Workspace_t ws) {
NetParameter net_param;
ReadProtoFromBinaryFile(model_file.c_str(), &net_param);
std::string scope = "";
LOG(INFO) << "Load Model: " << model_file << "......";
LOG(INFO) << "Format: Caffe";
for (int i = 0; i < net_param.layer_size(); i++) {
const LayerParameter& layer = net_param.layer(i);
const string& layer_name = layer.name();
string prefix = scope + layer_name + "/param:";
for (int j = 0; j < layer.blobs_size(); j++) {
string tensor_name = prefix + std::to_string(j);
if (!ws->HasTensor(tensor_name)) ws->CreateTensor(tensor_name);
BlobProto blob = layer.blobs(j);
vector<int64_t> dims;
for (auto dim : blob.shape().dim())
dims.push_back(dim);
Tensor* tensor = ws->GetTensor(tensor_name);
std::stringstream DimString;
if (dims.size() > 0) {
tensor->Reshape(dims);
CHECK_EQ(tensor->count(), blob.data_size())
<< "Tensor(" << tensor_name << ") "
<< "failed to load, except size: " << tensor->count()
<< ", loaded " << blob.data_size();
DimString << tensor->DimString();
} else {
tensor->Reshape(vector<int64_t>(1, blob.data_size()));
DimString << "(missing)";
}
float* Xdata = tensor->mutable_data<float, CPUContext>();
for (int idx = 0; idx < blob.data_size(); idx++)
Xdata[idx] = blob.data(idx);
LOG(INFO) << "Tensor(" << tensor_name << ") "
<< "loaded, shape: " << DimString.str()
<< ", size: " << blob.data_size();
}
}
}
void LoadONNXModel( void LoadONNXModel(
const string& model_file, const string& model_file,
GraphDef_t init_graph, GraphDef_t init_graph,
......
...@@ -19,7 +19,6 @@ ONNXImporterReturns ONNXBackend::ArgReduceImporter( ...@@ -19,7 +19,6 @@ ONNXImporterReturns ONNXBackend::ArgReduceImporter(
auto node = NodeProto(onnx_node->node); auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node); auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes; auto& attributes = onnx_node_v2.attributes;
// Determine the operation // Determine the operation
auto* operation = attributes.AddRewrittenAttribute("operation"); auto* operation = attributes.AddRewrittenAttribute("operation");
if (onnx_node->node.op_type() == "ArgMax") { if (onnx_node->node.op_type() == "ArgMax") {
...@@ -27,7 +26,6 @@ ONNXImporterReturns ONNXBackend::ArgReduceImporter( ...@@ -27,7 +26,6 @@ ONNXImporterReturns ONNXBackend::ArgReduceImporter(
} else if (onnx_node->node.op_type() == "ArgMin") { } else if (onnx_node->node.op_type() == "ArgMin") {
operation->set_s("MIN"); operation->set_s("MIN");
} }
return GenericImporter(&onnx_node_v2, ctx); return GenericImporter(&onnx_node_v2, ctx);
} }
...@@ -37,17 +35,13 @@ ONNXImporterReturns ONNXBackend::ATenImporter( ...@@ -37,17 +35,13 @@ ONNXImporterReturns ONNXBackend::ATenImporter(
auto node = NodeProto(onnx_node->node); auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node); auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes; auto& attributes = onnx_node_v2.attributes;
auto op_type = attributes.get<string>("op_type", ""); auto op_type = attributes.get<string>("op_type", "");
if (op_type.empty()) { if (op_type.empty()) {
LOG(FATAL) << "op_type is required to evolve " LOG(FATAL) << "op_type is required to evolve "
<< "to the specific operator."; << "to the specific operator.";
} }
node.set_op_type(op_type); node.set_op_type(op_type);
attributes.remove("op_type"); attributes.remove("op_type");
return GenericImporter(&onnx_node_v2, ctx); return GenericImporter(&onnx_node_v2, ctx);
} }
...@@ -56,17 +50,13 @@ ONNXImporterReturns ONNXBackend::BatchNormImporter( ...@@ -56,17 +50,13 @@ ONNXImporterReturns ONNXBackend::BatchNormImporter(
const ConversionContext& ctx) { const ConversionContext& ctx) {
auto node = NodeProto(onnx_node->node); auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node); auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes; auto& attributes = onnx_node_v2.attributes;
// Enforce to NCHW format // Enforce to NCHW format
attributes.AddRewrittenAttribute("axis")->set_i(1); attributes.AddRewrittenAttribute("axis")->set_i(1);
// Remove dummy attributes // Remove dummy attributes
attributes.remove("consumed_inputs"); attributes.remove("consumed_inputs");
attributes.remove("is_test"); attributes.remove("is_test");
attributes.remove("spatial"); attributes.remove("spatial");
return GenericImporter(&onnx_node_v2, ctx); return GenericImporter(&onnx_node_v2, ctx);
} }
...@@ -74,12 +64,10 @@ ONNXImporterReturns ONNXBackend::CastImporter( ...@@ -74,12 +64,10 @@ ONNXImporterReturns ONNXBackend::CastImporter(
ONNXNode* onnx_node, ONNXNode* onnx_node,
const ConversionContext& ctx) { const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes; auto& attributes = onnx_node->attributes;
// Determine the dtype // Determine the dtype
auto* dtype = attributes.AddRewrittenAttribute("dtype"); auto* dtype = attributes.AddRewrittenAttribute("dtype");
auto onnx_dtype = attributes.get<int64_t>("to", TensorProto::UNDEFINED); auto onnx_dtype = attributes.get<int64_t>("to", TensorProto::UNDEFINED);
auto supported_dtype = true; auto supported_dtype = true;
switch (onnx_dtype) { switch (onnx_dtype) {
case ONNX_NAMESPACE::TensorProto::BOOL: case ONNX_NAMESPACE::TensorProto::BOOL:
dtype->set_s("bool"); dtype->set_s("bool");
...@@ -138,11 +126,9 @@ ONNXImporterReturns ONNXBackend::CastImporter( ...@@ -138,11 +126,9 @@ ONNXImporterReturns ONNXBackend::CastImporter(
supported_dtype = false; supported_dtype = false;
break; break;
}; };
CHECK(supported_dtype) << "\nCasting to " << dtype->s() CHECK(supported_dtype) << "\nCasting to " << dtype->s()
<< " is not supported."; << " is not supported.";
attributes.remove("to"); attributes.remove("to");
return GenericImporter(onnx_node, ctx); return GenericImporter(onnx_node, ctx);
} }
...@@ -151,17 +137,16 @@ ONNXImporterReturns ONNXBackend::ConvPoolImporter( ...@@ -151,17 +137,16 @@ ONNXImporterReturns ONNXBackend::ConvPoolImporter(
const ConversionContext& ctx) { const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes; auto& attributes = onnx_node->attributes;
const auto onnx_op_type = onnx_node->node.op_type(); const auto onnx_op_type = onnx_node->node.op_type();
// Determine the padding // Determine the padding
auto mode = attributes.get<string>("auto_pad"); auto mode = attributes.get<string>("auto_pad");
auto* padding = attributes.AddRewrittenAttribute("padding"); auto* padding = attributes.AddRewrittenAttribute("padding");
// SAME, SAME_LOWER, or SAME_UPPER // SAME, SAME_LOWER, or SAME_UPPER
if (str::find(mode, "SAME")) if (str::find(mode, "SAME")) {
padding->set_s(mode); padding->set_s(mode);
else } else {
padding->set_s("VALID"); // Use explicit pads padding->set_s("VALID"); // Use explicit pads
}
attributes.remove("auto_pad"); attributes.remove("auto_pad");
// Determine the pooling mode // Determine the pooling mode
if (onnx_op_type == "MaxPool") { if (onnx_op_type == "MaxPool") {
attributes.AddRewrittenAttribute("mode")->set_s("MAX"); attributes.AddRewrittenAttribute("mode")->set_s("MAX");
...@@ -174,14 +159,11 @@ ONNXImporterReturns ONNXBackend::ConvPoolImporter( ...@@ -174,14 +159,11 @@ ONNXImporterReturns ONNXBackend::ConvPoolImporter(
attributes.AddRewrittenAttribute("mode")->set_s("AVG"); attributes.AddRewrittenAttribute("mode")->set_s("AVG");
attributes.AddRewrittenAttribute("global_pooling")->set_i(1); attributes.AddRewrittenAttribute("global_pooling")->set_i(1);
} }
auto returns = GenericImporter(onnx_node, ctx); auto returns = GenericImporter(onnx_node, ctx);
// Determine the op type // Determine the op type
OperatorDef* op_def = returns.GetOp(0); OperatorDef* op_def = returns.GetOp(0);
auto ks = attributes.get<ONNX_INTS>("kernel_shape"); auto ks = attributes.get<ONNX_INTS>("kernel_shape");
*(op_def->mutable_type()) += (str::to(ks.size() > 0 ? ks.size() : 2) + "d"); *(op_def->mutable_type()) += (str::to(ks.size() > 0 ? ks.size() : 2) + "d");
return returns; return returns;
} }
...@@ -194,11 +176,9 @@ ONNXImporterReturns ONNXBackend::GenericImporter( ...@@ -194,11 +176,9 @@ ONNXImporterReturns ONNXBackend::GenericImporter(
op_def->mutable_input()->MergeFrom(node.input()); op_def->mutable_input()->MergeFrom(node.input());
op_def->mutable_output()->MergeFrom(node.output()); op_def->mutable_output()->MergeFrom(node.output());
op_def->set_name(node.name()); op_def->set_name(node.name());
const auto onnx_op_type = node.op_type(); const auto onnx_op_type = node.op_type();
op_def->set_type( op_def->set_type(
get_default(get_renamed_nodes(), onnx_op_type, onnx_op_type)); get_default(get_renamed_nodes(), onnx_op_type, onnx_op_type));
auto mapper = [&, this](const std::string& k) { auto mapper = [&, this](const std::string& k) {
const auto it = get_node_renamed_attrs().find(onnx_op_type); const auto it = get_node_renamed_attrs().find(onnx_op_type);
if (it != get_node_renamed_attrs().end()) { if (it != get_node_renamed_attrs().end()) {
...@@ -224,18 +204,16 @@ ONNXImporterReturns ONNXBackend::GemmImporter( ...@@ -224,18 +204,16 @@ ONNXImporterReturns ONNXBackend::GemmImporter(
auto alpha = attributes.get<float>("alpha", 1.f); auto alpha = attributes.get<float>("alpha", 1.f);
auto beta = attributes.get<float>("beta", 1.f); auto beta = attributes.get<float>("beta", 1.f);
auto trans_a = attributes.get<int64_t>("transA", 0L); auto trans_a = attributes.get<int64_t>("transA", 0L);
// Remove the unsupported attributes
if (alpha != 1.f || beta != 1.f) { if (alpha != 1.f || beta != 1.f) {
LOG(FATAL) << "alpha/beta can not be set currently."; LOG(FATAL) << "alpha/beta can not be set currently.";
} }
if (trans_a) { if (trans_a) {
LOG(FATAL) << "Tranposed A is not supported currently."; LOG(FATAL) << "Tranposed A is not supported currently.";
} }
attributes.remove("alpha"); attributes.remove("alpha");
attributes.remove("beta"); attributes.remove("beta");
attributes.remove("transA"); attributes.remove("transA");
return GenericImporter(onnx_node, ctx); return GenericImporter(onnx_node, ctx);
} }
...@@ -244,11 +222,9 @@ ONNXImporterReturns ONNXBackend::MaxRoiPoolImporter( ...@@ -244,11 +222,9 @@ ONNXImporterReturns ONNXBackend::MaxRoiPoolImporter(
const ConversionContext& ctx) { const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes; auto& attributes = onnx_node->attributes;
auto pooled_shape = attributes.get<ONNX_INTS>("pooled_shape"); auto pooled_shape = attributes.get<ONNX_INTS>("pooled_shape");
attributes.AddRewrittenAttribute("pool_h")->set_i(pooled_shape.Get(0)); attributes.AddRewrittenAttribute("pool_h")->set_i(pooled_shape.Get(0));
attributes.AddRewrittenAttribute("pool_w")->set_i(pooled_shape.Get(1)); attributes.AddRewrittenAttribute("pool_w")->set_i(pooled_shape.Get(1));
attributes.remove("pooled_shape"); attributes.remove("pooled_shape");
return GenericImporter(onnx_node, ctx); return GenericImporter(onnx_node, ctx);
} }
...@@ -258,18 +234,16 @@ ONNXImporterReturns ONNXBackend::ReshapeImporter( ...@@ -258,18 +234,16 @@ ONNXImporterReturns ONNXBackend::ReshapeImporter(
auto node = NodeProto(onnx_node->node); auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node); auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes; auto& attributes = onnx_node_v2.attributes;
attributes.remove("consumed_inputs"); attributes.remove("consumed_inputs");
// Determine the dims // Determine the dims
auto* dims = attributes.AddRewrittenAttribute("dims"); auto* dims = attributes.AddRewrittenAttribute("dims");
if (ctx.opset_version() < 5) { if (ctx.opset_version() < 5) {
const auto& shape = attributes.get<ONNX_INTS>("shape"); const auto& shape = attributes.get<ONNX_INTS>("shape");
CHECK_GT(shape.size(), 0) << "\nExcepted the shape value"; CHECK_GT(shape.size(), 0) << "\nExcepted the shape value";
attributes.remove("shape"); attributes.remove("shape");
for (auto d : shape) for (auto d : shape) {
dims->add_ints(d); dims->add_ints(d);
}
} else { } else {
CHECK_EQ(node.input_size(), 2) CHECK_EQ(node.input_size(), 2)
<< "\nExpectd 2 input in upsample after onnx version 5"; << "\nExpectd 2 input in upsample after onnx version 5";
...@@ -280,10 +254,10 @@ ONNXImporterReturns ONNXBackend::ReshapeImporter( ...@@ -280,10 +254,10 @@ ONNXImporterReturns ONNXBackend::ReshapeImporter(
Argument shape_dtype, shape_values; Argument shape_dtype, shape_values;
ONNXTensorToArgument(*shape_tensor, &shape_dtype, &shape_values); ONNXTensorToArgument(*shape_tensor, &shape_dtype, &shape_values);
CHECK_GT(shape_values.ints_size(), 0) << "\nExcepted the shape value"; CHECK_GT(shape_values.ints_size(), 0) << "\nExcepted the shape value";
for (auto d : shape_values.ints()) for (auto d : shape_values.ints()) {
dims->add_ints(d); dims->add_ints(d);
}
} }
return GenericImporter(&onnx_node_v2, ctx); return GenericImporter(&onnx_node_v2, ctx);
} }
...@@ -293,9 +267,7 @@ ONNXImporterReturns ONNXBackend::ResizeImporter( ...@@ -293,9 +267,7 @@ ONNXImporterReturns ONNXBackend::ResizeImporter(
auto node = NodeProto(onnx_node->node); auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node); auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes; auto& attributes = onnx_node_v2.attributes;
attributes.remove("coordinate_transformation_mode"); attributes.remove("coordinate_transformation_mode");
if (ctx.opset_version() >= 9) { if (ctx.opset_version() >= 9) {
node.mutable_input()->Clear(); node.mutable_input()->Clear();
node.add_input(onnx_node->node.input(0)); node.add_input(onnx_node->node.input(0));
...@@ -307,21 +279,22 @@ ONNXImporterReturns ONNXBackend::ResizeImporter( ...@@ -307,21 +279,22 @@ ONNXImporterReturns ONNXBackend::ResizeImporter(
const auto* scales_tensor = ctx.initializer().at(scales_name); const auto* scales_tensor = ctx.initializer().at(scales_name);
ONNXTensorToArgument(*scales_tensor, &scales_dtype, &scale_values); ONNXTensorToArgument(*scales_tensor, &scales_dtype, &scale_values);
auto* scales = attributes.AddRewrittenAttribute("scales"); auto* scales = attributes.AddRewrittenAttribute("scales");
for (auto d : scale_values.floats()) for (auto d : scale_values.floats()) {
scales->add_floats(d); scales->add_floats(d);
}
if (sizes_idx > 0) { if (sizes_idx > 0) {
Argument sizes_dtype, sizes_values; Argument sizes_dtype, sizes_values;
const auto& sizes_name = onnx_node->node.input(sizes_idx); const auto& sizes_name = onnx_node->node.input(sizes_idx);
const auto* sizes_tensor = ctx.initializer().at(sizes_name); const auto* sizes_tensor = ctx.initializer().at(sizes_name);
ONNXTensorToArgument(*sizes_tensor, &sizes_dtype, &sizes_values); ONNXTensorToArgument(*sizes_tensor, &sizes_dtype, &sizes_values);
auto* sizes = attributes.AddRewrittenAttribute("sizes"); auto* sizes = attributes.AddRewrittenAttribute("sizes");
for (auto d : sizes_values.floats()) for (auto d : sizes_values.floats()) {
sizes->add_ints(d); sizes->add_ints(d);
}
} }
} else { } else {
LOG(FATAL) << "Required opset >= 7"; LOG(FATAL) << "Required opset >= 7";
} }
return GenericImporter(&onnx_node_v2, ctx); return GenericImporter(&onnx_node_v2, ctx);
} }
...@@ -330,12 +303,10 @@ ONNXImporterReturns ONNXBackend::RoiAlignImporter( ...@@ -330,12 +303,10 @@ ONNXImporterReturns ONNXBackend::RoiAlignImporter(
const ConversionContext& ctx) { const ConversionContext& ctx) {
auto node = NodeProto(onnx_node->node); auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node); auto onnx_node_v2 = ONNXNode(node);
// Remove the batch indices // Remove the batch indices
node.mutable_input()->Clear(); node.mutable_input()->Clear();
node.add_input(onnx_node->node.input(0)); node.add_input(onnx_node->node.input(0));
node.add_input(onnx_node->node.input(1)); node.add_input(onnx_node->node.input(1));
return GenericImporter(&onnx_node_v2, ctx); return GenericImporter(&onnx_node_v2, ctx);
} }
...@@ -345,19 +316,22 @@ ONNXImporterReturns ONNXBackend::TileImporter( ...@@ -345,19 +316,22 @@ ONNXImporterReturns ONNXBackend::TileImporter(
auto node = NodeProto(onnx_node->node); auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node); auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes; auto& attributes = onnx_node_v2.attributes;
if (ctx.opset_version() >= 6) {
// Determine the multiples from repeats // Determine repeats from repeats
auto* multiples = attributes.AddRewrittenAttribute("multiples"); auto* repeats = attributes.AddRewrittenAttribute("repeats");
node.mutable_input()->Clear(); node.mutable_input()->Clear();
node.add_input(onnx_node->node.input(0)); node.add_input(onnx_node->node.input(0));
const auto& repeats_name = onnx_node->node.input(1); const auto& repeats_name = onnx_node->node.input(1);
const auto* repeats_tensor = ctx.initializer().at(repeats_name); const auto* repeats_tensor = ctx.initializer().at(repeats_name);
Argument multiples_dtype, multiples_values; Argument repeats_dtype, repeats_values;
ONNXTensorToArgument(*repeats_tensor, &multiples_dtype, &multiples_values); ONNXTensorToArgument(*repeats_tensor, &repeats_dtype, &repeats_values);
CHECK_GT(multiples_values.ints_size(), 0) << "\nExcepted the repeats value"; CHECK_GT(repeats_values.ints_size(), 0) << "\nExcepted the repeats value";
for (auto d : multiples_values.ints()) for (auto repeat : repeats_values.ints()) {
multiples->add_ints(d); repeats->add_ints(repeat);
}
} else {
LOG(FATAL) << "Required opset >= 6";
}
return GenericImporter(&onnx_node_v2, ctx); return GenericImporter(&onnx_node_v2, ctx);
} }
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
namespace dragon { namespace dragon {
#define DEFINE_FILLER_OP_IMPL(name) \ #define DEFINE_FILLER_OP_IMPL(name) \
template <class Context> \ template <class Context> \
template <typename T> \ template <typename T> \
void name##Op<Context>::DoRunWithType() { \ void name##Op<Context>::DoRunWithType() { \
unique_ptr<Filler<T, Context>> f; \ unique_ptr<Filler<T, Context>> f; \
f.reset(CreateFiller<T, Context>(this->proto_)); \ f.reset(CreateFiller<T, Context>(this->filler_info_)); \
f->Fill(Output(0), ctx()); \ f->Fill(Output(0), ctx()); \
} }
#define DISPATCH_WITH_TYPES(name, ...) \ #define DISPATCH_WITH_TYPES(name, ...) \
......
...@@ -30,7 +30,7 @@ class InitializeOp : public Operator<Context> { ...@@ -30,7 +30,7 @@ class InitializeOp : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
protected: protected:
TensorFillerProto proto_; FillerInfo filler_info_;
DECLARE_ARGS_WITH_DESC(int64_t, dims); DECLARE_ARGS_WITH_DESC(int64_t, dims);
}; };
...@@ -142,9 +142,9 @@ class RandomNormalOp final : public InitializeOp<Context> { ...@@ -142,9 +142,9 @@ class RandomNormalOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
auto mu = OpArg<float>("mean", 0.f); auto mu = OpArg<float>("mean", 0.f);
auto sigma = OpArg<float>("std", 1.f); auto sigma = OpArg<float>("std", 1.f);
this->proto_.set_mean(mu); this->filler_info_.set_mean(mu);
this->proto_.set_std(sigma); this->filler_info_.set_std(sigma);
this->proto_.set_type("normal"); this->filler_info_.set_type("normal");
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -161,9 +161,9 @@ class RandomUniformOp final : public InitializeOp<Context> { ...@@ -161,9 +161,9 @@ class RandomUniformOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
auto low = OpArg<float>("low", -1.f); auto low = OpArg<float>("low", -1.f);
auto high = OpArg<float>("high", 1.f); auto high = OpArg<float>("high", 1.f);
this->proto_.set_low(low); this->filler_info_.set_low(low);
this->proto_.set_high(high); this->filler_info_.set_high(high);
this->proto_.set_type("uniform"); this->filler_info_.set_type("uniform");
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -180,11 +180,11 @@ class TruncatedNormalOp final : public InitializeOp<Context> { ...@@ -180,11 +180,11 @@ class TruncatedNormalOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
auto mu = OpArg<float>("mean", 0.f); auto mu = OpArg<float>("mean", 0.f);
auto sigma = OpArg<float>("std", 1.f); auto sigma = OpArg<float>("std", 1.f);
this->proto_.set_mean(mu); this->filler_info_.set_mean(mu);
this->proto_.set_std(sigma); this->filler_info_.set_std(sigma);
this->proto_.set_low(mu - 2 * sigma); this->filler_info_.set_low(mu - 2 * sigma);
this->proto_.set_high(mu + 2 * sigma); this->filler_info_.set_high(mu + 2 * sigma);
this->proto_.set_type("truncated_normal"); this->filler_info_.set_type("truncated_normal");
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -201,15 +201,15 @@ class GlorotNormalOp final : public InitializeOp<Context> { ...@@ -201,15 +201,15 @@ class GlorotNormalOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
auto scale = OpArg<float>("scale", 2.f); auto scale = OpArg<float>("scale", 2.f);
auto mode = OpArg<string>("mode", "fan_in"); auto mode = OpArg<string>("mode", "fan_in");
this->proto_.set_type("msra"); this->filler_info_.set_type("glorot_normal");
if (mode == "fan_avg") { if (mode == "fan_avg") {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_AVG); this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_AVG);
} else if (mode == "fan_out") { } else if (mode == "fan_out") {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_OUT); this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_OUT);
} else { } else {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_IN); this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_IN);
} }
this->proto_.set_scale(scale); this->filler_info_.set_scale(scale);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -226,15 +226,15 @@ class GlorotUniformOp final : public InitializeOp<Context> { ...@@ -226,15 +226,15 @@ class GlorotUniformOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
auto scale = OpArg<float>("scale", 3.f); auto scale = OpArg<float>("scale", 3.f);
auto mode = OpArg<string>("mode", "fan_in"); auto mode = OpArg<string>("mode", "fan_in");
this->proto_.set_type("xavier"); this->filler_info_.set_type("glorot_uniform");
if (mode == "fan_avg") { if (mode == "fan_avg") {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_AVG); this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_AVG);
} else if (mode == "fan_out") { } else if (mode == "fan_out") {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_OUT); this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_OUT);
} else { } else {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_IN); this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_IN);
} }
this->proto_.set_scale(scale); this->filler_info_.set_scale(scale);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -9,9 +9,12 @@ template <typename T> ...@@ -9,9 +9,12 @@ template <typename T>
void TileOp<Context>::DoRunWithType() { void TileOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
int num_repeats;
repeats(0, &num_repeats);
auto Y_dims = X.dims(); auto Y_dims = X.dims();
for (int i = 0; i < Y_dims.size(); ++i) for (int i = 0; i < num_repeats; ++i) {
Y_dims[i] *= multiples(i); Y_dims[i] *= repeats(i);
}
if (X.dims() == Y_dims) { if (X.dims() == Y_dims) {
Y->Reshape(Y_dims)->CopyFrom(X, ctx()); Y->Reshape(Y_dims)->CopyFrom(X, ctx());
...@@ -49,7 +52,7 @@ void TileGradientOp<Context>::DoRunWithType() { ...@@ -49,7 +52,7 @@ void TileGradientOp<Context>::DoRunWithType() {
dx = dest_->template mutable_data<T, Context>(); dx = dest_->template mutable_data<T, Context>();
} }
kernel::TileGrad( kernel::TileGrad(
dest_->count(0, axis_), dest_->count(axis_), multiple_, dy, dx, ctx()); dest_->count(0, axis_), dest_->count(axis_), repeat_, dy, dx, ctx());
} }
template <class Context> template <class Context>
...@@ -57,10 +60,14 @@ void TileGradientOp<Context>::RunOnDevice() { ...@@ -57,10 +60,14 @@ void TileGradientOp<Context>::RunOnDevice() {
auto &dY = Input(0), *dX = Output(0); auto &dY = Input(0), *dX = Output(0);
// Add the axes // Add the axes
int num_repeats;
repeats(0, &num_repeats);
vector<pair<int, int>> dispatch_axes; vector<pair<int, int>> dispatch_axes;
for (int i = 0; i < dY.ndim(); i++) { for (int i = 0; i < dY.ndim() && i < num_repeats; i++) {
auto m = multiples(i); auto repeat = repeats(i);
if (m > 1) dispatch_axes.push_back({m, i}); if (repeat > 1) {
dispatch_axes.push_back({repeat, i});
}
} }
std::sort(dispatch_axes.begin(), dispatch_axes.end()); std::sort(dispatch_axes.begin(), dispatch_axes.end());
std::reverse(dispatch_axes.begin(), dispatch_axes.end()); std::reverse(dispatch_axes.begin(), dispatch_axes.end());
...@@ -76,10 +83,10 @@ void TileGradientOp<Context>::RunOnDevice() { ...@@ -76,10 +83,10 @@ void TileGradientOp<Context>::RunOnDevice() {
// Reduce N times along each tiled axis // Reduce N times along each tiled axis
for (const auto& task : dispatch_axes) { for (const auto& task : dispatch_axes) {
axis_ = task.second, multiple_ = task.first; axis_ = task.second, repeat_ = task.first;
vec64_t X_dims(src_->dims()); vec64_t X_dims(src_->dims());
X_dims[axis_] /= multiple_; X_dims[axis_] /= repeat_;
dest_->Reshape(X_dims); dest_->Reshape(X_dims);
DispatchHelper<FloatingTensorTypes>::Call(this, dY); DispatchHelper<FloatingTensorTypes>::Call(this, dY);
......
...@@ -21,7 +21,7 @@ template <class Context> ...@@ -21,7 +21,7 @@ template <class Context>
class TileOp final : public Operator<Context> { class TileOp final : public Operator<Context> {
public: public:
TileOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) { TileOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, multiples); GET_ARGS_WITH_DESC(int64_t, repeats);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -31,7 +31,7 @@ class TileOp final : public Operator<Context> { ...@@ -31,7 +31,7 @@ class TileOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
DECLARE_ARGS_WITH_DESC(int64_t, multiples); DECLARE_ARGS_WITH_DESC(int64_t, repeats);
}; };
template <class Context> template <class Context>
...@@ -39,7 +39,7 @@ class TileGradientOp final : public Operator<Context> { ...@@ -39,7 +39,7 @@ class TileGradientOp final : public Operator<Context> {
public: public:
TileGradientOp(const OperatorDef& def, Workspace* ws) TileGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, multiples); GET_ARGS_WITH_DESC(int64_t, repeats);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -50,12 +50,12 @@ class TileGradientOp final : public Operator<Context> { ...@@ -50,12 +50,12 @@ class TileGradientOp final : public Operator<Context> {
protected: protected:
Tensor *dest_, *src_, nav_; Tensor *dest_, *src_, nav_;
int64_t axis_, multiple_; int64_t axis_, repeat_;
DECLARE_ARGS_WITH_DESC(int64_t, multiples); DECLARE_ARGS_WITH_DESC(int64_t, repeats);
}; };
DEFINE_ARGS_WITH_DESC(int64_t, TileOp, multiples); DEFINE_ARGS_WITH_DESC(int64_t, TileOp, repeats);
DEFINE_ARGS_WITH_DESC(int64_t, TileGradientOp, multiples); DEFINE_ARGS_WITH_DESC(int64_t, TileGradientOp, repeats);
} // namespace dragon } // namespace dragon
......
...@@ -9,7 +9,6 @@ void AdamUpdateOp<Context>::ComputeUpdate(Tensor* dX) { ...@@ -9,7 +9,6 @@ void AdamUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
t_++; t_++;
auto beta1 = Parameter("beta1"), beta2 = Parameter("beta2"); auto beta1 = Parameter("beta1"), beta2 = Parameter("beta2");
auto coef = sqrt(1.f - pow(beta2, t_)) / (1.f - pow(beta1, t_)); auto coef = sqrt(1.f - pow(beta2, t_)) / (1.f - pow(beta1, t_));
kernel::AdamUpdate( kernel::AdamUpdate(
dX->count(), dX->count(),
Parameter("base_lr") * coef * this->lr_mult_, Parameter("base_lr") * coef * this->lr_mult_,
......
...@@ -10,7 +10,6 @@ void SGDUpdateOp<Context>::ComputeUpdate(Tensor* dX) { ...@@ -10,7 +10,6 @@ void SGDUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
auto lr = Parameter("base_lr") * this->lr_mult_; auto lr = Parameter("base_lr") * this->lr_mult_;
if (last_lr_ > 0) correction_ = lr / last_lr_; if (last_lr_ > 0) correction_ = lr / last_lr_;
last_lr_ = lr; // Record the last value last_lr_ = lr; // Record the last value
kernel::SGDUpdate( kernel::SGDUpdate(
dX->count(), dX->count(),
lr, lr,
......
...@@ -20,9 +20,7 @@ void BiasAddOp<Context>::DoRunWithType() { ...@@ -20,9 +20,7 @@ void BiasAddOp<Context>::DoRunWithType() {
LOG(FATAL) << "Unknown DataFormat: " << data_format(); LOG(FATAL) << "Unknown DataFormat: " << data_format();
} }
// Maybe fill the bias at the first time
TENSOR_FILL(B, vec64_t({C})); TENSOR_FILL(B, vec64_t({C}));
kernel::BiasAdd( kernel::BiasAdd(
N, N,
C, C,
......
...@@ -4,73 +4,69 @@ ...@@ -4,73 +4,69 @@
namespace dragon { namespace dragon {
#define SAME_PADDING(A, B) \ #define DETERMINE_SAME_PADDING(l, r) \
A[i] = padding_needed / 2; \ if (padding_ != "SAME_UPPER") { \
B[i] = padding_needed - A[i] l[i] = pad_size / 2; \
r[i] = pad_size - l[i]; \
} else { \
r[i] = pad_size / 2; \
l[i] = pad_size - r[i]; \
}
template <class Context> template <class Context>
void ConvOpBase<Context>::ComputeOutShape() { void ConvOpBase<Context>::ComputeOutShape() {
auto X_dims = Input(0).dims();
out_shape_.clear(); out_shape_.clear();
for (int i = 0; i < num_axes_; i++) { vec64_t X_dims = Input(0).dims();
if (!Transposed()) { int64_t in_size, out_size, k_size, pad_size;
auto idm = X_dims[axis_ + i]; if (!Transposed()) {
auto dk = dilation_[i] * (kshape_[i] - 1) + 1; for (int i = 0; i < num_axes_; i++) {
if (!str::find(padding_, "SAME")) { in_size = X_dims[axis_ + i];
// Explicit pads k_size = dilation_[i] * (kshape_[i] - 1) + 1;
auto odm = (idm + pad_l_[i] + pad_r_[i] - dk) / stride_[i] + 1; if (!str::find(padding_, "SAME")) { // Explicit pads
out_shape_.push_back(odm); pad_size = pad_l_[i] + pad_r_[i];
} else { out_size = (in_size + pad_size - k_size) / stride_[i] + 1;
// Auto pads } else { // Auto pads
int64_t odm = (idm + stride_[i] - 1) / (float)stride_[i]; out_size = (in_size + stride_[i] - 1) / stride_[i];
auto padding_needed = pad_size = (out_size - 1) * stride_[i] + k_size - in_size;
std::max(int64_t(0), (odm - 1) * stride_[i] + dk - idm); pad_size = std::max(pad_size, int64_t(0));
out_shape_.push_back(odm); DETERMINE_SAME_PADDING(pad_l_, pad_r_);
if (padding_ == "SAME_UPPER") { }
SAME_PADDING(pad_l_, pad_r_); out_shape_.push_back(out_size);
} else { }
SAME_PADDING(pad_r_, pad_l_); } else {
} // SAME_LOWER or SAME int num_output_padding;
output_padding(0, &num_output_padding);
CHECK(num_output_padding == 0 || num_output_padding == num_axes_)
<< "\nExcepted 0 or " << num_axes_ << " ints for <output_padding>.";
if (!str::find(padding_, "SAME")) { // Explicit pads
for (int i = 0; i < num_axes_; i++) {
in_size = X_dims[axis_ + i];
k_size = dilation_[i] * (kshape_[i] - 1) + 1;
pad_size = pad_l_[i] + pad_r_[i];
out_size = stride_[i] * (in_size - 1) + k_size - pad_size;
if (num_output_padding > 0) out_size += output_padding(i);
out_shape_.push_back(out_size);
} }
} else { } else {
auto idm = X_dims[axis_ + i]; // Auto pads
auto dk = dilation_[i] * (kshape_[i] - 1) + 1; int num_output_shape;
if (!str::find(padding_, "SAME")) { output_shape(0, &num_output_shape);
// Explicit pads CHECK(num_output_shape == num_axes_)
auto odm = stride_[i] * (idm - 1) + dk - pad_l_[i] - pad_r_[i]; << "\nExcepted " << num_axes_ << " ints for <output_shape>.";
out_shape_.push_back(odm); for (int i = 0; i < num_axes_; i++) {
} else { in_size = X_dims[axis_ + i];
// Auto pads k_size = dilation_[i] * (kshape_[i] - 1) + 1;
int output_shape_size; out_size = output_shape(i);
int output_padding_size; pad_size = stride_[i] * (in_size - 1) + k_size;
output_shape(0, &output_shape_size); if (num_output_padding > 0) pad_size += output_padding(i);
output_padding(0, &output_padding_size); CHECK_GE(pad_size, out_size)
CHECK(output_shape_size == 0 || output_shape_size == num_axes_) << "\nThe output shape is incorrect."
<< "Excepted 0 or " << num_axes_ << " ints for output shape."; << "\nDimension of spatial axis " << i << " should be at most "
CHECK(output_padding_size == 0 || output_padding_size == num_axes_) << pad_size << ".";
<< "Excepted 0 or " << num_axes_ << " ints for output padding."; pad_size = stride_[i] * (in_size - 1) + k_size - out_size;
int64_t padding_needed, odm; pad_size = std::max(pad_size, int64_t(0));
if (output_padding_size) { DETERMINE_SAME_PADDING(pad_l_, pad_r_);
padding_needed = output_padding(i); out_shape_.push_back(out_size);
odm = stride_[i] * (idm - 1) + dk + padding_needed;
} else if (output_shape_size) {
odm = output_shape(i);
padding_needed = odm - (stride_[i] * (idm - 1) + dk);
CHECK_GE(padding_needed, 0)
<< "\nThe output shape is incorrect."
<< "\nWith the given stride and kernel, "
<< "dimension of spatial axis " << i << " should be at least "
<< odm - padding_needed << ".";
} else {
LOG(FATAL) << "Excepted the output padding or output shape "
<< "for \"SAME\" padding algorithm.";
}
out_shape_.push_back(odm);
if (padding_ == "SAME_UPPER") {
SAME_PADDING(pad_l_, pad_r_);
} else {
SAME_PADDING(pad_r_, pad_l_);
} // SAME_LOWER or SAME
} }
} }
} }
...@@ -373,7 +369,7 @@ INSTANTIATE_API(CUDAContext, float); ...@@ -373,7 +369,7 @@ INSTANTIATE_API(CUDAContext, float);
INSTANTIATE_API(CUDAContext, double); INSTANTIATE_API(CUDAContext, double);
#endif #endif
#undef SAME_PADDING
#undef INSTANTIATE_API #undef INSTANTIATE_API
#undef DETERMINE_SAME_PADDING
} // namespace dragon } // namespace dragon
...@@ -5,9 +5,14 @@ ...@@ -5,9 +5,14 @@
namespace dragon { namespace dragon {
#define SAME_PADDING(A, B) \ #define DETERMINE_SAME_PADDING(l, r) \
A[i] = padding_needed / 2; \ if (padding_ != "SAME_UPPER") { \
B[i] = padding_needed - A[i] l[i] = pad_size / 2; \
r[i] = pad_size - l[i]; \
} else { \
r[i] = pad_size / 2; \
l[i] = pad_size - r[i]; \
}
template <class Context> template <class Context>
void PoolOpBase<Context>::Setup(int num_axes) { void PoolOpBase<Context>::Setup(int num_axes) {
...@@ -52,41 +57,27 @@ void PoolOpBase<Context>::ComputeOutShape() { ...@@ -52,41 +57,27 @@ void PoolOpBase<Context>::ComputeOutShape() {
kshape_[i] = in_dims_[i + 2]; kshape_[i] = in_dims_[i + 2];
} }
// Adjust the pads for SAME padding algorithm
if (str::find(padding_, "SAME")) {
for (int i = 0; i < num_axes_; i++) {
auto idm = in_dims_[i + 2];
int64_t odm = (idm + stride_[i] - 1) / (float)stride_[i];
auto padding_needed =
std::max((int64_t)0, (odm - 1) * stride_[i] + kshape_[i] - idm);
if (padding_ == "SAME_UPPER") {
SAME_PADDING(pad_l_, pad_r_);
} else {
SAME_PADDING(pad_r_, pad_l_);
} /*! SAME_LOWER or SAME */
}
}
// Compute the output dimensions // Compute the output dimensions
auto floor_or_ceil = ceil_mode_ > 0 auto floor_or_ceil = ceil_mode_ > 0
? static_cast<float (*)(float)>(&std::ceil) ? static_cast<float (*)(float)>(&std::ceil)
: static_cast<float (*)(float)>(&std::floor); : static_cast<float (*)(float)>(&std::floor);
out_dims_ = in_dims_; out_dims_ = in_dims_;
out_shape_ = Input(0).dims(); out_shape_ = Input(0).dims();
int64_t in_size, k_size, pad_size;
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; i++) {
auto in_dim = in_dims_[i + 2]; float out_size;
if (!str::find(padding_, "SAME")) { in_size = in_dims_[i + 2], k_size = kshape_[i];
// Explicit pads if (!str::find(padding_, "SAME")) { // Explicit pads
in_dim += pad_l_[i] + pad_r_[i]; pad_size = pad_l_[i] + pad_r_[i];
out_shape_[i + axis_] = out_dims_[i + 2] = out_size = float(in_size + pad_size - k_size) / float(stride_[i]) + 1.f;
floor_or_ceil((in_dim - kshape_[i]) / (float)stride_[i]) + 1; out_size = floor_or_ceil(out_size);
} else { } else { // Auto pads
// Auto pads out_size = std::ceil(float(in_size) / float(stride_[i]));
out_shape_[i + axis_] = out_dims_[i + 2] = pad_size = ((int64_t)out_size - 1) * stride_[i] + k_size - in_size;
floor_or_ceil(in_dim / (float)stride_[i]); pad_size = std::max(pad_size, int64_t(0));
DETERMINE_SAME_PADDING(pad_l_, pad_r_);
} }
out_shape_[i + axis_] = out_dims_[i + 2] = out_size;
} }
} }
...@@ -95,6 +86,6 @@ template class PoolOpBase<CPUContext>; ...@@ -95,6 +86,6 @@ template class PoolOpBase<CPUContext>;
template class PoolOpBase<CUDAContext>; template class PoolOpBase<CUDAContext>;
#endif #endif
#undef SAME_PADDING #undef DETERMINE_SAME_PADDING
} // namespace dragon } // namespace dragon
syntax = "proto2";
package dragon;
message BlobShape {
repeated int64 dim = 1 [packed = true];
}
message BlobProto {
optional BlobShape shape = 7;
repeated float data = 5 [packed = true];
optional int32 num = 1 [default = 0];
optional int32 channels = 2 [default = 0];
optional int32 height = 3 [default = 0];
optional int32 width = 4 [default = 0];
}
message NetParameter {
optional string name = 1;
repeated LayerParameter layer = 100;
}
message LayerParameter {
optional string name = 1;
repeated BlobProto blobs = 7;
}
...@@ -51,26 +51,6 @@ message TensorProto { ...@@ -51,26 +51,6 @@ message TensorProto {
optional string name = 7; optional string name = 7;
} }
// Record the filler of Tensor.
// This structure is kept for backward compatibility
// with caffe1, which relies implicit initializer.
message TensorFillerProto {
optional string tensor = 1;
optional string type = 2 [default = 'constant'];
optional float value = 3 [default = 0];
optional float low = 4 [default = 0];
optional float high = 5 [default = 1];
optional float mean = 6 [default = 0];
optional float std = 7 [default = 1];
optional float scale = 8 [default = 3];
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
FAN_AVG = 2;
}
optional VarianceNorm variance_norm = 9 [default = FAN_IN];
}
// Store multiple TensorProto objects in one single proto. // Store multiple TensorProto objects in one single proto.
message TensorProtos { message TensorProtos {
repeated TensorProto protos = 1; repeated TensorProto protos = 1;
...@@ -139,16 +119,6 @@ message OperatorDef { ...@@ -139,16 +119,6 @@ message OperatorDef {
optional string cache_key = 7; optional string cache_key = 7;
} }
// Record the gradient information
message GradientProto {
// The derivative target.
optional string cost = 1;
// The target with respect to?
optional string wrt = 2;
// The external gradient
optional string external = 3;
}
// Graph Definition // Graph Definition
message GraphDef { message GraphDef {
// The graph name. // The graph name.
...@@ -171,6 +141,33 @@ message GraphDef { ...@@ -171,6 +141,33 @@ message GraphDef {
// The name of outputs. // The name of outputs.
repeated string output = 8; repeated string output = 8;
// The gradients information. // The info of gradients.
repeated GradientProto gradient = 9; repeated GradientInfo grad_info = 9;
}
// Record the filler information.
// This structure is kept for backward compatibility
// with caffe, which relies the implicit initializer.
message FillerInfo {
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
FAN_AVG = 2;
}
optional string type = 1 [default = 'constant'];
optional float value = 2 [default = 0];
optional float low = 3 [default = 0];
optional float high = 4 [default = 1];
optional float mean = 5 [default = 0];
optional float std = 6 [default = 1];
optional float scale = 7 [default = 3];
optional VarianceNorm variance_norm = 8 [default = FAN_IN];
}
// Record the gradient information.
message GradientInfo {
// The derivative target.
optional string y = 1;
// The differentiated inputs.
repeated string xs = 2;
} }
...@@ -30,7 +30,6 @@ from dragon._api import metrics ...@@ -30,7 +30,6 @@ from dragon._api import metrics
from dragon._api import nn from dragon._api import nn
from dragon._api import optimizers from dragon._api import optimizers
from dragon._api import random from dragon._api import random
from dragon._api import workspace
from dragon._api import vision from dragon._api import vision
# Virtual API # Virtual API
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.training.adam import Adam
from dragon.core.training.rmsprop import RMSProp
from dragon.core.training.sgd import Nesterov
from dragon.core.training.sgd import SGD
from dragon.core.training.updater import Updater
__all__ = [_s for _s in dir() if not _s.startswith('_')]
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.framework.workspace import feed_tensor
from dragon.core.framework.workspace import fetch_tensor
from dragon.core.framework.workspace import has_tensor
from dragon.core.framework.workspace import load
from dragon.core.framework.workspace import reset_tensor
from dragon.core.framework.workspace import run_operator
from dragon.core.framework.workspace import save
__all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -28,6 +28,7 @@ from dragon.core.autograph.tensor import Tensor ...@@ -28,6 +28,7 @@ from dragon.core.autograph.tensor import Tensor
from dragon.core.eager import context as eager_context from dragon.core.eager import context as eager_context
from dragon.core.eager.tensor import EagerTensor from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.framework import device_spec
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.training import optimizer from dragon.core.training import optimizer
from dragon.core.util import decorator from dragon.core.util import decorator
...@@ -276,13 +277,16 @@ class FunctionGuard(object): ...@@ -276,13 +277,16 @@ class FunctionGuard(object):
executables = self.executables executables = self.executables
inputs, kwargs = self.canonicalize_inputs(*args, **kwargs) inputs, kwargs = self.canonicalize_inputs(*args, **kwargs)
executables[0](*inputs, return_outputs=False, **kwargs) executables[0](*inputs, return_outputs=False, **kwargs)
_ = [func(return_outputs=False) for func in executables[1:]] [func(return_outputs=False) for func in executables[1:]]
outputs = [] outputs = []
for obj in self.outputs: current_ws = workspace.get_workspace()
if isinstance(obj, Tensor): for output in self.outputs:
outputs.append(EagerTensor(id=obj.id, own_storage=False)) if isinstance(output, Tensor):
impl = current_ws.GetTensor(output.id)
device = device_spec.DeviceSpec(*impl.device)
outputs.append(EagerTensor(impl=impl, device=device))
else: else:
outputs.append(obj) outputs.append(output)
return outputs return outputs
def __get__(self, instance, owner): def __get__(self, instance, owner):
......
...@@ -23,7 +23,6 @@ from dragon.core.autograph.op_def import OpDef ...@@ -23,7 +23,6 @@ from dragon.core.autograph.op_def import OpDef
from dragon.core.autograph.op_def import OpInfo from dragon.core.autograph.op_def import OpInfo
from dragon.core.autograph.tensor import Tensor from dragon.core.autograph.tensor import Tensor
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.framework import context
from dragon.core.framework import proto_util from dragon.core.framework import proto_util
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.proto import dragon_pb2 from dragon.core.proto import dragon_pb2
...@@ -32,7 +31,7 @@ from dragon.core.util import nest ...@@ -32,7 +31,7 @@ from dragon.core.util import nest
def add_device_option(graph_def): def add_device_option(graph_def):
"""Add the device option for graph.""" """Add the device option."""
cfg = config.config() cfg = config.config()
str2idx = {'cpu': 0, 'cuda': 1, 'cnml': 2} str2idx = {'cpu': 0, 'cuda': 1, 'cnml': 2}
dev_opt = dragon_pb2.DeviceOption() dev_opt = dragon_pb2.DeviceOption()
...@@ -42,69 +41,66 @@ def add_device_option(graph_def): ...@@ -42,69 +41,66 @@ def add_device_option(graph_def):
graph_def.device_option.CopyFrom(dev_opt) graph_def.device_option.CopyFrom(dev_opt)
def add_gradient_info(graph_def, targets): def add_grad_info(graph_def, targets):
"""Add the gradient info for graph.""" """Add the gradient info."""
gradients = set()
for target in targets: for target in targets:
if target._grad is not None: info = target._grad
gradients.update(target._grad.make_pairs()) if info is not None:
for (cost, wrt) in gradients: graph_def.grad_info.extend([
gradient = dragon_pb2.GradientProto() dragon_pb2.GradientInfo(
gradient.cost, gradient.wrt = str(cost), str(wrt) y=info.y.id,
graph_def.gradient.extend([gradient]) xs=[x.id for x in info.xs])])
def add_optimization(graph_def, level=None): def add_optimization(graph_def, level=None):
"""Add the optimization attribute for graph.""" """Add the optimization argument."""
cfg = config.config() cfg = config.config()
if level is None: if level is None:
level = cfg.graph_optimization level = cfg.graph_optimization
graph_def.arg.add().CopyFrom( graph_def.arg.add().CopyFrom(
proto_util.make_argument( proto_util.make_argument('optimization', level))
'optimization_level', level))
graph_def.graph_type = cfg.graph_type graph_def.graph_type = cfg.graph_type
def add_phase(graph_def, targets): def add_phase(graph_def, targets):
"""Add the phase attribute for graph.""" """Add the phase argument."""
phase = context.get_graph_phase() phase = 'TEST'
if phase is None: for target in targets:
phase = 'TEST' try:
for target in targets: if target._grad and target._grad.required():
if target._grad is not None and \
target._grad.required():
phase = 'TRAIN' phase = 'TRAIN'
break break
except AttributeError:
pass
graph_def.arg.extend([proto_util.make_argument('phase', phase)]) graph_def.arg.extend([proto_util.make_argument('phase', phase)])
def add_update_ops(graph_def, optimizer): def add_update_defs(graph_def, optimizer):
"""Add the update operators for graph.""" """Add the update defs."""
if optimizer is None: if optimizer is None:
return return
grads, update_ops = [], [] grads, update_defs = [], []
extra_arguments = optimizer._extra_kwargs extra_arguments = optimizer._extra_kwargs
extra_arguments['handle'] = optimizer._op_handle extra_arguments['handle'] = optimizer._op_handle
# Generate update operators according to the updater. # Generate op defs according to the collected updates
for e in optimizer._param_group: current_ws = workspace.get_workspace()
(param, grad), arguments = e for (param, grad), arguments in optimizer._param_group:
if workspace.has_tensor(grad): if current_ws.has_tensor(grad):
grads.append(grad) grads.append(grad)
arguments = dict(arguments, **extra_arguments) arguments = dict(arguments, **extra_arguments)
update_ops.append( update_defs.append(
proto_util.make_operator_def( proto_util.make_operator_def(
op_type=optimizer._op_type, op_type=optimizer._op_type,
inputs=[grad], inputs=[grad],
outputs=[param], outputs=[param],
name=OpDef.get_name(), name=OpDef.get_name(),
**arguments **arguments))
))
else: else:
logging.info('Skip to update Tensor({}).'.format(param)) logging.info('Skip to update Tensor({}).'.format(param))
# Insert a reduce op if the process group is found. # Insert a reduce def if the process group is found.
process_group = optimizer._process_group process_group = optimizer._process_group
if process_group is not None: if process_group is not None:
update_ops.insert( update_defs.insert(
0, proto_util.make_operator_def( 0, proto_util.make_operator_def(
op_type='Collective', op_type='Collective',
inputs=grads, inputs=grads,
...@@ -115,7 +111,7 @@ def add_update_ops(graph_def, optimizer): ...@@ -115,7 +111,7 @@ def add_update_ops(graph_def, optimizer):
**process_group.arguments **process_group.arguments
) )
) )
graph_def.op.extend(update_ops) graph_def.op.extend(update_defs)
class Function(object): class Function(object):
...@@ -128,16 +124,15 @@ class Function(object): ...@@ -128,16 +124,15 @@ class Function(object):
self.graph_name = None # Determined after creating self.graph_name = None # Determined after creating
self.inputs, self.outputs = None, None self.inputs, self.outputs = None, None
def create(self, inputs=None, outputs=None, givens=None, updater=None): def create(self, inputs=None, outputs=None, givens=None, optimizer=None):
self.inputs = inputs = [] if inputs is None else nest.flatten(inputs) self.inputs = inputs = [] if inputs is None else nest.flatten(inputs)
self.outputs = outputs = [] if outputs is None else nest.flatten(outputs) self.outputs = outputs = [] if outputs is None else nest.flatten(outputs)
if len(outputs) > 0 and updater is not None: if len(outputs) > 0 and optimizer is not None:
raise ValueError('Specific either <outputs> or <updater>, not both.') raise ValueError('Specific either <outputs> or <optimizer>, not both.')
# Collect the forward defs.
op_info = OpInfo() op_info = OpInfo()
# Collect the forward operators.
requires_grad = False requires_grad = False
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
op_info.merge_from(output) op_info.merge_from(output)
...@@ -149,7 +144,7 @@ class Function(object): ...@@ -149,7 +144,7 @@ class Function(object):
except AttributeError: except AttributeError:
raise ValueError('Output[%d] is not a symbolic tensor.' % i) raise ValueError('Output[%d] is not a symbolic tensor.' % i)
# Handle givens. # Handle the replacements.
if givens is not None: if givens is not None:
name_dict = {} name_dict = {}
for k, v in givens.items(): for k, v in givens.items():
...@@ -161,62 +156,61 @@ class Function(object): ...@@ -161,62 +156,61 @@ class Function(object):
'Excepted a Tensor, ' 'Excepted a Tensor, '
'got {}.'.format(type(v).__name__) 'got {}.'.format(type(v).__name__)
) )
# Update original operators. # Update the original defs.
op_info = copy.deepcopy(op_info) op_info = copy.deepcopy(op_info)
for k in op_info._defs.keys(): for k in op_info._defs.keys():
op_def = op_info._defs[k] op_def = op_info._defs[k]
op_def.input.extend([ op_def.input.extend([
name_dict[input] name_dict[input]
if input in name_dict else input if input in name_dict else input
for input in op_def.input for input in op_def.input])
])
del op_def.input[:len(op_def.input) // 2] del op_def.input[:len(op_def.input) // 2]
# Sort out the states. # Sort out the forward defs.
op_defs = sorted(op_info._defs.items(), key=lambda d: d[0]) op_defs = sorted(op_info._defs.items(), key=lambda d: d[0])
forward_ops = copy.deepcopy([v for k, v in op_defs]) forward_defs = copy.deepcopy([v for k, v in op_defs])
# Generate the backward operators. # Generate the backward defs.
if requires_grad: if requires_grad:
input_grads, grad_targets = {}, [] input_grads, grad_targets = {}, []
for output in outputs: for output in outputs:
grad_info = output._grad info = output._grad
if grad_info is not None: if info is not None:
if grad_info.input is not None: if info.grad_y is not None:
input_grads[output.id] = output._grad.input.id input_grads[output.id] = info.grad_y.id
grad_targets.append(output.id) grad_targets.append(output.id)
forward_ops, gradient_ops, _ = \ backward_defs = grad_maker.GradientMaker.make(
grad_maker.GradientMaker.make( op_defs=forward_defs,
forward_ops=forward_ops, targets=grad_targets,
targets=grad_targets, input_grads=input_grads,
input_grads=input_grads, )
)
else: else:
gradient_ops = [] backward_defs = []
# Fill with all known graph elements. # Fill graph elements.
self.graph_def.op.extend(forward_ops + gradient_ops) self.graph_def.op.extend(forward_defs + backward_defs)
self.graph_def.input.extend([input.name for input in inputs]) self.graph_def.input.extend([input.name for input in inputs])
self.graph_def.output.extend(list(op_info._targets)) self.graph_def.output.extend(list(op_info._targets))
if len(outputs) > 0: if len(outputs) > 0:
add_device_option(self.graph_def) add_device_option(self.graph_def)
add_optimization(self.graph_def) add_optimization(self.graph_def)
add_gradient_info(self.graph_def, outputs) add_grad_info(self.graph_def, outputs)
add_phase(self.graph_def, outputs) add_phase(self.graph_def, outputs)
elif updater is not None: elif optimizer is not None:
add_device_option(self.graph_def) add_device_option(self.graph_def)
add_optimization(self.graph_def, level=0) add_optimization(self.graph_def, level=0)
add_update_ops(self.graph_def, updater) add_update_defs(self.graph_def, optimizer)
# Notify the backend to create and optimize. # Notify the backend to create and optimize.
self.graph_name = workspace.create_graph(self.graph_def) current_ws = workspace.get_workspace()
self.graph_name = current_ws.create_graph(self.graph_def)
# Bind a callback to run this graph. # Bind a callback to run this graph.
self.callback = lambda *args, **kwargs: \ self.callback = lambda *args, **kwargs: \
workspace.run_graph( current_ws.run_graph(
graph=self.graph_name, name=self.graph_name,
inputs=(inputs, args), inputs_and_values=(inputs, args),
outputs=outputs, outputs=outputs,
**kwargs **kwargs
) )
...@@ -273,15 +267,15 @@ class Function(object): ...@@ -273,15 +267,15 @@ class Function(object):
add_phase(graph_def, self.outputs) add_phase(graph_def, self.outputs)
# Notify the backend to create and optimize. # Notify the backend to create and optimize.
current_ws = workspace.get_workspace()
self.graph_def = graph_def self.graph_def = graph_def
self.graph_name = workspace.create_graph(graph_def) self.graph_name = current_ws.create_graph(graph_def)
# Bind a callback to run this graph. # Bind a callback to run this graph.
callback_inputs = self.inputs if explicit_inputs else []
self.callback = lambda *args, **kwargs: \ self.callback = lambda *args, **kwargs: \
workspace.run_graph( current_ws.run_graph(
graph=self.graph_name, name=self.graph_name,
inputs=(callback_inputs, args), inputs_and_values=(self.inputs if explicit_inputs else [], args),
outputs=self.outputs, outputs=self.outputs,
**kwargs **kwargs
) )
......
...@@ -21,37 +21,26 @@ from dragon.core.util import nest ...@@ -21,37 +21,26 @@ from dragon.core.util import nest
class GradientInfo(object): class GradientInfo(object):
"""A class to store the known gradient relations.""" """A class to store the known gradient relations."""
def __init__(self, parent): def __init__(self, y, grad_y=None):
self._parent = parent self._y, self._grad_y, self._xs = y, grad_y, []
self._cost, self._wrt = [], []
self._input = None
@property @property
def cost(self): def grad_y(self):
return self._cost return self._grad_y
@property @property
def input(self): def xs(self):
return self._input return self._xs
@property @property
def wrt(self): def y(self):
return self._wrt return self._y
def add_cost(self, cost): def add_x(self, x):
self._cost.append(cost) self._xs.append(x)
def add_wrt(self, wrt):
self._wrt.append(wrt)
def make_pairs(self):
return [(self._parent.id, wrt) for wrt in self._wrt]
def required(self): def required(self):
return len(self._wrt) > 0 return len(self._xs) > 0
def set_input(self, input):
self._input = input
def gradients(ys, xs, grad_ys=None): def gradients(ys, xs, grad_ys=None):
...@@ -112,18 +101,14 @@ def gradients(ys, xs, grad_ys=None): ...@@ -112,18 +101,14 @@ def gradients(ys, xs, grad_ys=None):
if grad_ys is not None: if grad_ys is not None:
grad_ys = nest.flatten(grad_ys) grad_ys = nest.flatten(grad_ys)
# Record the gradient info (cost, wrt, input), # Record the gradient info (y, grad_y, xs),
# then, generate the gradient references once. # then, generate the gradient references once.
for i, y in enumerate(ys): for i, y in enumerate(ys):
if y._grad is None: if y._grad is None:
y._grad = GradientInfo(y) grad_y = grad_ys[i] if grad_ys is not None else None
if grad_ys is not None: y._grad = GradientInfo(y, grad_y)
y._grad.set_input(grad_ys[i])
for x in xs: for x in xs:
if not hasattr(x, '_grad') or x._grad is None: y._grad.add_x(x)
x._grad = GradientInfo(x)
y._grad.add_wrt(x.id)
x._grad.add_cost(y)
if i == 0: if i == 0:
dxs.append(TensorRef(x.id + '_grad', x.shape, x.dtype)) dxs.append(TensorRef(x.id + '_grad', x.shape, x.dtype))
......
...@@ -13,16 +13,7 @@ ...@@ -13,16 +13,7 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Gradient maker implemented in python. """Python-implemented gradient maker."""
The basic idea of ``GradientMaker`` comes from ``caffe2``,
Jia provided a simple way to bridge the Generator(Python) with OpScheme(C++).
For the efficient C++ implementation, see,
<https://github.com/seetaresearch/Dragon/blob/master/Dragon/src/core/graph_gradient.cc>
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -40,25 +31,25 @@ class GradientMaker(object): ...@@ -40,25 +31,25 @@ class GradientMaker(object):
"""Make def for the gradient based on rules.""" """Make def for the gradient based on rules."""
@classmethod @classmethod
def gen_def(cls, forward_op, g_outputs): def gen_def(cls, op_def, g_outputs):
"""Generate the OperatorDef from forward op.""" """Generate the OperatorDef from forward op."""
g_ops, g_inputs, defaults = backend.CreateGradientDefs( grad_defs, g_inputs, defaults = backend.CreateGradientDefs(
forward_op.SerializeToString(), g_outputs) op_def.SerializeToString(), g_outputs)
for idx, g_op in enumerate(g_ops): for i, grad_def in enumerate(grad_defs):
new_def = dragon_pb2.OperatorDef() new_def = dragon_pb2.OperatorDef()
new_def.ParseFromString(g_op) new_def.ParseFromString(grad_def)
g_ops[idx] = new_def grad_defs[i] = new_def
return g_ops, g_inputs, defaults return grad_defs, g_inputs, defaults
@classmethod @classmethod
def check(cls, forward_op, inputs_to_grads, blacklist, targets): def check(cls, op_def, inputs_to_grads, blacklist, targets):
"""Check if missing gradients. If missing, skip.""" """Check if missing gradients. If missing, skip."""
if forward_op.type in backend.NO_GRADIENT_OPERATORS: if op_def.type in backend.NO_GRADIENT_OPERATORS:
for input in forward_op.input: for input in op_def.input:
blacklist.add(input) blacklist.add(input)
return True, None return True, None
gen_grads = [] gen_grads = []
for idx, output in enumerate(forward_op.output): for idx, output in enumerate(op_def.output):
if output not in inputs_to_grads: if output not in inputs_to_grads:
if output in blacklist: if output in blacklist:
return True, gen_grads return True, gen_grads
...@@ -66,50 +57,43 @@ class GradientMaker(object): ...@@ -66,50 +57,43 @@ class GradientMaker(object):
# Consider to generate virtual gradient for targets. # Consider to generate virtual gradient for targets.
gen_grads.append((output, idx)) gen_grads.append((output, idx))
inputs_to_grads[output] = output + '_grad' inputs_to_grads[output] = output + '_grad'
elif len(forward_op.output) == 1: elif len(op_def.output) == 1:
# We can skip this op, obviously. # We can skip this op, obviously.
return True, gen_grads return True, gen_grads
# Pass, even if missing some grads. # Pass, even if missing some grads.
return False, gen_grads return False, gen_grads
@classmethod @classmethod
def make(cls, forward_ops, targets, input_grads=None): def make(cls, op_defs, targets, input_grads=None):
"""The making procedure.""" """Make the backward op defs."""
inputs_to_grads = {} if input_grads is None else input_grads inputs_to_grads = {} if input_grads is None else input_grads
inputs_count, grads_count = defaultdict(int), defaultdict(int) inputs_count, grads_count = defaultdict(int), defaultdict(int)
all_split_grads, blacklist = set(), set() all_split_grads, blacklist = set(), set()
backward_ops = []
# A DAG may not have any in-place operators.
is_dag = True
# PLAY for the forward. # PLAY for the forward.
for forward_op in forward_ops: for op_def in op_defs:
if forward_op.type in backend.NO_GRADIENT_OPERATORS: if op_def.type in backend.NO_GRADIENT_OPERATORS:
continue continue
outputs = [o for o in forward_op.output] outputs = [output for output in op_def.output]
for input in forward_op.input: for input in op_def.input:
if input not in outputs: if input not in outputs:
# Avoid to count the duplicate input, # Avoid to count the duplicate input,
# (i.e. the in-place output). # (i.e. the in-place output).
inputs_count[input] += 1 inputs_count[input] += 1
else:
is_dag = False
# PLAY for the backward. # PLAY for the backward.
for forward_op in forward_ops[::-1]: backward_defs = []
for op_def in op_defs[::-1]:
# Collect inputs and outputs. # Collect inputs and outputs.
is_skip, gen_grads = cls.check( is_skip, gen_grads = cls.check(
forward_op=forward_op, op_def=op_def,
inputs_to_grads=inputs_to_grads, inputs_to_grads=inputs_to_grads,
blacklist=blacklist, blacklist=blacklist,
targets=targets, targets=targets,
) )
# Missing grads are represented as ``None``. # Missing grads are represented as ``None``.
g_outputs = [inputs_to_grads.get(name, '') g_outputs = [inputs_to_grads.get(name, '') for name in op_def.output]
for name in forward_op.output] grad_defs, g_inputs, defaults = cls.gen_def(op_def, g_outputs)
g_ops, g_inputs, defaults = cls.gen_def(forward_op, g_outputs)
# Append operators. # Append operators.
if not is_skip: if not is_skip:
...@@ -127,17 +111,17 @@ class GradientMaker(object): ...@@ -127,17 +111,17 @@ class GradientMaker(object):
outputs=op_outputs, outputs=op_outputs,
defaults=values, defaults=values,
) )
if forward_op.HasField('device_option'): if op_def.HasField('device_option'):
gen_op.device_option.CopyFrom(forward_op.device_option) gen_op.device_option.CopyFrom(op_def.device_option)
backward_ops.append(gen_op) backward_defs.append(gen_op)
# GradientOp # GradientOp
for g_op in g_ops: for grad_def in grad_defs:
g_op.name = OpDef.get_name() grad_def.name = OpDef.get_name()
backward_ops.append(g_op) backward_defs.append(grad_def)
# Split and gather grads for multi-used input. # Split and gather grads for multi-used input.
for g_op in g_ops: for grad_def in grad_defs:
for g_output_idx, g_output in enumerate(g_op.output): for g_output_idx, g_output in enumerate(grad_def.output):
original_idx = -1 original_idx = -1
for g_input_idx, g_input in enumerate(g_inputs): for g_input_idx, g_input in enumerate(g_inputs):
if g_output == g_input: if g_output == g_input:
...@@ -145,10 +129,10 @@ class GradientMaker(object): ...@@ -145,10 +129,10 @@ class GradientMaker(object):
# Ignore un-used && in-placed GI(?). # Ignore un-used && in-placed GI(?).
if original_idx == -1: if original_idx == -1:
continue continue
if g_output in g_op.input: if g_output in grad_def.input:
continue continue
# Found a split branch. # Found a split branch.
original_name = forward_op.input[original_idx] original_name = op_def.input[original_idx]
if inputs_count[original_name] > 1: if inputs_count[original_name] > 1:
# Split. # Split.
split_name = g_output + '_autosplit_%d' % grads_count[g_output] split_name = g_output + '_autosplit_%d' % grads_count[g_output]
...@@ -161,21 +145,21 @@ class GradientMaker(object): ...@@ -161,21 +145,21 @@ class GradientMaker(object):
for idx in range(grads_count[g_output]): for idx in range(grads_count[g_output]):
if '%s_autosplit_%d' % (g_output, idx) in all_split_grads: if '%s_autosplit_%d' % (g_output, idx) in all_split_grads:
split_inputs.append('%s_autosplit_%d' % (g_output, idx)) split_inputs.append('%s_autosplit_%d' % (g_output, idx))
gather_op = proto_util.make_operator_def( gather_def = proto_util.make_operator_def(
name=OpDef.get_name(), name=OpDef.get_name(),
op_type='GradientGather', op_type='GradientGather',
inputs=split_inputs, inputs=split_inputs,
outputs=[g_output], outputs=[g_output],
) )
if g_op.HasField('device_option'): if grad_def.HasField('device_option'):
gather_op.device_option.CopyFrom(g_op.device_option) gather_def.device_option.CopyFrom(grad_def.device_option)
backward_ops.append(gather_op) backward_defs.append(gather_def)
g_op.output[g_output_idx] = split_name grad_def.output[g_output_idx] = split_name
# Done. # Done.
if not is_skip: if not is_skip:
for name, grad in zip(forward_op.input, g_inputs): for name, grad in zip(op_def.input, g_inputs):
if grad != '': if grad != '':
inputs_to_grads[name] = grad inputs_to_grads[name] = grad
return forward_ops, backward_ops, is_dag return backward_defs
...@@ -30,9 +30,9 @@ class OpInfo(object): ...@@ -30,9 +30,9 @@ class OpInfo(object):
self._defs = dict() self._defs = dict()
self._targets = set() self._targets = set()
def add_def(self, idx, op_def): def add_def(self, index, op_def):
"""Add a operator definition.""" """Add a operator definition."""
self._defs[idx] = op_def self._defs[index] = op_def
def add_target(self, target): def add_target(self, target):
"""Add an extra target relied by inputs.""" """Add an extra target relied by inputs."""
...@@ -74,13 +74,14 @@ class OpDef(object): ...@@ -74,13 +74,14 @@ class OpDef(object):
# Create outputs. # Create outputs.
if outputs is None: if outputs is None:
outputs = [] outputs = []
current_ws = workspace.get_workspace()
name_scope = context.get_name_scope() name_scope = context.get_name_scope()
for i in range(num_outputs): for i in range(num_outputs):
outputs.append(TensorRef( outputs.append(TensorRef(
workspace.get_dummy_name( current_ws.unique_name(
name_scope + (name if name else op_type), name_scope + (name if name else op_type),
suffix=':{}'.format(i), suffix=':{}'.format(i),
domain='Tensor'))) namespace='Tensor')))
else: else:
outputs = nest.flatten(outputs) outputs = nest.flatten(outputs)
num_outputs = len(outputs) num_outputs = len(outputs)
...@@ -124,13 +125,13 @@ class OpDef(object): ...@@ -124,13 +125,13 @@ class OpDef(object):
return spec_func(arguments, inputs, outputs) return spec_func(arguments, inputs, outputs)
@staticmethod @staticmethod
def get_index_and_name(prefix='Op'): def get_index_and_name():
"""Return an unique op name and index.""" """Return an unique op name and index."""
name = workspace.get_dummy_name( name = workspace.get_workspace().unique_name(
prefix, domain='Operator', zero_based=False) 'Op', namespace='Op', zero_based=False)
return int(name.split('_')[-1]), name return int(name.split('_')[-1]), name
@staticmethod @staticmethod
def get_name(prefix='Op'): def get_name():
"""Return an unique op name.""" """Return an unique op name."""
return OpDef.get_index_and_name(prefix)[1] return OpDef.get_index_and_name()[1]
...@@ -190,24 +190,28 @@ def conv_spec(args, inputs, outputs): ...@@ -190,24 +190,28 @@ def conv_spec(args, inputs, outputs):
out_shape = None out_shape = None
try: try:
out_shape = inputs[0].shape[:] out_shape = inputs[0].shape[:]
num_axes = len(out_shape) - 2
channel_axis = 1 if args['data_format'] == 'NCHW' else -1 channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1 spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
if 'out_channels' in args: if 'out_channels' in args:
out_shape[channel_axis] = args['out_channels'] out_shape[channel_axis] = args['out_channels']
else: else:
out_shape[channel_axis] = inputs[1].shape[0] out_shape[channel_axis] = inputs[1].shape[0]
for i in range(len(out_shape) - 2): for i in range(num_axes):
input_size = out_shape[i + spatial_axis] try:
k = args['kernel_shape'][i] k = args['kernel_shape'][i]
s = args['strides'][i] s = args['strides'][i]
pl, pr = args['pads'][i], args['pads'][i + 2] d = args['dilations'][i]
dk, dp = (k - 1) + 1, pl + pr in_size = out_shape[i + spatial_axis]
if 'SAME' not in args['padding']: k_size = d * (k - 1) + 1
out_shape[i + spatial_axis] = \ if 'SAME' not in args['padding']:
int(float(input_size + dp - dk) / s) + 1 pad_size = args['pads'][i] + args['pads'][i + num_axes]
else: out_size = (in_size + pad_size - k_size) // s + 1
out_shape[i + spatial_axis] = \ else:
int(float(input_size + s - 1) / s) out_size = (in_size + s - 1) // s
except IndexError:
out_size = None
out_shape[i + spatial_axis] = out_size
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
outputs[0].shape = out_shape outputs[0].shape = out_shape
...@@ -220,30 +224,33 @@ def conv_transpose_spec(args, inputs, outputs): ...@@ -220,30 +224,33 @@ def conv_transpose_spec(args, inputs, outputs):
out_shape = None out_shape = None
try: try:
out_shape = inputs[0].shape[:] out_shape = inputs[0].shape[:]
num_axes = len(out_shape) - 2
channel_axis = 1 if args['data_format'] == 'NCHW' else -1 channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1 spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
if 'out_channels' in args: if 'out_channels' in args:
out_shape[channel_axis] = args['out_channels'] out_shape[channel_axis] = args['out_channels']
else: else:
out_shape[channel_axis] = inputs[1].shape[1] out_shape[channel_axis] = inputs[1].shape[1]
for i in range(len(out_shape) - 2): for i in range(num_axes):
k = args['kernel_shape'][i] try:
s = args['strides'][i] k = args['kernel_shape'][i]
d = args['dilations'][i] s = args['strides'][i]
pl, pr = args['pads'][i], args['pads'][i + 2] d = args['dilations'][i]
dk, dp = d * (k - 1) + 1, pl + pr in_size = out_shape[i + spatial_axis]
input_size = out_shape[i + spatial_axis] k_size = d * (k - 1) + 1
if 'SAME' not in args['padding']: if 'SAME' not in args['padding']:
out_shape[i + spatial_axis] = s * \ pad_size = args['pads'][i] + args['pads'][i + num_axes]
(input_size - 1) + dk - dp out_size = s * (in_size - 1) + k_size - pad_size
else: if 'output_padding' in args and args['output_padding']:
out_shape[i + spatial_axis] = None out_size += args['output_padding'][i]
if args['output_padding'] is not None: else:
out_shape[i + spatial_axis] = \ if 'output_shape' in args and args['output_shape']:
s * (input_size - 1) + dk + \ out_size = args['output_shape'][i]
args['output_padding'][i] else:
elif args['output_shape'] is not None: out_size = None
out_shape[i + spatial_axis] = args['output_shape'][i] except IndexError:
out_size = None
out_shape[i + spatial_axis] = out_size
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
outputs[0].shape = out_shape outputs[0].shape = out_shape
...@@ -606,21 +613,24 @@ def pool_spec(args, inputs, outputs): ...@@ -606,21 +613,24 @@ def pool_spec(args, inputs, outputs):
out_shape = None out_shape = None
try: try:
out_shape = inputs[0].shape[:] out_shape = inputs[0].shape[:]
num_axes = len(out_shape) - 2
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1 spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
for i in range(len(out_shape) - 2): for i in range(num_axes):
k = args['kernel_shape'][i]
s = args['strides'][i]
pl, pr = args['pads'][i], args['pads'][i + 2]
if not args['global_pooling']: if not args['global_pooling']:
floor_or_ceil = math.ceil if args['ceil_mode'] else math.floor try:
if 'SAME' not in args['padding']: k = args['kernel_shape'][i]
in_size = out_shape[i + spatial_axis] + pl + pr s = args['strides'][i]
out_size = int(floor_or_ceil(float(in_size - k) / s) + 1)
out_shape[i + spatial_axis] = out_size
else:
in_size = out_shape[i + spatial_axis] in_size = out_shape[i + spatial_axis]
out_size = int(floor_or_ceil(float(in_size) / s)) if 'SAME' not in args['padding']:
out_shape[i + spatial_axis] = out_size floor_or_ceil = math.ceil if args['ceil_mode'] else math.floor
pad_size = args['pads'][i] + args['pads'][i + num_axes]
out_size = float(in_size + pad_size - k) / float(s) + 1
out_size = floor_or_ceil(out_size)
else:
out_size = math.ceil(float(in_size) / float(s))
except IndexError:
out_size = None
out_shape[i + spatial_axis] = out_size
else: else:
out_shape[i + spatial_axis] = 1 out_shape[i + spatial_axis] = 1
except (TypeError, IndexError): except (TypeError, IndexError):
...@@ -959,14 +969,14 @@ def stack_spec(args, inputs, outputs): ...@@ -959,14 +969,14 @@ def stack_spec(args, inputs, outputs):
@register('Tile') @register('Tile')
def tile_spec(args, inputs, outputs): def tile_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
multiples = args['multiples'] repeats = args['repeats']
if multiples is not None: if repeats is not None:
try: try:
out_shape = inputs[0].shape[:] out_shape = inputs[0].shape[:]
for i, multiple in enumerate(multiples): for i, size in enumerate(repeats):
if i < len(out_shape): if i < len(out_shape):
try: try:
out_shape[i] *= multiple out_shape[i] *= size
except TypeError: except TypeError:
out_shape[i] = None out_shape[i] = None
outputs[0].shape = out_shape outputs[0].shape = out_shape
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!