Commit 02ad90d5 by Ting PAN

Remove the duplicate workspace singletons

Summary:
This commit moves the workspace api into the current workspace instance.
For this reason, the namespace ``dragon.workspace`` is removed for simplicity.
1 parent adb6fa64
Showing with 1163 additions and 1799 deletions
<p align="center">
<img width="40%" src="http://dragon.seetatech.com/static/images/styles-dragon.png"/>
<img width="40%" src="https://dragon.seetatech.com/static/images/styles-dragon.png"/>
</p>
[Dragon](http://dragon.seetatech.com) is a **C**(Computation)**G**(Graph)**V**(Virtual)**M**(Machine) based distributed deep learning framework.
[Dragon](https://dragon.seetatech.com) is a **C**(Computation)**G**(Graph)**V**(Virtual)**M**(Machine) based distributed deep learning framework.
It fuses several modern frameworks and integrations together, powered by a unified engine.
The computation between different programming styles is deterministic and reproduceable.
......@@ -11,7 +11,7 @@ promoting internal interfaces. We will always learn from the AI community to evo
## Installation
See the [install guide](http://dragon.seetatech.com/install) for the pip package
See the [install guide](https://dragon.seetatech.com/install) for the pip package
or how to build from source.
## License
......
......@@ -15,9 +15,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context as eager_context
from dragon.core.framework import context
from dragon.core.util import logging
from dragon.vm.caffe.proto import caffe_pb2
class Layer(object):
......@@ -34,24 +38,26 @@ class Layer(object):
"""
self._proto = layer_param
self._name = layer_param.name
self._arguments, self.arguments = {'name': self._name}, {}
self._arguments, self.arguments = {'name': 'output'}, {}
# Store the inputs, outputs and trainable parameters.
self._bottom, self._top, self._blobs = [], [], []
for blob in layer_param.bottom:
self._bottom.append(blob)
for blob in layer_param.top:
self._top.append(blob)
# Store the loss weight to apply gradients.
self._loss_weight = layer_param.loss_weight \
if len(layer_param.loss_weight) > 0 else None
# Optional mirror stage argument for memory optimization.
if layer_param.HasField('mirror_stage'):
self._arguments['mirror_stage'] = layer_param.mirror_stage
@property
def blobs(self):
"""Return the blobs."""
return self._blobs
@property
def bottom(self):
"""Return the bottom names."""
return self._bottom
......@@ -62,49 +68,91 @@ class Layer(object):
return self._loss_weight
@property
def name(self):
"""Return the layer name."""
return self._name
@property
def top(self):
"""Return the top names."""
return self._top
def add_blob(self, value=None, filler=None, no_grad=False):
"""Add a weight blob into this layer."""
# Use a fixed name in the current workspace.
# Note that a non-empty tensor scope will make it
# impossible to load/save models. You should use
# a new workspace instead of the terrible name scope.
scoped_name = context.get_name_scope() + self._name
param_name = scoped_name + '/param:{}'.format(len(self._blobs))
# Set the name explicitly.
variable = TensorRef(param_name)
variable_grad = TensorRef(param_name + '_grad')
"""Add a blob into this layer."""
# Set the name for reference explicitly.
data_name = context.get_name_scope() + 'param:{}'.format(len(self._blobs))
data, diff = TensorRef(data_name), TensorRef(data_name + '_grad')
if filler is not None:
variable._register_as(**filler)
data._register_as(**filler)
else:
# Register a constant filler by default.
value = value if value else 0
variable.constant(value=value)
data.constant(value=value)
# Append to the blobs.
self._blobs.append({'data': data, 'diff': None if no_grad else diff})
# Determine whether to disable the gradients explicitly.
if no_grad is True:
variable_grad = None
def from_proto(self, proto):
"""Deserialize from the proto.
# Append to the blobs.
self._blobs.append({'data': variable, 'diff': variable_grad})
Parameters
----------
proto : LayerParameter
The ``LayerParameter`` protocol buffer.
"""
for i in range(len(self._blobs)):
if i < len(proto.blobs):
blob_proto = proto.blobs[i]
if len(blob_proto.data) > 0:
value = numpy.array(blob_proto.data, dtype='float32')
elif len(blob_proto.double_data) > 0:
value = numpy.array(blob_proto.double_data, dtype='float64')
else:
raise ValueError('Neither <data> or <double_data> in blob proto.')
if len(blob_proto.shape.dim) > 0:
value = value.reshape([dim for dim in blob_proto.shape.dim])
self._blobs[i]['data'].set_value(value)
logging.info('Blob({}/param:{}) loaded, shape: {}, size: {}'
.format(self._name, i, value.shape, value.size))
def setup(self, bottom):
# Merge the arguments, then setup up the specific layer.
"""Setup the layer."""
self.arguments = dict(self.arguments, **self._arguments)
bottom = bottom[0] if len(bottom) == 1 else bottom
with eager_context.graph_mode():
return self.__call__(bottom)
@classmethod
def get_filler(cls, layer_param, filler_name):
"""Construct a filler from the parameter."""
if layer_param.HasField(filler_name):
filler = getattr(layer_param, filler_name)
def to_proto(self):
"""Serialize to the proto.
Returns
-------
LayerParameter
The ``LayerParameter`` protocol buffer.
"""
proto = caffe_pb2.LayerParameter()
proto.CopyFrom(self._proto)
for blob in self._blobs:
value = blob['data'].get_value()
if str(value.dtype) == 'float32':
blob_proto = caffe_pb2.BlobProto(
data=value.flatten(),
shape=caffe_pb2.BlobShape(dim=value.shape))
elif str(value.dtype) == 'float64':
blob_proto = caffe_pb2.BlobProto(
double_data=value.flatten(),
shape=caffe_pb2.BlobShape(dim=value.shape))
else:
raise ValueError('Either float32 or float64 blob is required.')
proto.blobs.extend([blob_proto])
return proto
@staticmethod
def get_filler(proto, filler_name):
"""Return the filler from proto."""
if proto.HasField(filler_name):
filler = getattr(proto, filler_name)
return {
'type': filler.type.lower(),
'value': filler.value,
......
......@@ -16,14 +16,10 @@ from __future__ import print_function
from dragon.vm.caffe.layers.common import Accuracy
from dragon.vm.caffe.layers.common import ArgMax
from dragon.vm.caffe.layers.common import BatchNorm
from dragon.vm.caffe.layers.common import Cast
from dragon.vm.caffe.layers.common import Concat
from dragon.vm.caffe.layers.common import Crop
from dragon.vm.caffe.layers.common import Eltwise
from dragon.vm.caffe.layers.common import Flatten
from dragon.vm.caffe.layers.common import FusedBatchNorm
from dragon.vm.caffe.layers.common import FusedGroupNorm
from dragon.vm.caffe.layers.common import GroupNorm
from dragon.vm.caffe.layers.common import InnerProduct
from dragon.vm.caffe.layers.common import Input
from dragon.vm.caffe.layers.common import Normalize
......@@ -46,12 +42,10 @@ from dragon.vm.caffe.layers.neuron import ELU
from dragon.vm.caffe.layers.neuron import Power
from dragon.vm.caffe.layers.neuron import PReLU
from dragon.vm.caffe.layers.neuron import ReLU
from dragon.vm.caffe.layers.neuron import SELU
from dragon.vm.caffe.layers.neuron import Sigmoid
from dragon.vm.caffe.layers.neuron import TanH
from dragon.vm.caffe.layers.vision import Convolution
from dragon.vm.caffe.layers.vision import Deconvolution
from dragon.vm.caffe.layers.vision import DepthwiseConv2d
from dragon.vm.caffe.layers.vision import LRN
from dragon.vm.caffe.layers.vision import Pooling
from dragon.vm.caffe.layers.vision import ROIAlign
......
......@@ -33,8 +33,9 @@ class _DataPlugin(object):
def forward(self, inputs, outputs):
blobs = self.iterator.next()
current_ws = workspace.get_workspace()
for i, blob in enumerate(blobs):
workspace.feed_tensor(outputs[i], blob)
current_ws.feed_tensor(outputs[i], blob)
class Data(Layer):
......@@ -44,42 +45,46 @@ class Data(Layer):
```python
layer {
type: "Data"
top: "data"
top: "label"
include { phase: TRAIN }
data_param {
source: "/data/imagenet/train"
batch_size: 128
shuffle: true
num_chunks: 0
prefetch: 5
}
transform_param {
mirror: true
random_crop_size: 224
augment_color: true
mean_value: 104.00698793
mean_value: 116.66876762
mean_value: 122.67891434
}
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
data_param {
source: "/data/train"
batch_size: 128
shuffle: true
num_chunks: 0
prefetch: 5
}
transform_param {
mirror: true
random_crop_size: 224
augment_color: true
mean_value: 104.00698793
mean_value: 116.66876762
mean_value: 122.67891434
}
}
layer {
type: "Data"
top: "data"
top: "label"
include { phase: TEST }
data_param {
source: "/data/imagenet/val"
batch_size: 100
}
transform_param {
resize: 256
crop_size: 224
mean_value: 104.00698793
mean_value: 116.66876762
mean_value: 122.67891434
}
type: "Data"
top: "data"
top: "label"
include {
phase: TEST
}
data_param {
source: "/data/val"
batch_size: 64
}
transform_param {
resize: 256
crop_size: 224
mean_value: 104.00698793
mean_value: 116.66876762
mean_value: 122.67891434
}
}
```
......
......@@ -30,13 +30,13 @@ class EuclideanLoss(Layer):
```python
layer {
type: "EuclideanLoss"
bottom: "bbox_pred"
bottom: "bbox_target"
top: "bbox_loss"
loss_param {
normalization: BATCH_SIZE
}
type: "EuclideanLoss"
bottom: "bbox_pred"
bottom: "bbox_target"
top: "bbox_loss"
loss_param {
normalization: BATCH_SIZE
}
}
```
......@@ -67,13 +67,13 @@ class SigmoidCrossEntropyLoss(Layer):
```python
layer {
type: "SigmoidCrossEntropyLoss"
bottom: "rpn_cls_score"
bottom: "rpn_labels"
top: "rpn_loss"
loss_param {
normalization: VALID
}
type: "SigmoidCrossEntropyLoss"
bottom: "rpn_cls_score"
bottom: "rpn_labels"
top: "rpn_loss"
loss_param {
normalization: VALID
}
}
```
......@@ -106,15 +106,15 @@ class SmoothL1Loss(Layer):
```python
layer {
type: "SmoothL1Loss"
bottom: "bbox_pred"
bottom: "bbox_targets"
bottom: "bbox_inside_weights"
bottom: "bbox_outside_weights"
top: "bbox_loss"
loss_param {
normalization: BATCH_SIZE
}
type: "SmoothL1Loss"
bottom: "bbox_pred"
bottom: "bbox_targets"
bottom: "bbox_inside_weights"
bottom: "bbox_outside_weights"
top: "bbox_loss"
loss_param {
normalization: BATCH_SIZE
}
}
```
......@@ -155,15 +155,17 @@ class SoftmaxWithLoss(Layer):
```python
layer {
type: "SoftmaxWithLoss"
bottom: "cls_score"
bottom: "labels"
top: "cls_loss"
softmax_param { axis: 1 }
loss_param {
ignore_label: -1
normalization: VALID
}
type: "SoftmaxWithLoss"
bottom: "cls_score"
bottom: "labels"
top: "cls_loss"
softmax_param {
axis: 1
}
loss_param {
ignore_label: -1
normalization: VALID
}
}
```
......
......@@ -32,12 +32,12 @@ class Dropout(Layer):
```python
layer {
type: "Dropout"
bottom: "fc6"
top: "fc6"
dropout_param {
dropout_ratio: 0.5
}
type: "Dropout"
bottom: "fc6"
top: "fc6"
dropout_param {
dropout_ratio: 0.5
}
}
```
......@@ -73,12 +73,12 @@ class ELU(Layer):
```python
layer {
type: "ELU"
bottom: "conv2"
top: "conv2"
elu_param {
alpha: 1.
}
type: "ELU"
bottom: "conv2"
top: "conv2"
elu_param {
alpha: 1.
}
}
```
......@@ -101,14 +101,14 @@ class Power(Layer):
```python
layer {
type: "Power"
bottom: "x"
top: "y"
power_param {
scale: 1.
shift: 0.
power: 2.
}
type: "Power"
bottom: "x"
top: "y"
power_param {
scale: 1.
shift: 0.
power: 2.
}
}
```
......@@ -148,16 +148,16 @@ class PReLU(Layer):
```python
layer {
type: "PReLU"
bottom: "conv2"
top: "conv2/relu"
prelu_param {
channel_shared: false
filler {
type: "constant"
value: 0.25
}
type: "PReLU"
bottom: "conv2"
top: "conv2/relu"
prelu_param {
channel_shared: false
filler {
type: "constant"
value: 0.25
}
}
}
```
......@@ -194,12 +194,12 @@ class ReLU(Layer):
```python
layer {
type: "ReLU"
bottom: "conv2"
top: "conv2/relu"
relu_param {
negative_slope: 0.
}
type: "ReLU"
bottom: "conv2"
top: "conv2/relu"
relu_param {
negative_slope: 0.
}
}
```
......@@ -215,38 +215,6 @@ class ReLU(Layer):
return activation_ops.relu(bottom, **self.arguments)
class SELU(Layer):
r"""Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
The **SELU** function is defined as:
.. math::
\text{SELU}(x) = 1.0507 *
\begin{cases}
x, & \text{ if } x \geq 0 \\
1.6733 * (e^{x} - 1), & \text{ otherwise }
\end{cases}
Examples:
```python
layer {
type: "SELU"
bottom: "conv2"
top: "conv2/relu"
}
```
"""
def __init__(self, layer_param):
super(SELU, self).__init__(layer_param)
def __call__(self, bottom):
return activation_ops.selu(bottom, **self.arguments)
class Sigmoid(Layer):
r"""Apply the sigmoid function.
......@@ -258,9 +226,9 @@ class Sigmoid(Layer):
```python
layer {
type: "Sigmoid"
bottom: "rpn_cls_score"
top: "rpn_cls_prob"
type: "Sigmoid"
bottom: "rpn_cls_score"
top: "rpn_cls_prob"
}
```
......@@ -284,9 +252,9 @@ class TanH(Layer):
```python
layer {
type: "TanH"
bottom: "g/conv5"
top: "g/image"
type: "TanH"
bottom: "g/conv5"
top: "g/image"
}
```
......
......@@ -23,39 +23,29 @@ from dragon.vm.caffe.layer import Layer
class Convolution(Layer):
r"""Apply the n-dimension convolution.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} +
2 * pad - \text{DK}_{size}) / stride + 1
\end{cases}
Examples:
```python
layer {
type: "Convolution"
bottom: "input"
top: "conv1"
convolution_param {
num_output: 32
bias_term: true
kernel_size: 3
pad: 1
stride: 1
dilation: 1
group: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
value: 0
}
type: "Convolution"
bottom: "input"
top: "conv1"
convolution_param {
num_output: 32
bias_term: true
kernel_size: 3
pad: 1
stride: 1
dilation: 1
group: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
value: 0
}
}
}
```
......@@ -83,7 +73,6 @@ class Convolution(Layer):
if param.HasField('pad_h'):
assert param.HasField('pad_w')
self.arguments['pads'] = [param.pad_h, param.pad_w]
self.add_blob(filler=self.get_filler(param, 'weight_filler'))
if param.bias_term:
self.add_blob(filler=self.get_filler(param, 'bias_filler'))
......@@ -96,39 +85,29 @@ class Convolution(Layer):
class Deconvolution(Convolution):
r"""Apply the 2d deconvolution.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} - 1) *
stride + \text{DK}_{size} - 2 * pad
\end{cases}
Examples:
```python
layer {
type: "Deconvolution"
bottom: "conv5"
top: "conv5/upscale"
convolution_param {
num_output: 256
bias_term: true
kernel_size: 2
pad: 0
stride: 2
dilation: 1
group: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
value: 0
}
type: "Deconvolution"
bottom: "conv5"
top: "conv5/upscale"
convolution_param {
num_output: 256
bias_term: true
kernel_size: 2
pad: 0
stride: 2
dilation: 1
group: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
value: 0
}
}
}
```
......@@ -142,77 +121,6 @@ class Deconvolution(Convolution):
return vision_ops.conv2d_transpose(inputs, **self.arguments)
class DepthwiseConv2d(Layer):
r"""Apply the 2d depthwise convolution.
`[Chollet, 2016] <https://arxiv.org/abs/1610.02357>`_.
The spatial output dimension is computed as:
.. math::
\begin{cases}
\text{DK}_{size} = dilation *
(\text{K}_{size} - 1) + 1 \\
\text{Dim}_{out} = (\text{Dim}_{in} +
2 * pad - \text{DK}_{size}) / stride + 1
\end{cases}
Examples:
```python
layer {
type: "DepthwiseConv2d"
bottom: "input"
top: "conv1"
convolution_param {
num_output: 32
bias_term: true
kernel_size: 3
pad: 1
stride: 1
dilation: 1
weight_filler {
type: "xavier"
variance_norm: FAN_OUT
}
bias_filler {
type: "constant"
value: 0
}
}
}
```
"""
def __init__(self, layer_param):
super(DepthwiseConv2d, self).__init__(layer_param)
param = layer_param.convolution_param
self.arguments = {
'out_channels': param.num_output,
'kernel_shape': [int(e) for e in param.kernel_size],
'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1],
'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0],
'padding': 'VALID',
'data_format': 'NCHW',
}
if param.HasField('kernel_h'):
assert param.HasField('kernel_w')
self.arguments['kernel_shape'] = [param.kernel_h, param.kernel_w]
if param.HasField('stride_h'):
assert param.HasField('stride_w')
self.arguments['strides'] = [param.stride_h, param.stride_w]
if param.HasField('pad_h'):
assert param.HasField('pad_w')
self.arguments['pads'] = [param.pad_h, param.pad_w]
self.add_blob(filler=self.get_filler(param, 'weight_filler'))
if param.bias_term:
self.add_blob(filler=self.get_filler(param, 'bias_filler'))
def __call__(self, bottom):
inputs = [bottom] + [blob['data'] for blob in self._blobs]
return vision_ops.depthwise_conv2d(inputs, **self.arguments)
class LRN(Layer):
r"""Apply the local response normalization.
`[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_.
......@@ -221,15 +129,15 @@ class LRN(Layer):
```python
layer {
type: "LRN"
bottom: "conv2"
top: "conv2/norm"
lrn_param {
local_size: 5
alpha: 1.
beta: 0.75
k: 1.
}
type: "LRN"
bottom: "conv2"
top: "conv2/norm"
lrn_param {
local_size: 5
alpha: 1.
beta: 0.75
k: 1.
}
}
```
......@@ -255,24 +163,18 @@ class LRN(Layer):
class Pooling(Layer):
r"""Apply the n-dimension pooling.
The spatial output dimension is computed as:
.. math::
\text{Dim}_{out} = (\text{Dim}_{in} +
2 * pad - \text{K}_{size}) / stride + 1
Examples:
```python
layer {
type: "Pooling"
bottom: "conv2"
top: "pool2"
pooling_param {
kernel_size: 3
stride: 2
pool: AVG
}
type: "Pooling"
bottom: "conv2"
top: "pool2"
pooling_param {
kernel_size: 3
stride: 2
pool: AVG
}
}
```
......@@ -311,14 +213,14 @@ class ROIAlign(Layer):
```python
layer {
type: "ROIAlign"
bottom: "conv5_3"
top: "roi_pool4"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625
}
type: "ROIAlign"
bottom: "conv5_3"
top: "roi_pool4"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625
}
}
```
......@@ -345,14 +247,14 @@ class ROIPooling(Layer):
```python
layer {
type: "ROIPooling"
bottom: "conv5_3"
top: "roi_pool4"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625
}
type: "ROIPooling"
bottom: "conv5_3"
top: "roi_pool4"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625
}
}
```
......
......@@ -16,15 +16,14 @@ from __future__ import division
from __future__ import print_function
import time
from google.protobuf import text_format
from dragon.core.autograph import def_function
from dragon.core.framework import workspace
from dragon.core.training.adam import Adam
from dragon.core.training.rmsprop import RMSprop
from dragon.core.training.sgd import SGD
from dragon.core.training.sgd import Nesterov
from dragon.core.util import logging
from dragon.vm.caffe.net import Net
from dragon.vm.caffe.proto import caffe_pb2
......@@ -99,8 +98,9 @@ class Solver(object):
if self._current_step < len(self._param.stepvalue) \
and self.iter >= self._param.stepvalue[self._current_step]:
self._current_step = self._current_step + 1
print('MultiStep Status: Iteration {}, step = {}'
.format(self.iter, self._current_step))
logging.info(
'MultiStep Status: Iteration {}, step = {}'
.format(self.iter, self._current_step))
new_lr = self._param.base_lr * \
pow(self._param.gamma, self._current_step)
self.base_lr = new_lr
......@@ -112,8 +112,9 @@ class Solver(object):
else:
if self._current_step + 1 < len(stage_iters):
self._current_step = self._current_step + 1
print('MultiFixed Status: Iteration {}, stage = {}'
.format(self.iter, self._current_step))
logging.info(
'MultiFixed Status: Iteration {}, stage = {}'
.format(self.iter, self._current_step))
self.base_lr = stage_lrs[self._current_step]
elif policy == 'inv':
power = self._param.power
......@@ -130,8 +131,7 @@ class Solver(object):
def _apply_update(self):
"""Apply the weights update."""
for blob in self.net._layer_blobs:
if blob.lr_multiplier > 0 and \
blob.diff is not None:
if blob.lr_multiplier > 0 and blob.diff is not None:
self._optimizer.apply_gradients(
values_and_grads=[(blob.data, blob.diff)],
lr_mult=blob.lr_multiplier,
......@@ -211,80 +211,18 @@ class Solver(object):
"""
return self._test_nets
def one_step(self):
"""One step run the train net.
Returns
-------
dict
The stats.
"""
if self._param.test_interval and \
self.iter % self._param.test_interval == 0:
if (self.iter == 0 and
self._param.test_initialization) or self.iter != 0:
for test_idx in range(len(self._test_nets)):
self.test(test_idx)
# Forward, backward and compute loss.
run_time, stats = 0., {'loss': {'total': 0.}, 'iter': self.iter}
for i in range(self._param.iter_size):
tic = time.time()
self._net.forward_backward(return_outputs=False)
run_time += (time.time() - tic)
# Total loss.
for e in self.net.losses:
values = e.get_value().flatten()
if values.size == 1:
stats['loss']['total'] += values[0]
# Partial loss.
for key in self.net.outputs:
values = self.net.blobs[key].data
values = values.get_value().flatten()
if values.size != 1:
continue
if key not in stats['loss']:
stats['loss'][key] = 0.
stats['loss'][key] += values[0]
# Apply Update.
self._get_learning_rate()
tic = time.time()
self._apply_update()
run_time += (time.time() - tic)
self.iter = self.iter + 1
# Snapshot.
if self._param.snapshot:
if self.iter % self._param.snapshot == 0:
self.snapshot()
# Average loss by the iter size.
for k in stats['loss'].keys():
stats['loss'][k] /= self._param.iter_size
# Misc stats.
stats['lr'] = self.base_lr
stats['time'] = run_time
return stats
def snapshot(self):
"""Snapshot the parameters of train net."""
workspace.save(
tensors=[blob.data for blob in self.net._layer_blobs],
filename='_iter_%d' % self.iter,
prefix=self._param.snapshot_prefix,
suffix='.caffemodel',
format='caffe',
)
def step(self, num_iterations):
self._net.save(
'%s_iter_%d.caffemodel'
% (self._param.snapshot_prefix, self._iter))
def step(self, num_iterations=1):
"""Step the train net.
Parameters
----------
num_iterations : int
num_iterations : int, optional, default=1
The number of iterations to step.
"""
......@@ -293,19 +231,18 @@ class Solver(object):
loss_vec, smoothed_loss = [], 0.
tic = time.time()
while self.iter < stop_step:
# Test if necessary.
if self._param.test_interval and \
self.iter % self._param.test_interval == 0:
if (self.iter == 0 and
self._param.test_initialization) or self.iter != 0:
if self._is_root and self._param.test_interval > 0 and \
self.iter % self._param.test_interval == 0:
if (self.iter == 0 and self._param.test_initialization) or \
self.iter != 0:
for test_idx in range(len(self._test_nets)):
self.test(test_idx)
# Forward, backward and compute loss.
loss = 0.
for i in range(self._param.iter_size):
self._net.forward_backward(return_outputs=False)
self._net.forward_backward()
if self._is_root:
for e in self.net.losses:
values = e.get_value().flatten()
......@@ -322,24 +259,23 @@ class Solver(object):
idx = (self.iter - start_step) % self._param.average_loss
smoothed_loss += ((loss - loss_vec[idx]) / self._param.average_loss)
loss_vec[idx] = loss
# Apply Update.
self._get_learning_rate()
self._apply_update()
# Display.
# Display iteration info.
if self._is_root and self._param.display:
if self.iter % self._param.display == 0:
print('Iteration %d, lr = %s, loss = %f, time = %.2fs' % (
self.iter, str(self.base_lr), smoothed_loss, time.time() - tic))
logging.info(
'Iteration %d, lr = %s, loss = %f, time = %.2fs'
% (self.iter, str(self.base_lr), smoothed_loss, time.time() - tic))
tic = time.time()
for idx, net_output in enumerate(self.net.outputs):
values = self.net.blobs[net_output].data.get_value().flatten()
for v in values:
print(' ' * 10 + 'Train net output #{}({}): {}'
.format(idx, net_output, v))
logging.info(
' ' * 10 + 'Train net output #{}({}): {}'
.format(idx, net_output, v))
self.iter = self.iter + 1
# Snapshot if necessary.
if self._param.snapshot:
if self.iter % self._param.snapshot == 0:
......@@ -359,7 +295,7 @@ class Solver(object):
test_iter = self._param.test_iter[test_idx]
for iter in range(test_iter):
net.forward_backward(return_outputs=False)
net.forward()
if not self._is_root:
continue
if iter == 0:
......@@ -376,27 +312,25 @@ class Solver(object):
test_score[i] += value
i += 1
if not self._is_root:
return
print('Iteration {}, Test net #{}'.format(self.iter, test_idx))
logging.info('Iteration {}, Test net #{}'.format(self.iter, test_idx))
for i, score in enumerate(test_score):
print(' ' * 10 + 'Test net output #%d(%s): %.4f'
% (i, output_id[i], score / test_iter))
logging.info(
' ' * 10 + 'Test net output #%d(%s): %.4f'
% (i, output_id[i], score / test_iter))
class AdamSolver(Solver):
r"""The Adam solver.
`[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_.
Following hyper parameters will be taken:
Examples:
```python
caffe_pb2.SolverParameter(
base_lr=0.,
momentum=0.,
momentum2=0.999,
delta=1e-8,
solver {
base_lr=0.001,
momentum=0.9,
momentum2=0.999,
delta=1e-8,
)
```
......@@ -425,13 +359,13 @@ class NesterovSolver(Solver):
r"""The Nesterov-SGD solver.
`[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_.
Following hyper parameters will be taken:
Examples:
```python
caffe_pb2.SolverParameter(
base_lr=0.,
momentum=0.,
)
solver {
base_lr: 0.01
momentum: 0.9
}
```
"""
......@@ -457,13 +391,13 @@ class RMSPropSolver(Solver):
r"""The RMSProp solver.
`[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_.
Following hyper parameters will be taken:
Examples:
```python
caffe_pb2.SolverParameter(
base_lr=0.,
rms_decay=0.99,
delta=1e-8,
solver {
base_lr=0.01,
rms_decay=0.99,
delta=1e-8,
)
```
......@@ -491,12 +425,12 @@ class SGDSolver(Solver):
r"""The Momentum-SGD solver.
`[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_.
Following hyper parameters will be taken:
Examples:
```python
caffe_pb2.SolverParameter(
base_lr=0.,
momentum=0.,
solver {
base_lr=0.01,
momentum=0.9,
)
```
......
......@@ -3,9 +3,9 @@ Building Dragon Documentation
This page will help you to build the following documentations:
Dragon C++ API: http://dragon.seetatech.com/api/cc
Dragon C++ API: https://dragon.seetatech.com/api/cc
Dragon Python API: http://dragon.seetatech.com/api/python
Dragon Python API: https://dragon.seetatech.com/api/python
Build Documentation of C++ API
------------------------------
......
......@@ -34,10 +34,6 @@ vm.caffe.layers
`class Deconvolution <layers/Deconvolution.html>`_
: Apply the n-dimension deconvolution.
`class DepthwiseConv2d <layers/DepthwiseConv2d.html>`_
: Apply the 2d depthwise convolution.
`[Chollet, 2016] <https://arxiv.org/abs/1610.02357>`_.
`class Dropout <layers/Dropout.html>`_
: Set the elements of the input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
......@@ -58,18 +54,6 @@ vm.caffe.layers
`class Flatten <layers/Flatten.html>`_
: Flatten the input along the given axes.
`class FusedBatchNorm <layers/FusedBatchNorm.html>`_
: Apply the fused batch normalization.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`class FusedGroupNorm <layers/FusedBatchNorm.html>`_
: Apply the fused group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`class GroupNorm <layers/FusedBatchNorm.html>`_
: Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`class InnerProduct <layers/InnerProduct.html>`_
: Compute the dense matrix multiplication along the given axes.
......@@ -121,10 +105,6 @@ vm.caffe.layers
`class Scale <layers/Scale.html>`_
: Compute the affine transformation along the given axes.
`class SELU <layers/SELU.html>`_
: Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`class Sigmoid <layers/Sigmoid.html>`_
: Apply the sigmoid function.
......@@ -145,7 +125,7 @@ vm.caffe.layers
: Apply the tanh function.
`class Tile <layers/Tile.html>`_
: Tile the input according to the given multiples.
: Repeat the input according to the given axis.
.. toctree::
:hidden:
......@@ -153,21 +133,16 @@ vm.caffe.layers
layers/Accuracy
layers/ArgMax
layers/BatchNorm
layers/Cast
layers/Concat
layers/Convolution
layers/Crop
layers/Data
layers/Deconvolution
layers/DepthwiseConv2d
layers/Dropout
layers/Eltwise
layers/ELU
layers/EuclideanLoss
layers/Flatten
layers/FusedBatchNorm
layers/FusedGroupNorm
layers/GroupNorm
layers/InnerProduct
layers/Input
layers/LRN
......@@ -183,7 +158,6 @@ vm.caffe.layers
layers/ROIAlign
layers/ROIPooling
layers/Scale
layers/SELU
layers/Sigmoid
layers/SigmoidCrossEntropyLoss
layers/SmoothL1Loss
......
DepthwiseConv2d
===============
.. autoclass:: dragon.vm.caffe.layers.DepthwiseConv2d
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
FusedBatchNorm
==============
.. autoclass:: dragon.vm.caffe.layers.FusedBatchNorm
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
FusedGroupNorm
==============
.. autoclass:: dragon.vm.caffe.layers.FusedGroupNorm
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
GroupNorm
=========
.. autoclass:: dragon.vm.caffe.layers.GroupNorm
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
SELU
====
.. autoclass:: dragon.vm.caffe.layers.SELU
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
color: #103d3e;
}
</style>
......@@ -18,8 +18,8 @@ dragon
`class TensorSpec <dragon/TensorSpec.html>`_
: Spec to describe properties of a tensor.
`class Workspace <dragon/Workspace_.html>`_
: Space to isolate computations that share resources.
`class Workspace <dragon/Workspace.html>`_
: Sandbox to isolate the resources and computations.
Functions
---------
......@@ -151,7 +151,7 @@ dragon
: Return the identity of input with truncated gradient-flow.
`tile(...) <dragon/tile.html>`_
: Tile the input according to the given multiples.
: Tile the input according to the given repeats.
`transpose(...) <dragon/transpose.html>`_
: Permute the dimensions of input.
......@@ -217,7 +217,7 @@ dragon
dragon/tile
dragon/transpose
dragon/where
dragon/Workspace_
dragon/Workspace
dragon/zeros
dragon/zeros_like
......
......@@ -14,10 +14,6 @@ gradient
########
.. automethod:: dragon.GradientTape.gradient
replay
######
.. automethod:: dragon.GradientTape.replay
reset
#####
.. automethod:: dragon.GradientTape.reset
......
......@@ -30,6 +30,10 @@ shape
#####
.. autoattribute:: dragon.Tensor.shape
size
#####
.. autoattribute:: dragon.Tensor.size
Methods
-------
......
Workspace
=========
.. autoclass:: dragon.Workspace
__init__
--------
.. automethod:: dragon.Workspace.__init__
Methods
-------
as_default
##########
.. automethod:: dragon.Workspace.as_default
clear
#####
.. automethod:: dragon.Workspace.clear
merge_from
##########
.. automethod:: dragon.Workspace.merge_from
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
dragon.workspace
================
Workspace
=========
.. only:: html
.. autoclass:: dragon.Workspace
Functions
---------
__init__
--------
.. automethod:: dragon.Workspace.__init__
`feed_tensor(...) <workspace/feed_tensor.html>`_
: Copy the value to tensor.
Methods
-------
`fetch_tensor(...) <workspace/fetch_tensor.html>`_
: Return the value of tensor.
as_default
##########
.. automethod:: dragon.Workspace.as_default
`has_tensor(...) <workspace/has_tensor.html>`_
: Return a bool indicating if tensor is in current workspace.
feed_tensor
###########
.. automethod:: dragon.Workspace.feed_tensor
`load(...) <workspace/load.html>`_
: Load tensors from a binary file.
fetch_tensor
############
.. automethod:: dragon.Workspace.fetch_tensor
`reset_tensor(...) <workspace/reset_tensor.html>`_
: Reset the memory of tensor.
has_tensor
##########
.. automethod:: dragon.Workspace.has_tensor
`run_operator(...) <workspace/run_operator.html>`_
: Run the operators in current workspace.
merge_from
##########
.. automethod:: dragon.Workspace.merge_from
`save(...) <workspace/save.html>`_
: Serialize tensors into a binary file.
.. toctree::
:hidden:
workspace/feed_tensor
workspace/fetch_tensor
workspace/has_tensor
workspace/load
workspace/reset_tensor
workspace/run_operator
workspace/save
reset_tensor
############
.. automethod:: dragon.Workspace.reset_tensor
.. raw:: html
<style>
h1:before {
content: "Module: ";
color: #103d3e;
}
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
feed_tensor
===========
.. autofunction:: dragon.workspace.feed_tensor
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
fetch_tensor
============
.. autofunction:: dragon.workspace.fetch_tensor
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
has_tensor
==========
.. autofunction:: dragon.workspace.has_tensor
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
load
====
.. autofunction:: dragon.workspace.load
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
reset_tensor
============
.. autofunction:: dragon.workspace.reset_tensor
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
run_operator
============
.. autofunction:: dragon.workspace.run_operator
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
save
====
.. autofunction:: dragon.workspace.save
.. raw:: html
<style>
h1:before {
content: "dragon.workspace.";
color: #103d3e;
}
</style>
......@@ -40,7 +40,6 @@ Dragon
* `dragon.nn <dragon/nn.html>`_
* `dragon.optimizers <dragon/optimizers.html>`_
* `dragon.random <dragon/random.html>`_
* `dragon.workspace <dragon/workspace.html>`_
* `dragon.vision <dragon/vision.html>`_
Caffe
......@@ -112,6 +111,7 @@ PyTorch
This style involves the following components:
* `torch <torch.html>`_
* `torch.autograd <torch/autograd.html>`_
* `torch.distributed <torch/distributed.html>`_
* `torch.jit <torch/jit.html>`_
* `torch.nn <torch/nn.html>`_
......@@ -206,15 +206,9 @@ Modules
`Module random <dragon/random.html>`_
: Native API for ``dragon.random`` namespace.
`Module workspace <dragon/workspace.html>`_
: Native API for ``dragon.workspace`` namespace.
`Module vision <dragon/vision.html>`_
: Native API for ``dragon.vision`` namespace.
`Module workspace <dragon/workspace.html>`_
: Native API for ``dragon.workspace`` namespace.
`Module vm.caffe <caffe.html>`_
: Virtual API for ``caffe`` namespace.
......@@ -278,6 +272,9 @@ Modules
`Module vm.torch <torch.html>`_
: Virtual API for ``torch`` namespace.
`Module vm.torch.autograd <torch/autograd.html>`_
: Virtual API for ``torch.autograd`` namespace.
`Module vm.torch.distributed <torch/distributed.html>`_
: Virtual API for ``torch.distributed`` namespace.
......@@ -319,7 +316,6 @@ Modules
dragon/nn
dragon/optimizers
dragon/random
dragon/workspace
dragon/vision
caffe
caffe/layers
......@@ -343,6 +339,7 @@ Modules
tensorrt
tensorrt/backend
torch
torch/autograd
torch/distributed
torch/jit
torch/nn
......
......@@ -15,11 +15,6 @@ gradient
.. automethod:: dragon.GradientTape.gradient
:noindex:
replay
######
.. automethod:: dragon.GradientTape.replay
:noindex:
reset
#####
.. automethod:: dragon.GradientTape.reset
......
vm.torch.autograd
==================
.. only:: html
Functions
---------
`backward(...) <autograd/backward.html>`_
: Compute the derivatives of tensors w.r.t. graph leaves.
.. toctree::
:hidden:
autograd/backward
.. raw:: html
<style>
h1:before {
content: "Module: dragon.";
color: #103d3e;
}
</style>
Cast
====
backward
========
.. autoclass:: dragon.vm.caffe.layers.Cast
.. autofunction:: dragon.vm.torch.autograd.backward
.. raw:: html
<style>
h1:before {
content: "caffe.layers.";
content: "torch.autograd.";
color: #103d3e;
}
</style>
......@@ -7,7 +7,7 @@ all_reduce
<style>
h1:before {
content: "torch.nn.distributed.";
content: "torch.distributed.";
color: #103d3e;
}
</style>
......@@ -7,7 +7,7 @@ broadcast
<style>
h1:before {
content: "torch.nn.distributed.";
content: "torch.distributed.";
color: #103d3e;
}
</style>
......@@ -7,7 +7,7 @@ trace
<style>
h1:before {
content: "torch.nn.jit.";
content: "torch.jit.";
color: #103d3e;
}
</style>
......@@ -10,25 +10,25 @@ __init__
Methods
-------
accumulate_grad
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate_grad
:noindex:
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
:noindex:
:noindex:
step
####
.. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex:
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
:noindex:
:noindex:
.. raw:: html
......
......@@ -10,9 +10,9 @@ __init__
Methods
-------
accumulate_grad
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate_grad
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
add_param_group
###############
......
......@@ -10,25 +10,25 @@ __init__
Methods
-------
accumulate_grad
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate_grad
:noindex:
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
:noindex:
:noindex:
step
####
.. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex:
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
:noindex:
:noindex:
.. raw:: html
......
......@@ -10,25 +10,25 @@ __init__
Methods
-------
accumulate_grad
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate_grad
:noindex:
accumulate
##########
.. automethod:: dragon.vm.torch.optim.Optimizer.accumulate
:noindex:
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
:noindex:
:noindex:
step
####
.. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex:
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
:noindex:
:noindex:
.. raw:: html
......
......@@ -7,7 +7,7 @@ namespace dragon {
GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
: def_(def), ws_(ws), name_(def.name()), phase_("TEST") {
// Scan the defined arguments
// Collect arguments
for (auto& arg : def_.arg()) {
CHECK_GT(arg.name().size(), 0);
CHECK_EQ(args_.count(arg.name()), 0);
......@@ -18,32 +18,31 @@ GraphBase::GraphBase(const GraphDef& def, Workspace* ws)
// Collect outputs
Set<string> outputs;
for (const auto& op : def.op()) {
for (const auto& in : op.input())
CHECK(outputs.count(in) || ws_->HasTensor(in))
<< "\nInput: " << in << " for op: " << op.name() << " is unknown.";
for (const auto& out : op.output())
outputs.insert(out);
for (const auto& input : op.input())
CHECK(outputs.count(input) || ws_->HasTensor(input))
<< "\nThe input <" << input << "> is not in graph.";
for (const auto& output : op.output()) {
outputs.insert(output);
}
}
// Check targets
Set<string> targets;
for (const auto& target : def.output()) {
CHECK(outputs.count(target) || ws_->HasTensor(target))
<< "\nTarget: " << target << " does not exist in the graph.";
<< "\nThe output <" << target << "> is not in graph.";
targets.insert(target);
}
// Check gradients
for (const auto& gradient : def.gradient()) {
const auto& cost = gradient.cost();
const auto& wrt = gradient.wrt();
CHECK(outputs.count(cost) || ws_->HasTensor(cost))
<< "\nTarget: " << cost << "does not exist in the graph.";
CHECK(outputs.count(wrt) || ws_->HasTensor(wrt))
<< "\nTarget: " << wrt << "does not exist in the graph.";
CHECK_GT(targets.count(cost), 0)
<< "\nTo solve d(" << cost << ")/d(" << wrt << "),\n"
<< cost << " should be set as a target.";
for (const auto& grad_info : def.grad_info()) {
const auto& y = grad_info.y();
CHECK_GT(targets.count(y), 0)
<< "\nThe derivative target <" << y << "> is not in outputs.";
for (const auto& x : grad_info.xs()) {
CHECK(outputs.count(x) || ws_->HasTensor(x))
<< "\nThe differentiated input <" << x << "> is not in graph.";
}
}
}
......@@ -54,21 +53,18 @@ bool Graph::Create(const GraphDef& def, Workspace* ws) {
auto op_def(def.op(i));
LOG(DEBUG) << "Create Operator " << op_def.name() << ": " << op_def.type();
// Inherit device option if necessary
if (!op_def.has_device_option() && has_device_option)
if (!op_def.has_device_option() && has_device_option) {
op_def.mutable_device_option()->CopyFrom(def.device_option());
}
Argument arg;
arg.set_name("allow_recomp");
arg.set_i(1);
op_def.add_arg()->CopyFrom(arg);
// For the last operator, enforce the synchronization
if (i == def.op_size() - 1) {
arg.set_name("do_sync");
arg.set_i(1);
op_def.add_arg()->CopyFrom(arg);
}
ops_.push_back(NewOperator(op_def, ws));
// Attatch the output aliases info
ops_.back()->set_output_aliases(output_aliases_);
cached_ops_.push_back(NewOperator(op_def, ws));
cached_ops_.back()->set_output_aliases(output_aliases_);
}
return true;
}
......@@ -80,7 +76,7 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
GraphGradientMaker gradient_maker;
Map<string, vec32_t> subgraph_indices;
int opt = 3; // defaults: O3
if (args().count("optimization_level")) opt = arg("optimization_level").i();
if (args().count("optimization")) opt = arg("optimization").i();
if (opt >= 1) opt_def = graph_optim.PruneNodes(def);
if (opt >= 2) graph_optim.AddInplace(opt_def, output_aliases_);
if (opt >= 3) {
......@@ -101,22 +97,23 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>();
for (const auto& idx : subgraph_indices[it.first])
subgraph[it.first].push_back(ops_[idx]);
subgraph[it.first].push_back(cached_ops_[idx]);
}
for (const auto& op : ops_)
for (auto* op : cached_ops_) {
op->set_subgraph(subgraph);
}
}
}
bool Graph::Run(const string& incl, const string& excl, int stream_id) {
bool Graph::Run(const string& include, const string& exclude, int stream) {
LOG(DEBUG) << "Run Graph: " << name();
for (auto op : ops_) {
if (!incl.empty() && !str::find(op->type(), incl)) continue;
if (!excl.empty() && str::find(op->type(), excl)) continue;
for (auto* op : cached_ops_) {
if (!include.empty() && !str::find(op->type(), include)) continue;
if (!exclude.empty() && str::find(op->type(), exclude)) continue;
op->SwitchToPhase(phase());
LOG(DEBUG) << "$ Before Operator: " << op->name();
op->Run(stream_id);
LOG(DEBUG) << "$ After Operator: " << op->name();
LOG(DEBUG) << "Run Op: " << op->name();
op->Run(stream);
LOG(DEBUG) << "Finish Op: " << op->name();
}
return true;
}
......
......@@ -88,8 +88,8 @@ class Graph : public GraphBase {
/*! \brief Default Destructor */
virtual ~Graph() {
for (auto* op : ops_) {
delete op;
for (auto* cached_op : cached_ops_) {
delete cached_op;
}
}
......@@ -100,8 +100,8 @@ class Graph : public GraphBase {
bool Run(const string&, const string&, int = 0) override;
protected:
/*! \brief Store the internal operators */
vector<OperatorBase*> ops_;
/*! \brief The cached operators */
vector<OperatorBase*> cached_ops_;
/*! \brief Store the candidate output aliases */
Map<string, Set<string>> output_aliases_;
......
......@@ -4,23 +4,24 @@
namespace dragon {
bool GraphGradientMaker::CheckGrad(
const OperatorDef& forward_op,
const OperatorDef& op_def,
const Set<string>& targets,
vector<pair<string, int>>& gen_grads) {
if (NoGradientRegistry()->Has(forward_op.type())) {
for (auto& input : forward_op.input())
if (NoGradientRegistry()->Has(op_def.type())) {
for (auto& input : op_def.input()) {
blacklist_set_.insert(input);
}
return true;
}
for (int i = 0; i < forward_op.output_size(); ++i) {
const auto& output = forward_op.output(i);
for (int i = 0; i < op_def.output_size(); ++i) {
const auto& output = op_def.output(i);
if (!inputs_to_grads_.count(output)) {
if (blacklist_set_.count(output)) return true;
if (targets.count(output)) {
// Consider to generate virtual gradient for targets
gen_grads.push_back({output, i});
inputs_to_grads_[output] = output + "_grad";
} else if (forward_op.output_size() == 1) {
} else if (op_def.output_size() == 1) {
return true; // We can skip this op, obviously
}
}
......@@ -30,7 +31,7 @@ bool GraphGradientMaker::CheckGrad(
}
void GraphGradientMaker::Make(
const vector<OperatorDef*>& forward_ops,
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& input_grads,
GraphDef& backward_def) {
......@@ -39,11 +40,11 @@ void GraphGradientMaker::Make(
Map<string, string> targets_to_grads;
// PLAY for the forward
for (auto* op : forward_ops) {
if (NoGradientRegistry()->Has(op->type())) continue;
for (const auto& input : op->input()) {
for (auto* op_def : op_defs) {
if (NoGradientRegistry()->Has(op_def->type())) continue;
for (const auto& input : op_def->input()) {
bool input_in_outputs = false;
for (auto& output : op->output())
for (auto& output : op_def->output())
if (output == input) {
input_in_outputs = true;
break;
......@@ -62,9 +63,9 @@ void GraphGradientMaker::Make(
targets_set.insert(targets[i]);
}
for (int op_idx = (int)forward_ops.size() - 1; op_idx >= 0; --op_idx) {
for (int op_idx = (int)op_defs.size() - 1; op_idx >= 0; --op_idx) {
// Collect inputs and outputs, generate raw gradient ops
const OperatorDef& op = *forward_ops[op_idx];
const OperatorDef& op = *op_defs[op_idx];
vector<pair<string, int>> gen_grads;
bool is_skip = CheckGrad(op, targets_set, gen_grads);
vector<string> g_outputs;
......@@ -183,9 +184,9 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
// Flag the gathering gradients
if (op.type() == "GradientGather") {
invalid_ops.insert(op_idx);
if (ignored_grads_.count(op.output(0))) {
if (empty_grads_.count(op.output(0))) {
for (const auto& input : op.input()) {
ignored_grads_.insert(input);
empty_grads_.insert(input);
}
continue;
} else {
......@@ -200,7 +201,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
}
// Count the references to detect leafs
for (const auto& input : op.input()) {
if (str::find(input, "grad")) {
if (str::endswith(input, "_grad")) {
ref_count[input] += 1;
}
}
......@@ -293,21 +294,17 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
// Rewrite output gradients
for (int i = 0; i < op->output_size(); ++i) {
if (str::startswith(op->type(), "Python")) continue;
const string& output = op->output(i);
if (output.empty() || str::startswith(output, "/share/")) continue;
if (ignored_grads_.count(output) > 0) {
// Prune for non-trainable leafs
if (output.empty() || str::startswith(output, "/share/buffer")) continue;
if (empty_grads_.count(output) > 0) {
*op->mutable_output(i) = "";
continue;
}
if (hooked_grads_.empty()) {
// Protection for leafs
if (ref_count.count(output) == 0) continue;
} else {
// Protection for sources
if (hooked_grads_.count(output) > 0) continue;
}
if (op->type() == "PythonPluginGradient") continue;
// Protection for leafs
if (ref_count.count(output) == 0) continue;
// Protection for sources and leafs
if (retained_grads_.count(output) > 0) continue;
string new_output = output;
if (inplace_flags[i] >= 0) {
new_output = op->input(inplace_flags[i]);
......
......@@ -21,22 +21,22 @@ class DRAGON_API GraphGradientMaker {
public:
/*! \brief Generate a backward graph from the forward ops */
void Make(
const vector<OperatorDef*>& forward_ops,
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& input_grads,
GraphDef& backward_def);
GraphDef& graph_def);
/*! \brief Rewrite a graph to share the intermediate grads */
GraphDef Share(const GraphDef& input_def);
/*! \brief Add a hooked gradient */
void add_hooked_grad(const string& name) {
hooked_grads_.insert(name);
/*! \brief Add an empty gradient */
void add_empty_grad(const string& name) {
empty_grads_.insert(name);
}
/*! \brief Add an ignored gradient */
void add_ignored_grad(const string& name) {
ignored_grads_.insert(name);
/*! \brief Add a retained gradient */
void add_retained_grad(const string& name) {
retained_grads_.insert(name);
}
/*! \brief Set the prefix of backward op name */
......@@ -47,32 +47,32 @@ class DRAGON_API GraphGradientMaker {
private:
/*! \brief Check the missing grads of backward procedure */
bool CheckGrad(
const OperatorDef& forward_op,
const OperatorDef& op_def,
const Set<string>& targets,
vector<pair<string, int>>& gen_grads);
/*! \brief Return a dummy operator name */
string GetOperatorName() {
if (op_prefix_.empty()) return "Generic";
if (op_prefix_.empty()) return "GradientOp";
return op_prefix_ + str::to(op_index_++);
}
/*! \brief Store the mapping of intermediate grads */
/*! \brief The mapping from input to grad */
Map<string, string> inputs_to_grads_;
/*! \brief Store the non-gradient outputs */
/*! \brief The non-gradient outputs */
Set<string> blacklist_set_;
/*! \brief Store the non-shared gradients */
Set<string> hooked_grads_;
/*! \brief The gradients should be retained */
Set<string> retained_grads_;
/*! \brief Store the gradients that are not required */
Set<string> ignored_grads_;
/*! \brief The gradients should be set to empty */
Set<string> empty_grads_;
/*! \brief Store the prefix of dummy operator name */
/*! \brief The prefix of op name */
string op_prefix_;
/*! \brief Store the counter of dummy operator name */
/*! \brief The counter of op name */
int64_t op_index_ = 0;
};
......
......@@ -39,14 +39,12 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) {
BackwardPrunePass(target);
}
// Forward pass from gradients
for (const auto& gradient : input_def.gradient()) {
auto u = gradient.cost() + "_grad";
auto v = gradient.wrt() + "_grad";
if (ws_->HasTensor(u)) u = ws_->GetTensor(u)->name();
if (ws_->HasTensor(v)) v = ws_->GetTensor(v)->name();
visited_.clear();
ForwardPrunePass(u, v, vector<string>({u}));
for (const auto& grad_info : input_def.grad_info()) {
const auto u = grad_info.y() + "_grad";
for (const auto& x : grad_info.xs()) {
visited_.clear();
ForwardPrunePass(u, x + "_grad", std::deque<string>({u}));
}
}
// Select all colored operators
......@@ -64,7 +62,6 @@ GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) {
// Generate the final op sequence
map<int, OperatorDef> final_sequence;
for (auto op_idx : selected_op_indices) {
const auto& op = input_def.op(op_idx);
auto new_op(input_def.op(op_idx));
......@@ -308,11 +305,13 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) {
void GraphOptimizer::ForwardPrunePass(
const string& u,
const string& leaf,
const vector<string>& path) {
const std::deque<string>& path) {
if (visited_.count(u)) {
if (visited_[u])
for (const auto& node : path)
if (visited_[u]) {
for (const auto& node : path) {
visited_[node] = colored_[node] = true;
}
}
return;
}
visited_[u] = false;
......@@ -321,8 +320,9 @@ void GraphOptimizer::ForwardPrunePass(
auto new_path(path);
new_path.push_back(v);
if (v == leaf) {
for (const auto& node : new_path)
for (const auto& node : new_path) {
visited_[node] = colored_[node] = true;
}
return;
}
ForwardPrunePass(v, leaf, new_path);
......
......@@ -56,7 +56,7 @@ class GraphOptimizer {
void ForwardPrunePass(
const string& u,
const string& leaf,
const vector<string>& path);
const std::deque<string>& path);
/*! \brief Pass from targets to remove unused nodes */
void BackwardPrunePass(const string& v);
......
......@@ -41,14 +41,11 @@ OperatorBase::OperatorBase(const OperatorDef& def, Workspace* ws)
}
}
template <class Context>
Operator<Context>::Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws),
ctx_(def.device_option()),
do_sync_(OpArg<bool>("do_sync", false)),
allow_recomp_(OpArg<bool>("allow_recomp", false)) {
allow_run_ = (!(OutputSize() == 1 && !Output(0)->has_name()));
}
// template <class Context>
// Operator<Context>::Operator(const OperatorDef& def, Workspace* ws)
// : OperatorBase(def, ws),
// ctx_(def.device_option()),
// do_sync_(OpArg<bool>("do_sync", false)) {}
Tensor& OperatorBase::Input(int i) {
CHECK_LT(i, (int)inputs_.size());
......@@ -112,32 +109,32 @@ OperatorBase* OperatorBase::UpdateFrom(const OperatorDef& def) {
handle_ = def.name();
inputs_.resize(def.input_size());
outputs_.resize(def.output_size());
for (int i = 0; i < inputs_.size(); i++)
for (int i = 0; i < inputs_.size(); i++) {
inputs_[i] = ws()->GetTensor(def.input(i));
for (int i = 0; i < outputs_.size(); i++)
}
for (int i = 0; i < outputs_.size(); i++) {
outputs_[i] = ws()->CreateTensor(def.output(i));
}
return this;
}
template <class Context>
void Operator<Context>::Prepare() {
string tensor_name;
size_t ver_pos;
int version;
for (int i = 0; i < InputSize(); i++) {
if (Input(i).version() >= 0) {
tensor_name = def().input(i);
ver_pos = tensor_name.find("/ver:");
version = std::atoi(tensor_name.substr(ver_pos + 5).c_str());
const auto& name = def().input(i);
auto ver_pos = name.find("/ver:");
auto version = std::atoi(name.substr(ver_pos + 5).c_str());
if (version == Input(i).version()) continue;
LOG(DEBUG) << "Excepted version of Tensor(" + Input(i).name() + ") "
<< "is " << version << ", got " << Input(i).version()
<< ". Recompute.";
Tensor* flag = ws()->GetTensor("/share/flag/recomputing");
flag->mutable_data<bool, CPUContext>()[0] = true;
vector<OperatorBase*>& chain = subgraph()[tensor_name];
for (auto* op : chain)
vector<OperatorBase*>& chain = subgraph()[name];
for (auto* op : chain) {
op->Run(ctx()->stream_id());
}
flag->mutable_data<bool, CPUContext>()[0] = false;
}
}
......@@ -145,14 +142,11 @@ void Operator<Context>::Prepare() {
template <class Context>
void Operator<Context>::Release() {
string tensor_name;
size_t ver_pos;
int version;
for (int i = 0; i < OutputSize(); i++) {
if (Output(i)->version() >= 0) {
tensor_name = def().output(i);
ver_pos = tensor_name.find("/ver:");
version = std::atoi(tensor_name.substr(ver_pos + 5).c_str());
const auto& name = def().output(i);
auto ver_pos = name.find("/ver:");
auto version = std::atoi(name.substr(ver_pos + 5).c_str());
Output(i)->set_version(version);
}
}
......@@ -195,8 +189,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) {
OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) {
auto* schema = OpSchemaRegistry::Schema(def.type());
if (schema) {
// Check the Inputs and Outputs if necessary
if (schema != nullptr) {
CHECK(schema->Verify(def))
<< "\nOperator failed to pass the schema checking.";
}
......@@ -219,7 +212,7 @@ Gradient MakeGradientForOp(
<< "not implemented.";
Gradient grad = maker->Make();
OperatorDef reference_def(def);
// Map the cache key
// Set the cache key
if (reference_def.has_cache_key()) {
for (int i = 0; i < grad.ops.size(); ++i) {
grad.ops[i].set_cache_key(
......
......@@ -40,7 +40,7 @@ class DRAGON_API OperatorBase {
}
/*! \brief Run operator on the specified stream */
virtual void Run(int stream_id = 0) {
virtual void Run(int stream = 0) {
NOT_IMPLEMENTED;
}
......@@ -154,12 +154,12 @@ class DRAGON_API OperatorBase {
}
/*! \brief Set the output aliases for in-place */
void set_output_aliases(const Map<string, Set<string>>& aliases_map) {
void set_output_aliases(const Map<string, Set<string>>& alias_map) {
output_aliases_.resize(outputs_.size());
for (int i = 0; i < outputs_.size(); ++i) {
auto aliases_iter = aliases_map.find(outputs_[i]->name());
if (aliases_iter != aliases_map.end()) {
output_aliases_[i] = aliases_iter->second;
const auto& it = alias_map.find(outputs_[i]->name());
if (it != alias_map.end()) {
output_aliases_[i] = it->second;
} else {
output_aliases_[i].clear();
}
......@@ -196,7 +196,10 @@ template <class Context>
class DRAGON_API Operator : public OperatorBase {
public:
/*! \brief Default constructor */
Operator(const OperatorDef& def, Workspace* ws);
Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws),
ctx_(def.device_option()),
do_sync_(OperatorBase::Arg<bool>("do_sync", false)) {}
/*! \brief Prepare the content of inputs */
virtual void Prepare();
......@@ -207,36 +210,32 @@ class DRAGON_API Operator : public OperatorBase {
/*! \brief Coordinate the context of inputs and outputs */
virtual void SwitchToDevice();
/*! \brief Implement the detailed execution */
/*! \brief The detailed execution on device */
virtual void RunOnDevice() = 0;
/*! \brief Run this operator on the specified stream */
void Run(int stream_id = 0) final {
if (!allow_run_) return;
if (allow_recomp_) Prepare();
ctx()->SwitchToDevice(stream_id);
/*! \brief Run this operator */
void Run(int stream = 0) final {
Prepare();
ctx()->SwitchToDevice(stream);
SwitchToDevice();
RunOnDevice();
if (do_sync_ || stream_id > 0) {
if (do_sync_ || stream > 0) {
ctx()->FinishDeviceComputation();
}
if (allow_recomp_) Release();
Release();
}
/*! \brief Return a bool indicating the run is available */
bool allow_run() const {
return allow_run_;
}
/*! \brief Return the internal context */
/*! \brief Return the context */
Context* ctx() {
return &ctx_;
}
protected:
/*! \brief Store the internal context */
/*! \brief The context */
Context ctx_;
bool do_sync_, allow_run_, allow_recomp_;
/*! \brief The executing flags */
bool do_sync_;
};
/*! \brief New a operator from the raw def */
......@@ -266,9 +265,8 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
using OperatorBase::def; \
using OperatorBase::ws
#define USE_OPERATOR_FUNCTIONS \
USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<Context>::allow_run; \
#define USE_OPERATOR_FUNCTIONS \
USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<Context>::ctx
#define STORE_INPUT_SPEC(i) \
......@@ -342,46 +340,46 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
/* Fillers */
#define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \
if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "may be specify a filler for it?"; \
tensor.Reshape(shape); \
unique_ptr<Filler<type, Context>> filler( \
CreateFiller<type, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \
} else { \
int64_t count = 1; \
for (int i = 0; i < shape.size(); i++) \
count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \
<< "\nExcepted Tensor(" << tensor.name() << ")'s " \
<< "size is " << count << ", \n" \
<< "but now is " << tensor.count() << ", " \
<< "did you feed the incorrect data before?"; \
tensor.Reshape(shape); \
#define TENSOR_FILL_WITH_TYPE(tensor, shape, type) \
if (tensor.count() == 0) { \
auto* filler_info = ws()->GetFillerInfo(tensor.name()); \
CHECK(filler_info) << "\nTensor(" << tensor.name() << ") is empty.\n" \
<< "May be specify a filler for it?"; \
tensor.Reshape(shape); \
unique_ptr<Filler<type, Context>> filler( \
CreateFiller<type, Context>(*filler_info)); \
filler->Fill(&tensor, ctx()); \
} else { \
int64_t count = 1; \
for (int i = 0; i < shape.size(); i++) \
count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \
<< "\nExcepted Tensor(" << tensor.name() << ")'s " \
<< "size is " << count << ", \n" \
<< "but now is " << tensor.count() << ", " \
<< "did you feed the incorrect data before?"; \
tensor.Reshape(shape); \
}
#define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \
<< "Maybe specify a filler for it?"; \
tensor.Reshape(shape); \
unique_ptr<Filler<T, Context>> filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor, ctx()); \
} else { \
int64_t count = 1; \
for (int i = 0; i < shape.size(); i++) \
count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \
<< "\nExcepted Tensor(" << tensor.name() << ")'s " \
<< "size is " << count << ", \n" \
<< "but now is " << tensor.count() << ", " \
<< "did you feed the incorrect data before?"; \
tensor.Reshape(shape); \
#define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \
auto* filler_info = ws()->GetFillerInfo(tensor.name()); \
CHECK(filler_info) << "\nTensor(" << tensor.name() << ") is empty.\n" \
<< "May be specify a filler for it?"; \
tensor.Reshape(shape); \
unique_ptr<Filler<T, Context>> filler( \
CreateFiller<T, Context>(*filler_info)); \
filler->Fill(&tensor, ctx()); \
} else { \
int64_t count = 1; \
for (int i = 0; i < shape.size(); i++) \
count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \
<< "\nExcepted Tensor(" << tensor.name() << ")'s " \
<< "size is " << count << ", \n" \
<< "but now is " << tensor.count() << ", " \
<< "did you feed the incorrect data before?"; \
tensor.Reshape(shape); \
}
/* Arguments */
......
......@@ -4,176 +4,150 @@
namespace dragon {
vector<string> Workspace::tensors() const {
vector<string> locals;
// Search the local workspace
for (const auto& it : tensor_map_)
locals.push_back(it.first);
// Search the remote workspaces
for (const auto& it : external_tensor_map_) {
locals.push_back(it.first);
}
return locals;
}
vector<string> Workspace::graphs() const {
vector<string> names;
for (const auto& it : graph_map_) {
names.push_back(it.first);
}
return names;
}
void Workspace::Initialize() {
Workspace::Workspace(const string& name) : name_(name) {
CreateTensor(""); // Empty placeholder
CreateTensor("/share/flag/recomputing")
->Reshape({1})
->Reshape({})
->mutable_data<bool, CPUContext>()[0] = false;
}
void Workspace::Clear() {
// Remove and Initialize again
tensor_map_.clear();
Initialize();
}
void Workspace::MergeFrom(Workspace* ws) {
CHECK(ws) << "\nThe given Workspace is invalid.";
for (const auto& it : ws->tensor_map_) {
if (!it.first.empty() && !str::startswith(it.first, "/")) {
external_tensor_map_[it.first] = it.second.get();
void Workspace::MergeFrom(Workspace* other) {
if (other != nullptr) {
// Add the external tensors
for (const auto& it : other->tensor_map_) {
if (!it.first.empty() && !str::startswith(it.first, "/")) {
external_tensor_map_[it.first] = it.second.get();
}
}
// Recount the unique index to avoid duplicate names
for (const auto& i : other->unique_index_map_) {
auto& index_map = unique_index_map_[i.first];
for (const auto& j : i.second) {
index_map[j.first] = std::max(index_map[j.first], j.second);
}
}
}
}
string Workspace::GetTensorName(const string& name) const {
const auto& it = alias_active_map_.find(name);
if (it != alias_active_map_.end()) return it->second;
return name;
}
Tensor* Workspace::TryGetTensor(const string& name, bool use_remote) const {
// Check the proxy of this tensor firstly
string query = GetTensorName(name);
// Search the local workspace
const auto& it = tensor_map_.find(query);
Tensor* Workspace::TryGetTensor(const string& name, bool external) const {
// Check the alias firstly
const auto& alias_it = alias_map_.find(name);
auto name_v2 = alias_it != alias_map_.end() ? alias_it->second : name;
// Search this workspace
const auto& it = tensor_map_.find(name_v2);
if (it != tensor_map_.end()) return it->second.get();
if (use_remote) {
// Search the remote workspaces
const auto& it = external_tensor_map_.find(query);
if (external) {
// Search external workspaces
const auto& it = external_tensor_map_.find(name_v2);
if (it != external_tensor_map_.end()) return it->second;
}
return nullptr;
}
Tensor* Workspace::CreateTensor(const string& name) {
Tensor* tensor = TryGetTensor(name);
if (!tensor) {
tensor_map_[name] = unique_ptr<Tensor>(new Tensor(name));
return tensor_map_[name].get();
Tensor* Workspace::CreateTensor(const string& name, FillerInfo* filler) {
auto* tensor = TryGetTensor(name);
// Create only if name not existed
if (tensor == nullptr) {
tensor = new Tensor(name);
tensor_map_[name] = unique_ptr<Tensor>(tensor);
}
// Maybe bind it with a filler
if (filler != nullptr) {
filler_map_[tensor->name()] = std::move(FillerInfo(*filler));
}
return tensor;
}
Tensor* Workspace::GetTensor(const string& name, bool use_remote) const {
Tensor* tensor = TryGetTensor(name, use_remote);
CHECK(tensor) << "\nTensor(" << name << ") does not "
<< "exist in current workspace.";
Tensor* Workspace::GetTensor(const string& name, bool external) const {
auto* tensor = TryGetTensor(name, external);
CHECK(tensor) << "\nTensor(" << name << ") is not in current workspace.";
return tensor;
}
void Workspace::ResetTensor(const string& name) {
Tensor* tensor = TryGetTensor(name, false);
CHECK(tensor) << "\nTensor(" << name << ") does not "
<< "belong to current workspace.";
auto* tensor = TryGetTensor(name, false);
CHECK(tensor) << "\nTensor(" << name << ") is not in current workspace.";
tensor->Reset();
}
bool Workspace::HasFiller(const string& name) const {
return tensor_filler_map_.count(name) > 0;
}
void Workspace::CreateFiller(const TensorFillerProto& filler) {
CHECK_GT(filler.tensor().size(), 0)
<< "\nTensor with an empty name can not be filled.";
if (HasFiller(filler.tensor())) return;
tensor_filler_map_[filler.tensor()] = filler;
}
TensorFillerProto* Workspace::GetFiller(const string& name) {
const auto& it = tensor_filler_map_.find(name);
if (it != tensor_filler_map_.end()) return &it->second;
FillerInfo* Workspace::GetFillerInfo(const string& name) {
const auto& it = filler_map_.find(name);
if (it != filler_map_.end()) return &it->second;
return nullptr;
}
OperatorBase* Workspace::CreateOperator(const OperatorDef& def) {
const auto& it = operator_map_.find(def.cache_key());
if (it == operator_map_.end()) {
auto* new_op = NewOperator(def, this);
operator_map_[def.cache_key()] = unique_ptr<OperatorBase>(new_op);
return new_op;
}
return it->second.get();
}
void Workspace::RunOperator(const OperatorDef& def) {
if (def.has_cache_key()) {
CreateOperator(def)->UpdateFrom(def)->Run(0);
OperatorBase* cached_op = nullptr;
const auto& it = operator_map_.find(def.cache_key());
if (it == operator_map_.end()) {
cached_op = NewOperator(def, this);
operator_map_[def.cache_key()] = unique_ptr<OperatorBase>(cached_op);
} else {
cached_op = it->second.get();
}
cached_op->UpdateFrom(def)->Run();
} else {
unique_ptr<OperatorBase> op(NewOperator(def, this));
op->Run(0);
OperatorBase* temporal_op = NewOperator(def, this);
temporal_op->Run();
delete temporal_op;
}
}
GraphBase* Workspace::CreateGraph(const GraphDef& def) {
CHECK(def.has_name()) << "\nGraph name is missing.";
auto name = GetDummyName(def.name(), "", "Graph", false);
LOG(DEBUG) << "Create Graph: " << name << "(" << def.name() << ")";
GraphDef _def(def);
_def.set_name(name);
graph_map_[name] = unique_ptr<GraphBase>(NewGraph(_def, this));
return graph_map_[name].get();
CHECK(def.has_name()) << "\nExcepted non-empty graph name.";
GraphDef def_v2(def); // Copy to set an unique name
def_v2.set_name(UniqueName(def.name(), "", "Graph", false));
LOG(DEBUG) << "Create Graph: " << def_v2.name() << "(" << def.name() << ")";
auto* cached_graph = NewGraph(def_v2, this);
graph_map_[def_v2.name()] = unique_ptr<GraphBase>(cached_graph);
return cached_graph;
}
void Workspace::RunGraph(
const string& graph_name,
const string& incl,
const string& excl,
int stream_id) {
if (!graph_map_.count(graph_name)) {
LOG(FATAL) << "Graph(" << graph_name << ") does not exist.";
}
graph_map_[graph_name]->Run(incl, excl, stream_id);
const string& name,
const string& include,
const string& exclude,
const int stream) {
CHECK(graph_map_.count(name))
<< "\nGraph(" << name << ") is not in current workspace.";
graph_map_[name]->Run(include, exclude, stream);
}
bool Workspace::ActivateAlias(const string& name, const string& alias) {
bool status = alias_active_map_.count(alias) > 0;
alias_active_map_[alias] = name;
return status; // True if activated otherwise false
void Workspace::RegisterAlias(const string& target, const string& alias) {
alias_map_[alias] = target;
}
string Workspace::GetDummyName(
const string& base_name,
string Workspace::UniqueName(
const string& name,
const string& suffix,
const string& domain,
const string& scope,
bool zero_based) {
string accepted_name;
int64_t index;
const auto required_name = base_name + suffix;
auto& dmap = dummy_name_map_[domain];
while (1) {
index = dmap[required_name]++;
accepted_name = index ? base_name + "_" + str::to(index) + suffix
: zero_based
? required_name
: base_name + "_" + str::to(dmap[required_name]++) + suffix;
if (external_tensor_map_.empty()) break;
if (!HasTensor(accepted_name)) break;
auto& index_map = unique_index_map_[scope];
auto required_name = name + suffix;
auto index = index_map[required_name]++;
if (index > 0) return name + "_" + str::to(index) + suffix;
if (zero_based) return required_name;
return name + "_" + str::to(index_map[required_name]++) + suffix;
}
vector<string> Workspace::tensors() const {
vector<string> names;
for (const auto& it : tensor_map_) {
names.push_back(it.first);
}
return accepted_name;
for (const auto& it : external_tensor_map_) {
names.push_back(it.first);
}
return names;
}
vector<string> Workspace::graphs() const {
vector<string> names;
for (const auto& it : graph_map_) {
names.push_back(it.first);
}
return names;
}
} // namespace dragon
......@@ -20,83 +20,63 @@ namespace dragon {
class Workspace {
public:
/*! \brief Constructor */
explicit Workspace(const string& name) : name_(name) {
Initialize();
}
/*! \brief Create some internal tensors */
DRAGON_API void Initialize();
DRAGON_API explicit Workspace(const string& name);
/*! \brief Merge tensors from a external workspace */
/*! \brief Merge resources from other */
DRAGON_API void MergeFrom(Workspace*);
/*! \brief Destory all the tensors */
DRAGON_API void Clear();
/* \brief Return a unique dummy name within this workspace */
DRAGON_API string GetDummyName(
const string& base_name,
/* \brief Return an unique name */
DRAGON_API string UniqueName(
const string& name,
const string& suffix,
const string& domain = "",
bool zero_based = true);
const string& scope = "",
const bool zero_based = false);
/*! \brief Whether the specified tensor is in this workspace */
DRAGON_API bool HasTensor(const string& name, bool use_remote = true) const {
return TryGetTensor(name, use_remote) ? true : false;
}
/* \brief Register an alias for the target */
DRAGON_API void RegisterAlias(const string& target, const string& alias);
/*! \brief Query the real name of specified tensor */
DRAGON_API string GetTensorName(const string&) const;
/* \brief Activate an alias for the target */
DRAGON_API bool ActivateAlias(const string& name, const string& alias);
/*! \brief Return whether tensor is existing */
DRAGON_API bool HasTensor(const string& name, bool external = true) const {
return TryGetTensor(name, external) == nullptr ? false : true;
}
/*! \brief Create a tensor in this workspace */
DRAGON_API Tensor* CreateTensor(const string&);
/*! \brief Create the tensor */
DRAGON_API Tensor* CreateTensor(const string&, FillerInfo* = nullptr);
/*! \brief Try to search the specified tensor in this workspace */
/*! \brief Try to return the tensor */
DRAGON_API Tensor* TryGetTensor(const string&, bool = true) const;
/*! \brief Return the specified tensor */
/*! \brief Return the tensor */
DRAGON_API Tensor* GetTensor(const string&, bool = true) const;
/*! \brief Reset the specified tensor */
/*! \brief Reset the tensor */
DRAGON_API void ResetTensor(const string&);
/* \brief Whether the specified filler is existing */
DRAGON_API bool HasFiller(const string&) const;
/*! \brief Create a filler in this workspace */
DRAGON_API void CreateFiller(const TensorFillerProto&);
/*! \brief Return the filler info */
DRAGON_API FillerInfo* GetFillerInfo(const string&);
/*! \brief Return the specified filler */
DRAGON_API TensorFillerProto* GetFiller(const string&);
/*! \brief Create an operator in this workspace */
DRAGON_API OperatorBase* CreateOperator(const OperatorDef&);
/*! \brief Run an operator in this workspace */
/*! \brief Run the operator */
DRAGON_API void RunOperator(const OperatorDef&);
/*! \brief Create a graph in this workspace */
/*! \brief Create the graph */
DRAGON_API GraphBase* CreateGraph(const GraphDef&);
/*! \brief Run the specifed graph by name and rules */
/*! \brief Run the graph */
DRAGON_API void RunGraph(
const string& graph_name,
const string& incl = "",
const string& excl = "",
int stream_id = 0);
const string& include = "",
const string& exclude = "",
const int stream = 0);
/*! \brief Return the name of this workspace */
/*! \brief Return the workspace name */
const string& name() {
return name_;
}
/*! \brief Return the name of stored tensors */
/*! \brief Return the name of cached tensors */
DRAGON_API vector<string> tensors() const;
/*! \brief Return the name of stored graphs */
/*! \brief Return the name of cached graphs */
DRAGON_API vector<string> graphs() const;
/*! \brief Provide a group of the shared byte data */
......@@ -127,28 +107,28 @@ class Workspace {
}
private:
/*! \brief The unique workspace name */
/*! \brief The workspace name */
string name_;
/*! \brief The dummy name indices */
Map<string, Map<string, int64_t>> dummy_name_map_;
/*! \brief The external tensors */
Map<string, Tensor*> external_tensor_map_;
/*! \brief Store the created tensors */
Map<string, unique_ptr<Tensor>> tensor_map_;
/*! \brief The unique indices */
Map<string, Map<string, int64_t>> unique_index_map_;
/*! \brief Store the external tensors */
Map<string, Tensor*> external_tensor_map_;
/*! \brief The registered fillers */
Map<string, FillerInfo> filler_map_;
/*! \brief Store the registered tensor fillers */
Map<string, TensorFillerProto> tensor_filler_map_;
/*! \brief The registered aliases */
Map<string, string> alias_map_;
/*! \brief Store the active aliases */
Map<string, string> alias_active_map_;
/*! \brief The cached tensors */
Map<string, unique_ptr<Tensor>> tensor_map_;
/*! \brief Store the registered operators for dynamic graph */
/*! \brief The cached operators */
Map<string, unique_ptr<OperatorBase>> operator_map_;
/*! \brief Store the registered graphs for static graph */
/*! \brief The cached graphs */
Map<string, unique_ptr<GraphBase>> graph_map_;
};
......
......@@ -425,7 +425,7 @@ void PReluWGrad<float16, CUDAContext>(
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
N * C * S,
N * S,
C,
S,
reinterpret_cast<const half*>(dy),
......@@ -437,7 +437,7 @@ void PReluWGrad<float16, CUDAContext>(
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
N * C * S,
N * S,
C,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(x),
......@@ -536,13 +536,13 @@ void PReluWGrad<float16, CUDAContext>(
CUDA_2D_BLOCKS(C), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(N * C * S, C, S, dy, x, dw); \
ctx->cuda_stream()>>>(N * S, C, S, dy, x, dw); \
} else if (data_format == "NHWC") { \
_PReluWGradNHWC<<< \
CUDA_2D_BLOCKS(C), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(N * C * S, C, dy, x, dw); \
ctx->cuda_stream()>>>(N * S, C, dy, x, dw); \
} else { \
LOG(FATAL) << "Unknown data format: " << data_format; \
} \
......
......@@ -25,7 +25,6 @@
#include "dragon/core/workspace.h"
#include "dragon/modules/python/types.h"
#include "dragon/onnx/onnx_backend.h"
#include "dragon/utils/caffemodel.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
......
......@@ -74,7 +74,7 @@ class DLPackWrapper {
}
Tensor* From(py::object obj) {
CHECK(PyCapsule_CheckExact(obj.ptr())) << "\nExpected DLPack capsule";
CHECK(PyCapsule_CheckExact(obj.ptr())) << "\nExpected DLPack capsule.";
auto* managed_tensor =
(DLManagedTensor*)PyCapsule_GetPointer(obj.ptr(), "dltensor");
CHECK(managed_tensor) << "\nInvalid DLPack capsule";
......
......@@ -44,48 +44,38 @@ PYBIND11_MODULE(libdragon_python, m) {
/*! \brief Return the name of stored graphs */
.def_property_readonly("graphs", &Workspace::graphs)
/*! \brief Destory all the tensors */
.def("Clear", &Workspace::Clear)
/*! \brief Merge a external workspace into self */
/*! \brief Merge resources from another workspace */
.def("MergeFrom", &Workspace::MergeFrom)
/*! \brief Return a unique dummy name */
.def("GetDummyName", &Workspace::GetDummyName)
/*! \brief Return the unique name of given tensor */
.def("GetTensorName", &Workspace::GetTensorName)
/*! \brief Return an unique name */
.def("UniqueName", &Workspace::UniqueName)
/*! \brief Reset a tensor with the given name */
/*! \brief Reset the tensor */
.def("ResetTensor", &Workspace::ResetTensor)
/*! \brief Indicate whether the given tensor is existing */
/*! \brief Return whether the tensor is existing */
.def(
"HasTensor",
[](Workspace* self, const string& name) {
return self->HasTensor(name);
})
/*! \brief Create a tensor with the given name */
/*! \brief Create the tensor */
.def(
"CreateTensor",
[](Workspace* self, const string& name) {
[](Workspace* self, const string& name, const string& filler_str) {
if (!filler_str.empty()) {
FillerInfo filler_info;
if (!filler_info.ParseFromString(filler_str)) {
LOG(FATAL) << "Failed to parse the FillerInfo.";
}
return self->CreateTensor(name, &filler_info);
}
return self->CreateTensor(name);
},
py::return_value_policy::reference_internal)
/*! \brief Create a tensor from the specified filler */
.def(
"CreateFiller",
[](Workspace* self, const string& serialized) {
TensorFillerProto filler_proto;
if (!filler_proto.ParseFromString(serialized))
LOG(FATAL) << "Failed to parse the TensorFiller.";
self->CreateFiller(filler_proto);
self->CreateTensor(filler_proto.tensor());
})
/*! \brief Return the CXX Tensor reference */
/*! \brief Return the tensor */
.def(
"GetTensor",
[](Workspace* self, const string& name) {
......@@ -93,11 +83,11 @@ PYBIND11_MODULE(libdragon_python, m) {
},
py::return_value_policy::reference_internal)
/* \brief Set an alias for the tensor */
/* \brief Register an alias for the name */
.def(
"SetTensorAlias",
"RegisterAlias",
[](Workspace* self, const string& name, const string& alias) {
return self->ActivateAlias(name, alias);
return self->RegisterAlias(name, alias);
})
/*! \brief Copy the array data to tensor */
......@@ -118,7 +108,7 @@ PYBIND11_MODULE(libdragon_python, m) {
dev, reinterpret_cast<PyArrayObject*>(value.ptr()), tensor);
})
/*! \brief Copy the tensor data to the array */
/*! \brief Copy the tensor data to array */
.def(
"FetchTensor",
[](Workspace* self, const string& name) {
......@@ -142,7 +132,7 @@ PYBIND11_MODULE(libdragon_python, m) {
}
})
/*! \brief Run a operator from the def reference */
/*! \brief Run the operator */
.def(
"RunOperator",
[](Workspace* self, OperatorDef* def, const bool verbose) {
......@@ -156,7 +146,7 @@ PYBIND11_MODULE(libdragon_python, m) {
self->RunOperator(*def);
})
/*! \brief Run operators from the def reference */
/*! \brief Run the operators */
.def(
"RunOperator",
[](Workspace* self, vector<OperatorDef*>& defs, const bool verbose) {
......@@ -172,7 +162,7 @@ PYBIND11_MODULE(libdragon_python, m) {
}
})
/*! \brief Run a operator from the serialized def */
/*! \brief Run the operator from serialized def */
.def(
"RunOperator",
[](Workspace* self, const string& serialized, const bool verbose) {
......@@ -188,7 +178,7 @@ PYBIND11_MODULE(libdragon_python, m) {
self->RunOperator(def);
})
/*! \brief Create a graph from the serialized def */
/*! \brief Create the graph */
.def(
"CreateGraph",
[](Workspace* self, const string& serialized, const bool verbose) {
......@@ -213,89 +203,49 @@ PYBIND11_MODULE(libdragon_python, m) {
return graph->name();
})
/*! \brief Run an existing graph */
/*! \brief Run the graph */
.def(
"RunGraph",
[](Workspace* self,
const string& name,
const string& incl,
const string& excl) {
const string& include,
const string& exclude) {
py::gil_scoped_release g;
self->RunGraph(name, incl, excl);
self->RunGraph(name, include, exclude);
})
/*! \brief Run the backward */
.def(
"RunBackward",
[](Workspace* self,
const vector<OperatorDef*>& forward_ops,
const vector<OperatorDef*>& op_defs,
const vector<string>& targets,
const vector<string>& sources,
const vector<string>& input_grads,
const vector<string>& ignored_grads,
const bool is_sharing,
const vector<string>& empty_grads,
const bool retain_grads,
const bool verbose) {
GraphDef backward_ops;
GraphDef graph_def;
GraphGradientMaker maker;
for (const auto& name : ignored_grads) {
maker.add_ignored_grad(name);
for (const auto& name : empty_grads) {
maker.add_empty_grad(name);
}
for (const auto& name : sources) {
maker.add_hooked_grad(name + "_grad");
maker.add_retained_grad(name + "_grad");
}
maker.Make(forward_ops, targets, input_grads, backward_ops);
maker.Make(op_defs, targets, input_grads, graph_def);
py::gil_scoped_release g;
if (is_sharing) {
backward_ops = maker.Share(backward_ops);
if (!retain_grads) {
graph_def = maker.Share(graph_def);
}
for (const auto& def : backward_ops.op()) {
for (const auto& op_def : graph_def.op()) {
if (verbose) {
auto msg = string("\n") + def.DebugString();
auto msg = string("\n") + op_def.DebugString();
msg.pop_back();
PRINT(INFO)
<< "op {" << str::replace_all(msg, "\n", "\n ") << "\n}\n";
}
self->RunOperator(def);
}
})
/*! \brief Serialize tensors into a binary file */
.def(
"Save",
[](Workspace* self,
const string& filename,
const vector<string>& tensors,
const int format) {
vector<Tensor*> refs;
switch (format) {
case 0: // Pickle
LOG(FATAL) << "Format depends on Pickle. "
<< "Can't be used in C++.";
break;
case 1: // CaffeModel
for (const auto& name : tensors) {
refs.emplace_back(self->GetTensor(name));
}
SavaCaffeModel(filename, refs);
break;
default:
LOG(FATAL) << "Unknown format, code: " << format;
}
})
/*! \brief Load tensors from a binary file */
.def(
"Load",
[](Workspace* self, const string& filename, const int format) {
switch (format) {
case 0: // Pickle
LOG(FATAL) << "Format depends on Pickle. "
<< "Can't be used in C++.";
break;
case 1: // CaffeModel
LoadCaffeModel(filename, self);
break;
default:
LOG(FATAL) << "Unknown format, code: " << format;
self->RunOperator(op_def);
}
})
......
......@@ -20,7 +20,6 @@ PythonPluginInferOp<Context>::PythonPluginInferOp(
class_name_(OpArg<string>("class_name", "")),
kwargs_str_((OpArg<string>("kwargs_str", ""))) {
// Optimization for all python ops
if (!allow_run()) return;
this->do_sync_ = false;
// Initialize interpreter and load module
......
......@@ -24,6 +24,9 @@ namespace tensor {
void RegisterModule(py::module& m) {
/*! \brief Export the Tensor class */
py::class_<Tensor>(m, "Tensor")
/*! \brief Return the tensor name */
.def_property_readonly("name", &Tensor::name)
/*! \brief Return the number of dimensions */
.def_property_readonly("ndim", &Tensor::ndim)
......@@ -46,9 +49,9 @@ void RegisterModule(py::module& m) {
"device",
[](Tensor* self) {
if (self->has_memory()) {
auto mem_info = self->memory()->info();
auto info = self->memory()->info();
return std::tuple<string, int>(
mem_info["device_type"], atoi(mem_info["device_id"].c_str()));
info["device_type"], atoi(info["device_id"].c_str()));
} else {
return std::tuple<string, int>("Unknown", 0);
}
......
......@@ -119,8 +119,6 @@ DRAGON_API void DestroyGraphDef(GraphDef_t graph_def);
* Model API
*/
DRAGON_API void LoadCaffeModel(const std::string& model_file, Workspace_t ws);
DRAGON_API void LoadONNXModel(
const std::string& model_file,
GraphDef_t init_graph,
......
#include "dragon/core/common.h"
#include "dragon/modules/runtime/dragon_runtime.h"
#include "dragon/onnx/onnx_backend.h"
#include "dragon/utils/caffemodel.h"
#include "dragon/utils/proto_utils.h"
namespace dragon {
......@@ -161,46 +160,6 @@ DRAGON_API void DestroyGraphDef(GraphDef_t graph_def) {
if (graph_def) delete graph_def;
}
void LoadCaffeModel(const string& model_file, Workspace_t ws) {
NetParameter net_param;
ReadProtoFromBinaryFile(model_file.c_str(), &net_param);
std::string scope = "";
LOG(INFO) << "Load Model: " << model_file << "......";
LOG(INFO) << "Format: Caffe";
for (int i = 0; i < net_param.layer_size(); i++) {
const LayerParameter& layer = net_param.layer(i);
const string& layer_name = layer.name();
string prefix = scope + layer_name + "/param:";
for (int j = 0; j < layer.blobs_size(); j++) {
string tensor_name = prefix + std::to_string(j);
if (!ws->HasTensor(tensor_name)) ws->CreateTensor(tensor_name);
BlobProto blob = layer.blobs(j);
vector<int64_t> dims;
for (auto dim : blob.shape().dim())
dims.push_back(dim);
Tensor* tensor = ws->GetTensor(tensor_name);
std::stringstream DimString;
if (dims.size() > 0) {
tensor->Reshape(dims);
CHECK_EQ(tensor->count(), blob.data_size())
<< "Tensor(" << tensor_name << ") "
<< "failed to load, except size: " << tensor->count()
<< ", loaded " << blob.data_size();
DimString << tensor->DimString();
} else {
tensor->Reshape(vector<int64_t>(1, blob.data_size()));
DimString << "(missing)";
}
float* Xdata = tensor->mutable_data<float, CPUContext>();
for (int idx = 0; idx < blob.data_size(); idx++)
Xdata[idx] = blob.data(idx);
LOG(INFO) << "Tensor(" << tensor_name << ") "
<< "loaded, shape: " << DimString.str()
<< ", size: " << blob.data_size();
}
}
}
void LoadONNXModel(
const string& model_file,
GraphDef_t init_graph,
......
......@@ -19,7 +19,6 @@ ONNXImporterReturns ONNXBackend::ArgReduceImporter(
auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes;
// Determine the operation
auto* operation = attributes.AddRewrittenAttribute("operation");
if (onnx_node->node.op_type() == "ArgMax") {
......@@ -27,7 +26,6 @@ ONNXImporterReturns ONNXBackend::ArgReduceImporter(
} else if (onnx_node->node.op_type() == "ArgMin") {
operation->set_s("MIN");
}
return GenericImporter(&onnx_node_v2, ctx);
}
......@@ -37,17 +35,13 @@ ONNXImporterReturns ONNXBackend::ATenImporter(
auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes;
auto op_type = attributes.get<string>("op_type", "");
if (op_type.empty()) {
LOG(FATAL) << "op_type is required to evolve "
<< "to the specific operator.";
}
node.set_op_type(op_type);
attributes.remove("op_type");
return GenericImporter(&onnx_node_v2, ctx);
}
......@@ -56,17 +50,13 @@ ONNXImporterReturns ONNXBackend::BatchNormImporter(
const ConversionContext& ctx) {
auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes;
// Enforce to NCHW format
attributes.AddRewrittenAttribute("axis")->set_i(1);
// Remove dummy attributes
attributes.remove("consumed_inputs");
attributes.remove("is_test");
attributes.remove("spatial");
return GenericImporter(&onnx_node_v2, ctx);
}
......@@ -74,12 +64,10 @@ ONNXImporterReturns ONNXBackend::CastImporter(
ONNXNode* onnx_node,
const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes;
// Determine the dtype
auto* dtype = attributes.AddRewrittenAttribute("dtype");
auto onnx_dtype = attributes.get<int64_t>("to", TensorProto::UNDEFINED);
auto supported_dtype = true;
switch (onnx_dtype) {
case ONNX_NAMESPACE::TensorProto::BOOL:
dtype->set_s("bool");
......@@ -138,11 +126,9 @@ ONNXImporterReturns ONNXBackend::CastImporter(
supported_dtype = false;
break;
};
CHECK(supported_dtype) << "\nCasting to " << dtype->s()
<< " is not supported.";
attributes.remove("to");
return GenericImporter(onnx_node, ctx);
}
......@@ -151,17 +137,16 @@ ONNXImporterReturns ONNXBackend::ConvPoolImporter(
const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes;
const auto onnx_op_type = onnx_node->node.op_type();
// Determine the padding
auto mode = attributes.get<string>("auto_pad");
auto* padding = attributes.AddRewrittenAttribute("padding");
// SAME, SAME_LOWER, or SAME_UPPER
if (str::find(mode, "SAME"))
if (str::find(mode, "SAME")) {
padding->set_s(mode);
else
} else {
padding->set_s("VALID"); // Use explicit pads
}
attributes.remove("auto_pad");
// Determine the pooling mode
if (onnx_op_type == "MaxPool") {
attributes.AddRewrittenAttribute("mode")->set_s("MAX");
......@@ -174,14 +159,11 @@ ONNXImporterReturns ONNXBackend::ConvPoolImporter(
attributes.AddRewrittenAttribute("mode")->set_s("AVG");
attributes.AddRewrittenAttribute("global_pooling")->set_i(1);
}
auto returns = GenericImporter(onnx_node, ctx);
// Determine the op type
OperatorDef* op_def = returns.GetOp(0);
auto ks = attributes.get<ONNX_INTS>("kernel_shape");
*(op_def->mutable_type()) += (str::to(ks.size() > 0 ? ks.size() : 2) + "d");
return returns;
}
......@@ -194,11 +176,9 @@ ONNXImporterReturns ONNXBackend::GenericImporter(
op_def->mutable_input()->MergeFrom(node.input());
op_def->mutable_output()->MergeFrom(node.output());
op_def->set_name(node.name());
const auto onnx_op_type = node.op_type();
op_def->set_type(
get_default(get_renamed_nodes(), onnx_op_type, onnx_op_type));
auto mapper = [&, this](const std::string& k) {
const auto it = get_node_renamed_attrs().find(onnx_op_type);
if (it != get_node_renamed_attrs().end()) {
......@@ -224,18 +204,16 @@ ONNXImporterReturns ONNXBackend::GemmImporter(
auto alpha = attributes.get<float>("alpha", 1.f);
auto beta = attributes.get<float>("beta", 1.f);
auto trans_a = attributes.get<int64_t>("transA", 0L);
// Remove the unsupported attributes
if (alpha != 1.f || beta != 1.f) {
LOG(FATAL) << "alpha/beta can not be set currently.";
}
if (trans_a) {
LOG(FATAL) << "Tranposed A is not supported currently.";
}
attributes.remove("alpha");
attributes.remove("beta");
attributes.remove("transA");
return GenericImporter(onnx_node, ctx);
}
......@@ -244,11 +222,9 @@ ONNXImporterReturns ONNXBackend::MaxRoiPoolImporter(
const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes;
auto pooled_shape = attributes.get<ONNX_INTS>("pooled_shape");
attributes.AddRewrittenAttribute("pool_h")->set_i(pooled_shape.Get(0));
attributes.AddRewrittenAttribute("pool_w")->set_i(pooled_shape.Get(1));
attributes.remove("pooled_shape");
return GenericImporter(onnx_node, ctx);
}
......@@ -258,18 +234,16 @@ ONNXImporterReturns ONNXBackend::ReshapeImporter(
auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes;
attributes.remove("consumed_inputs");
// Determine the dims
auto* dims = attributes.AddRewrittenAttribute("dims");
if (ctx.opset_version() < 5) {
const auto& shape = attributes.get<ONNX_INTS>("shape");
CHECK_GT(shape.size(), 0) << "\nExcepted the shape value";
attributes.remove("shape");
for (auto d : shape)
for (auto d : shape) {
dims->add_ints(d);
}
} else {
CHECK_EQ(node.input_size(), 2)
<< "\nExpectd 2 input in upsample after onnx version 5";
......@@ -280,10 +254,10 @@ ONNXImporterReturns ONNXBackend::ReshapeImporter(
Argument shape_dtype, shape_values;
ONNXTensorToArgument(*shape_tensor, &shape_dtype, &shape_values);
CHECK_GT(shape_values.ints_size(), 0) << "\nExcepted the shape value";
for (auto d : shape_values.ints())
for (auto d : shape_values.ints()) {
dims->add_ints(d);
}
}
return GenericImporter(&onnx_node_v2, ctx);
}
......@@ -293,9 +267,7 @@ ONNXImporterReturns ONNXBackend::ResizeImporter(
auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes;
attributes.remove("coordinate_transformation_mode");
if (ctx.opset_version() >= 9) {
node.mutable_input()->Clear();
node.add_input(onnx_node->node.input(0));
......@@ -307,21 +279,22 @@ ONNXImporterReturns ONNXBackend::ResizeImporter(
const auto* scales_tensor = ctx.initializer().at(scales_name);
ONNXTensorToArgument(*scales_tensor, &scales_dtype, &scale_values);
auto* scales = attributes.AddRewrittenAttribute("scales");
for (auto d : scale_values.floats())
for (auto d : scale_values.floats()) {
scales->add_floats(d);
}
if (sizes_idx > 0) {
Argument sizes_dtype, sizes_values;
const auto& sizes_name = onnx_node->node.input(sizes_idx);
const auto* sizes_tensor = ctx.initializer().at(sizes_name);
ONNXTensorToArgument(*sizes_tensor, &sizes_dtype, &sizes_values);
auto* sizes = attributes.AddRewrittenAttribute("sizes");
for (auto d : sizes_values.floats())
for (auto d : sizes_values.floats()) {
sizes->add_ints(d);
}
}
} else {
LOG(FATAL) << "Required opset >= 7";
}
return GenericImporter(&onnx_node_v2, ctx);
}
......@@ -330,12 +303,10 @@ ONNXImporterReturns ONNXBackend::RoiAlignImporter(
const ConversionContext& ctx) {
auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node);
// Remove the batch indices
node.mutable_input()->Clear();
node.add_input(onnx_node->node.input(0));
node.add_input(onnx_node->node.input(1));
return GenericImporter(&onnx_node_v2, ctx);
}
......@@ -345,19 +316,22 @@ ONNXImporterReturns ONNXBackend::TileImporter(
auto node = NodeProto(onnx_node->node);
auto onnx_node_v2 = ONNXNode(node);
auto& attributes = onnx_node_v2.attributes;
// Determine the multiples from repeats
auto* multiples = attributes.AddRewrittenAttribute("multiples");
node.mutable_input()->Clear();
node.add_input(onnx_node->node.input(0));
const auto& repeats_name = onnx_node->node.input(1);
const auto* repeats_tensor = ctx.initializer().at(repeats_name);
Argument multiples_dtype, multiples_values;
ONNXTensorToArgument(*repeats_tensor, &multiples_dtype, &multiples_values);
CHECK_GT(multiples_values.ints_size(), 0) << "\nExcepted the repeats value";
for (auto d : multiples_values.ints())
multiples->add_ints(d);
if (ctx.opset_version() >= 6) {
// Determine repeats from repeats
auto* repeats = attributes.AddRewrittenAttribute("repeats");
node.mutable_input()->Clear();
node.add_input(onnx_node->node.input(0));
const auto& repeats_name = onnx_node->node.input(1);
const auto* repeats_tensor = ctx.initializer().at(repeats_name);
Argument repeats_dtype, repeats_values;
ONNXTensorToArgument(*repeats_tensor, &repeats_dtype, &repeats_values);
CHECK_GT(repeats_values.ints_size(), 0) << "\nExcepted the repeats value";
for (auto repeat : repeats_values.ints()) {
repeats->add_ints(repeat);
}
} else {
LOG(FATAL) << "Required opset >= 6";
}
return GenericImporter(&onnx_node_v2, ctx);
}
......
......@@ -4,13 +4,13 @@
namespace dragon {
#define DEFINE_FILLER_OP_IMPL(name) \
template <class Context> \
template <typename T> \
void name##Op<Context>::DoRunWithType() { \
unique_ptr<Filler<T, Context>> f; \
f.reset(CreateFiller<T, Context>(this->proto_)); \
f->Fill(Output(0), ctx()); \
#define DEFINE_FILLER_OP_IMPL(name) \
template <class Context> \
template <typename T> \
void name##Op<Context>::DoRunWithType() { \
unique_ptr<Filler<T, Context>> f; \
f.reset(CreateFiller<T, Context>(this->filler_info_)); \
f->Fill(Output(0), ctx()); \
}
#define DISPATCH_WITH_TYPES(name, ...) \
......
......@@ -30,7 +30,7 @@ class InitializeOp : public Operator<Context> {
void RunOnDevice() override;
protected:
TensorFillerProto proto_;
FillerInfo filler_info_;
DECLARE_ARGS_WITH_DESC(int64_t, dims);
};
......@@ -142,9 +142,9 @@ class RandomNormalOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) {
auto mu = OpArg<float>("mean", 0.f);
auto sigma = OpArg<float>("std", 1.f);
this->proto_.set_mean(mu);
this->proto_.set_std(sigma);
this->proto_.set_type("normal");
this->filler_info_.set_mean(mu);
this->filler_info_.set_std(sigma);
this->filler_info_.set_type("normal");
}
USE_OPERATOR_FUNCTIONS;
......@@ -161,9 +161,9 @@ class RandomUniformOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) {
auto low = OpArg<float>("low", -1.f);
auto high = OpArg<float>("high", 1.f);
this->proto_.set_low(low);
this->proto_.set_high(high);
this->proto_.set_type("uniform");
this->filler_info_.set_low(low);
this->filler_info_.set_high(high);
this->filler_info_.set_type("uniform");
}
USE_OPERATOR_FUNCTIONS;
......@@ -180,11 +180,11 @@ class TruncatedNormalOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) {
auto mu = OpArg<float>("mean", 0.f);
auto sigma = OpArg<float>("std", 1.f);
this->proto_.set_mean(mu);
this->proto_.set_std(sigma);
this->proto_.set_low(mu - 2 * sigma);
this->proto_.set_high(mu + 2 * sigma);
this->proto_.set_type("truncated_normal");
this->filler_info_.set_mean(mu);
this->filler_info_.set_std(sigma);
this->filler_info_.set_low(mu - 2 * sigma);
this->filler_info_.set_high(mu + 2 * sigma);
this->filler_info_.set_type("truncated_normal");
}
USE_OPERATOR_FUNCTIONS;
......@@ -201,15 +201,15 @@ class GlorotNormalOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) {
auto scale = OpArg<float>("scale", 2.f);
auto mode = OpArg<string>("mode", "fan_in");
this->proto_.set_type("msra");
this->filler_info_.set_type("glorot_normal");
if (mode == "fan_avg") {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_AVG);
this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_AVG);
} else if (mode == "fan_out") {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_OUT);
this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_OUT);
} else {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_IN);
this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_IN);
}
this->proto_.set_scale(scale);
this->filler_info_.set_scale(scale);
}
USE_OPERATOR_FUNCTIONS;
......@@ -226,15 +226,15 @@ class GlorotUniformOp final : public InitializeOp<Context> {
: InitializeOp<Context>(def, ws) {
auto scale = OpArg<float>("scale", 3.f);
auto mode = OpArg<string>("mode", "fan_in");
this->proto_.set_type("xavier");
this->filler_info_.set_type("glorot_uniform");
if (mode == "fan_avg") {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_AVG);
this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_AVG);
} else if (mode == "fan_out") {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_OUT);
this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_OUT);
} else {
this->proto_.set_variance_norm(TensorFillerProto_VarianceNorm_FAN_IN);
this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_IN);
}
this->proto_.set_scale(scale);
this->filler_info_.set_scale(scale);
}
USE_OPERATOR_FUNCTIONS;
......
......@@ -9,9 +9,12 @@ template <typename T>
void TileOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
int num_repeats;
repeats(0, &num_repeats);
auto Y_dims = X.dims();
for (int i = 0; i < Y_dims.size(); ++i)
Y_dims[i] *= multiples(i);
for (int i = 0; i < num_repeats; ++i) {
Y_dims[i] *= repeats(i);
}
if (X.dims() == Y_dims) {
Y->Reshape(Y_dims)->CopyFrom(X, ctx());
......@@ -49,7 +52,7 @@ void TileGradientOp<Context>::DoRunWithType() {
dx = dest_->template mutable_data<T, Context>();
}
kernel::TileGrad(
dest_->count(0, axis_), dest_->count(axis_), multiple_, dy, dx, ctx());
dest_->count(0, axis_), dest_->count(axis_), repeat_, dy, dx, ctx());
}
template <class Context>
......@@ -57,10 +60,14 @@ void TileGradientOp<Context>::RunOnDevice() {
auto &dY = Input(0), *dX = Output(0);
// Add the axes
int num_repeats;
repeats(0, &num_repeats);
vector<pair<int, int>> dispatch_axes;
for (int i = 0; i < dY.ndim(); i++) {
auto m = multiples(i);
if (m > 1) dispatch_axes.push_back({m, i});
for (int i = 0; i < dY.ndim() && i < num_repeats; i++) {
auto repeat = repeats(i);
if (repeat > 1) {
dispatch_axes.push_back({repeat, i});
}
}
std::sort(dispatch_axes.begin(), dispatch_axes.end());
std::reverse(dispatch_axes.begin(), dispatch_axes.end());
......@@ -76,10 +83,10 @@ void TileGradientOp<Context>::RunOnDevice() {
// Reduce N times along each tiled axis
for (const auto& task : dispatch_axes) {
axis_ = task.second, multiple_ = task.first;
axis_ = task.second, repeat_ = task.first;
vec64_t X_dims(src_->dims());
X_dims[axis_] /= multiple_;
X_dims[axis_] /= repeat_;
dest_->Reshape(X_dims);
DispatchHelper<FloatingTensorTypes>::Call(this, dY);
......
......@@ -21,7 +21,7 @@ template <class Context>
class TileOp final : public Operator<Context> {
public:
TileOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, multiples);
GET_ARGS_WITH_DESC(int64_t, repeats);
}
USE_OPERATOR_FUNCTIONS;
......@@ -31,7 +31,7 @@ class TileOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARGS_WITH_DESC(int64_t, multiples);
DECLARE_ARGS_WITH_DESC(int64_t, repeats);
};
template <class Context>
......@@ -39,7 +39,7 @@ class TileGradientOp final : public Operator<Context> {
public:
TileGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, multiples);
GET_ARGS_WITH_DESC(int64_t, repeats);
}
USE_OPERATOR_FUNCTIONS;
......@@ -50,12 +50,12 @@ class TileGradientOp final : public Operator<Context> {
protected:
Tensor *dest_, *src_, nav_;
int64_t axis_, multiple_;
DECLARE_ARGS_WITH_DESC(int64_t, multiples);
int64_t axis_, repeat_;
DECLARE_ARGS_WITH_DESC(int64_t, repeats);
};
DEFINE_ARGS_WITH_DESC(int64_t, TileOp, multiples);
DEFINE_ARGS_WITH_DESC(int64_t, TileGradientOp, multiples);
DEFINE_ARGS_WITH_DESC(int64_t, TileOp, repeats);
DEFINE_ARGS_WITH_DESC(int64_t, TileGradientOp, repeats);
} // namespace dragon
......
......@@ -9,7 +9,6 @@ void AdamUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
t_++;
auto beta1 = Parameter("beta1"), beta2 = Parameter("beta2");
auto coef = sqrt(1.f - pow(beta2, t_)) / (1.f - pow(beta1, t_));
kernel::AdamUpdate(
dX->count(),
Parameter("base_lr") * coef * this->lr_mult_,
......
......@@ -10,7 +10,6 @@ void SGDUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
auto lr = Parameter("base_lr") * this->lr_mult_;
if (last_lr_ > 0) correction_ = lr / last_lr_;
last_lr_ = lr; // Record the last value
kernel::SGDUpdate(
dX->count(),
lr,
......
......@@ -20,9 +20,7 @@ void BiasAddOp<Context>::DoRunWithType() {
LOG(FATAL) << "Unknown DataFormat: " << data_format();
}
// Maybe fill the bias at the first time
TENSOR_FILL(B, vec64_t({C}));
kernel::BiasAdd(
N,
C,
......
......@@ -4,73 +4,69 @@
namespace dragon {
#define SAME_PADDING(A, B) \
A[i] = padding_needed / 2; \
B[i] = padding_needed - A[i]
#define DETERMINE_SAME_PADDING(l, r) \
if (padding_ != "SAME_UPPER") { \
l[i] = pad_size / 2; \
r[i] = pad_size - l[i]; \
} else { \
r[i] = pad_size / 2; \
l[i] = pad_size - r[i]; \
}
template <class Context>
void ConvOpBase<Context>::ComputeOutShape() {
auto X_dims = Input(0).dims();
out_shape_.clear();
for (int i = 0; i < num_axes_; i++) {
if (!Transposed()) {
auto idm = X_dims[axis_ + i];
auto dk = dilation_[i] * (kshape_[i] - 1) + 1;
if (!str::find(padding_, "SAME")) {
// Explicit pads
auto odm = (idm + pad_l_[i] + pad_r_[i] - dk) / stride_[i] + 1;
out_shape_.push_back(odm);
} else {
// Auto pads
int64_t odm = (idm + stride_[i] - 1) / (float)stride_[i];
auto padding_needed =
std::max(int64_t(0), (odm - 1) * stride_[i] + dk - idm);
out_shape_.push_back(odm);
if (padding_ == "SAME_UPPER") {
SAME_PADDING(pad_l_, pad_r_);
} else {
SAME_PADDING(pad_r_, pad_l_);
} // SAME_LOWER or SAME
vec64_t X_dims = Input(0).dims();
int64_t in_size, out_size, k_size, pad_size;
if (!Transposed()) {
for (int i = 0; i < num_axes_; i++) {
in_size = X_dims[axis_ + i];
k_size = dilation_[i] * (kshape_[i] - 1) + 1;
if (!str::find(padding_, "SAME")) { // Explicit pads
pad_size = pad_l_[i] + pad_r_[i];
out_size = (in_size + pad_size - k_size) / stride_[i] + 1;
} else { // Auto pads
out_size = (in_size + stride_[i] - 1) / stride_[i];
pad_size = (out_size - 1) * stride_[i] + k_size - in_size;
pad_size = std::max(pad_size, int64_t(0));
DETERMINE_SAME_PADDING(pad_l_, pad_r_);
}
out_shape_.push_back(out_size);
}
} else {
int num_output_padding;
output_padding(0, &num_output_padding);
CHECK(num_output_padding == 0 || num_output_padding == num_axes_)
<< "\nExcepted 0 or " << num_axes_ << " ints for <output_padding>.";
if (!str::find(padding_, "SAME")) { // Explicit pads
for (int i = 0; i < num_axes_; i++) {
in_size = X_dims[axis_ + i];
k_size = dilation_[i] * (kshape_[i] - 1) + 1;
pad_size = pad_l_[i] + pad_r_[i];
out_size = stride_[i] * (in_size - 1) + k_size - pad_size;
if (num_output_padding > 0) out_size += output_padding(i);
out_shape_.push_back(out_size);
}
} else {
auto idm = X_dims[axis_ + i];
auto dk = dilation_[i] * (kshape_[i] - 1) + 1;
if (!str::find(padding_, "SAME")) {
// Explicit pads
auto odm = stride_[i] * (idm - 1) + dk - pad_l_[i] - pad_r_[i];
out_shape_.push_back(odm);
} else {
// Auto pads
int output_shape_size;
int output_padding_size;
output_shape(0, &output_shape_size);
output_padding(0, &output_padding_size);
CHECK(output_shape_size == 0 || output_shape_size == num_axes_)
<< "Excepted 0 or " << num_axes_ << " ints for output shape.";
CHECK(output_padding_size == 0 || output_padding_size == num_axes_)
<< "Excepted 0 or " << num_axes_ << " ints for output padding.";
int64_t padding_needed, odm;
if (output_padding_size) {
padding_needed = output_padding(i);
odm = stride_[i] * (idm - 1) + dk + padding_needed;
} else if (output_shape_size) {
odm = output_shape(i);
padding_needed = odm - (stride_[i] * (idm - 1) + dk);
CHECK_GE(padding_needed, 0)
<< "\nThe output shape is incorrect."
<< "\nWith the given stride and kernel, "
<< "dimension of spatial axis " << i << " should be at least "
<< odm - padding_needed << ".";
} else {
LOG(FATAL) << "Excepted the output padding or output shape "
<< "for \"SAME\" padding algorithm.";
}
out_shape_.push_back(odm);
if (padding_ == "SAME_UPPER") {
SAME_PADDING(pad_l_, pad_r_);
} else {
SAME_PADDING(pad_r_, pad_l_);
} // SAME_LOWER or SAME
// Auto pads
int num_output_shape;
output_shape(0, &num_output_shape);
CHECK(num_output_shape == num_axes_)
<< "\nExcepted " << num_axes_ << " ints for <output_shape>.";
for (int i = 0; i < num_axes_; i++) {
in_size = X_dims[axis_ + i];
k_size = dilation_[i] * (kshape_[i] - 1) + 1;
out_size = output_shape(i);
pad_size = stride_[i] * (in_size - 1) + k_size;
if (num_output_padding > 0) pad_size += output_padding(i);
CHECK_GE(pad_size, out_size)
<< "\nThe output shape is incorrect."
<< "\nDimension of spatial axis " << i << " should be at most "
<< pad_size << ".";
pad_size = stride_[i] * (in_size - 1) + k_size - out_size;
pad_size = std::max(pad_size, int64_t(0));
DETERMINE_SAME_PADDING(pad_l_, pad_r_);
out_shape_.push_back(out_size);
}
}
}
......@@ -373,7 +369,7 @@ INSTANTIATE_API(CUDAContext, float);
INSTANTIATE_API(CUDAContext, double);
#endif
#undef SAME_PADDING
#undef INSTANTIATE_API
#undef DETERMINE_SAME_PADDING
} // namespace dragon
......@@ -5,9 +5,14 @@
namespace dragon {
#define SAME_PADDING(A, B) \
A[i] = padding_needed / 2; \
B[i] = padding_needed - A[i]
#define DETERMINE_SAME_PADDING(l, r) \
if (padding_ != "SAME_UPPER") { \
l[i] = pad_size / 2; \
r[i] = pad_size - l[i]; \
} else { \
r[i] = pad_size / 2; \
l[i] = pad_size - r[i]; \
}
template <class Context>
void PoolOpBase<Context>::Setup(int num_axes) {
......@@ -52,41 +57,27 @@ void PoolOpBase<Context>::ComputeOutShape() {
kshape_[i] = in_dims_[i + 2];
}
// Adjust the pads for SAME padding algorithm
if (str::find(padding_, "SAME")) {
for (int i = 0; i < num_axes_; i++) {
auto idm = in_dims_[i + 2];
int64_t odm = (idm + stride_[i] - 1) / (float)stride_[i];
auto padding_needed =
std::max((int64_t)0, (odm - 1) * stride_[i] + kshape_[i] - idm);
if (padding_ == "SAME_UPPER") {
SAME_PADDING(pad_l_, pad_r_);
} else {
SAME_PADDING(pad_r_, pad_l_);
} /*! SAME_LOWER or SAME */
}
}
// Compute the output dimensions
auto floor_or_ceil = ceil_mode_ > 0
? static_cast<float (*)(float)>(&std::ceil)
: static_cast<float (*)(float)>(&std::floor);
out_dims_ = in_dims_;
out_shape_ = Input(0).dims();
int64_t in_size, k_size, pad_size;
for (int i = 0; i < num_axes_; i++) {
auto in_dim = in_dims_[i + 2];
if (!str::find(padding_, "SAME")) {
// Explicit pads
in_dim += pad_l_[i] + pad_r_[i];
out_shape_[i + axis_] = out_dims_[i + 2] =
floor_or_ceil((in_dim - kshape_[i]) / (float)stride_[i]) + 1;
} else {
// Auto pads
out_shape_[i + axis_] = out_dims_[i + 2] =
floor_or_ceil(in_dim / (float)stride_[i]);
float out_size;
in_size = in_dims_[i + 2], k_size = kshape_[i];
if (!str::find(padding_, "SAME")) { // Explicit pads
pad_size = pad_l_[i] + pad_r_[i];
out_size = float(in_size + pad_size - k_size) / float(stride_[i]) + 1.f;
out_size = floor_or_ceil(out_size);
} else { // Auto pads
out_size = std::ceil(float(in_size) / float(stride_[i]));
pad_size = ((int64_t)out_size - 1) * stride_[i] + k_size - in_size;
pad_size = std::max(pad_size, int64_t(0));
DETERMINE_SAME_PADDING(pad_l_, pad_r_);
}
out_shape_[i + axis_] = out_dims_[i + 2] = out_size;
}
}
......@@ -95,6 +86,6 @@ template class PoolOpBase<CPUContext>;
template class PoolOpBase<CUDAContext>;
#endif
#undef SAME_PADDING
#undef DETERMINE_SAME_PADDING
} // namespace dragon
syntax = "proto2";
package dragon;
message BlobShape {
repeated int64 dim = 1 [packed = true];
}
message BlobProto {
optional BlobShape shape = 7;
repeated float data = 5 [packed = true];
optional int32 num = 1 [default = 0];
optional int32 channels = 2 [default = 0];
optional int32 height = 3 [default = 0];
optional int32 width = 4 [default = 0];
}
message NetParameter {
optional string name = 1;
repeated LayerParameter layer = 100;
}
message LayerParameter {
optional string name = 1;
repeated BlobProto blobs = 7;
}
......@@ -51,26 +51,6 @@ message TensorProto {
optional string name = 7;
}
// Record the filler of Tensor.
// This structure is kept for backward compatibility
// with caffe1, which relies implicit initializer.
message TensorFillerProto {
optional string tensor = 1;
optional string type = 2 [default = 'constant'];
optional float value = 3 [default = 0];
optional float low = 4 [default = 0];
optional float high = 5 [default = 1];
optional float mean = 6 [default = 0];
optional float std = 7 [default = 1];
optional float scale = 8 [default = 3];
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
FAN_AVG = 2;
}
optional VarianceNorm variance_norm = 9 [default = FAN_IN];
}
// Store multiple TensorProto objects in one single proto.
message TensorProtos {
repeated TensorProto protos = 1;
......@@ -139,16 +119,6 @@ message OperatorDef {
optional string cache_key = 7;
}
// Record the gradient information
message GradientProto {
// The derivative target.
optional string cost = 1;
// The target with respect to?
optional string wrt = 2;
// The external gradient
optional string external = 3;
}
// Graph Definition
message GraphDef {
// The graph name.
......@@ -171,6 +141,33 @@ message GraphDef {
// The name of outputs.
repeated string output = 8;
// The gradients information.
repeated GradientProto gradient = 9;
// The info of gradients.
repeated GradientInfo grad_info = 9;
}
// Record the filler information.
// This structure is kept for backward compatibility
// with caffe, which relies the implicit initializer.
message FillerInfo {
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
FAN_AVG = 2;
}
optional string type = 1 [default = 'constant'];
optional float value = 2 [default = 0];
optional float low = 3 [default = 0];
optional float high = 4 [default = 1];
optional float mean = 5 [default = 0];
optional float std = 6 [default = 1];
optional float scale = 7 [default = 3];
optional VarianceNorm variance_norm = 8 [default = FAN_IN];
}
// Record the gradient information.
message GradientInfo {
// The derivative target.
optional string y = 1;
// The differentiated inputs.
repeated string xs = 2;
}
......@@ -30,7 +30,6 @@ from dragon._api import metrics
from dragon._api import nn
from dragon._api import optimizers
from dragon._api import random
from dragon._api import workspace
from dragon._api import vision
# Virtual API
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.training.adam import Adam
from dragon.core.training.rmsprop import RMSProp
from dragon.core.training.sgd import Nesterov
from dragon.core.training.sgd import SGD
from dragon.core.training.updater import Updater
__all__ = [_s for _s in dir() if not _s.startswith('_')]
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.framework.workspace import feed_tensor
from dragon.core.framework.workspace import fetch_tensor
from dragon.core.framework.workspace import has_tensor
from dragon.core.framework.workspace import load
from dragon.core.framework.workspace import reset_tensor
from dragon.core.framework.workspace import run_operator
from dragon.core.framework.workspace import save
__all__ = [_s for _s in dir() if not _s.startswith('_')]
......@@ -28,6 +28,7 @@ from dragon.core.autograph.tensor import Tensor
from dragon.core.eager import context as eager_context
from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import context
from dragon.core.framework import device_spec
from dragon.core.framework import workspace
from dragon.core.training import optimizer
from dragon.core.util import decorator
......@@ -276,13 +277,16 @@ class FunctionGuard(object):
executables = self.executables
inputs, kwargs = self.canonicalize_inputs(*args, **kwargs)
executables[0](*inputs, return_outputs=False, **kwargs)
_ = [func(return_outputs=False) for func in executables[1:]]
[func(return_outputs=False) for func in executables[1:]]
outputs = []
for obj in self.outputs:
if isinstance(obj, Tensor):
outputs.append(EagerTensor(id=obj.id, own_storage=False))
current_ws = workspace.get_workspace()
for output in self.outputs:
if isinstance(output, Tensor):
impl = current_ws.GetTensor(output.id)
device = device_spec.DeviceSpec(*impl.device)
outputs.append(EagerTensor(impl=impl, device=device))
else:
outputs.append(obj)
outputs.append(output)
return outputs
def __get__(self, instance, owner):
......
......@@ -23,7 +23,6 @@ from dragon.core.autograph.op_def import OpDef
from dragon.core.autograph.op_def import OpInfo
from dragon.core.autograph.tensor import Tensor
from dragon.core.framework import config
from dragon.core.framework import context
from dragon.core.framework import proto_util
from dragon.core.framework import workspace
from dragon.core.proto import dragon_pb2
......@@ -32,7 +31,7 @@ from dragon.core.util import nest
def add_device_option(graph_def):
"""Add the device option for graph."""
"""Add the device option."""
cfg = config.config()
str2idx = {'cpu': 0, 'cuda': 1, 'cnml': 2}
dev_opt = dragon_pb2.DeviceOption()
......@@ -42,69 +41,66 @@ def add_device_option(graph_def):
graph_def.device_option.CopyFrom(dev_opt)
def add_gradient_info(graph_def, targets):
"""Add the gradient info for graph."""
gradients = set()
def add_grad_info(graph_def, targets):
"""Add the gradient info."""
for target in targets:
if target._grad is not None:
gradients.update(target._grad.make_pairs())
for (cost, wrt) in gradients:
gradient = dragon_pb2.GradientProto()
gradient.cost, gradient.wrt = str(cost), str(wrt)
graph_def.gradient.extend([gradient])
info = target._grad
if info is not None:
graph_def.grad_info.extend([
dragon_pb2.GradientInfo(
y=info.y.id,
xs=[x.id for x in info.xs])])
def add_optimization(graph_def, level=None):
"""Add the optimization attribute for graph."""
"""Add the optimization argument."""
cfg = config.config()
if level is None:
level = cfg.graph_optimization
graph_def.arg.add().CopyFrom(
proto_util.make_argument(
'optimization_level', level))
proto_util.make_argument('optimization', level))
graph_def.graph_type = cfg.graph_type
def add_phase(graph_def, targets):
"""Add the phase attribute for graph."""
phase = context.get_graph_phase()
if phase is None:
phase = 'TEST'
for target in targets:
if target._grad is not None and \
target._grad.required():
"""Add the phase argument."""
phase = 'TEST'
for target in targets:
try:
if target._grad and target._grad.required():
phase = 'TRAIN'
break
except AttributeError:
pass
graph_def.arg.extend([proto_util.make_argument('phase', phase)])
def add_update_ops(graph_def, optimizer):
"""Add the update operators for graph."""
def add_update_defs(graph_def, optimizer):
"""Add the update defs."""
if optimizer is None:
return
grads, update_ops = [], []
grads, update_defs = [], []
extra_arguments = optimizer._extra_kwargs
extra_arguments['handle'] = optimizer._op_handle
# Generate update operators according to the updater.
for e in optimizer._param_group:
(param, grad), arguments = e
if workspace.has_tensor(grad):
# Generate op defs according to the collected updates
current_ws = workspace.get_workspace()
for (param, grad), arguments in optimizer._param_group:
if current_ws.has_tensor(grad):
grads.append(grad)
arguments = dict(arguments, **extra_arguments)
update_ops.append(
update_defs.append(
proto_util.make_operator_def(
op_type=optimizer._op_type,
inputs=[grad],
outputs=[param],
name=OpDef.get_name(),
**arguments
))
**arguments))
else:
logging.info('Skip to update Tensor({}).'.format(param))
# Insert a reduce op if the process group is found.
# Insert a reduce def if the process group is found.
process_group = optimizer._process_group
if process_group is not None:
update_ops.insert(
update_defs.insert(
0, proto_util.make_operator_def(
op_type='Collective',
inputs=grads,
......@@ -115,7 +111,7 @@ def add_update_ops(graph_def, optimizer):
**process_group.arguments
)
)
graph_def.op.extend(update_ops)
graph_def.op.extend(update_defs)
class Function(object):
......@@ -128,16 +124,15 @@ class Function(object):
self.graph_name = None # Determined after creating
self.inputs, self.outputs = None, None
def create(self, inputs=None, outputs=None, givens=None, updater=None):
def create(self, inputs=None, outputs=None, givens=None, optimizer=None):
self.inputs = inputs = [] if inputs is None else nest.flatten(inputs)
self.outputs = outputs = [] if outputs is None else nest.flatten(outputs)
if len(outputs) > 0 and updater is not None:
raise ValueError('Specific either <outputs> or <updater>, not both.')
if len(outputs) > 0 and optimizer is not None:
raise ValueError('Specific either <outputs> or <optimizer>, not both.')
# Collect the forward defs.
op_info = OpInfo()
# Collect the forward operators.
requires_grad = False
for i, output in enumerate(outputs):
op_info.merge_from(output)
......@@ -149,7 +144,7 @@ class Function(object):
except AttributeError:
raise ValueError('Output[%d] is not a symbolic tensor.' % i)
# Handle givens.
# Handle the replacements.
if givens is not None:
name_dict = {}
for k, v in givens.items():
......@@ -161,62 +156,61 @@ class Function(object):
'Excepted a Tensor, '
'got {}.'.format(type(v).__name__)
)
# Update original operators.
# Update the original defs.
op_info = copy.deepcopy(op_info)
for k in op_info._defs.keys():
op_def = op_info._defs[k]
op_def.input.extend([
name_dict[input]
if input in name_dict else input
for input in op_def.input
])
for input in op_def.input])
del op_def.input[:len(op_def.input) // 2]
# Sort out the states.
# Sort out the forward defs.
op_defs = sorted(op_info._defs.items(), key=lambda d: d[0])
forward_ops = copy.deepcopy([v for k, v in op_defs])
forward_defs = copy.deepcopy([v for k, v in op_defs])
# Generate the backward operators.
# Generate the backward defs.
if requires_grad:
input_grads, grad_targets = {}, []
for output in outputs:
grad_info = output._grad
if grad_info is not None:
if grad_info.input is not None:
input_grads[output.id] = output._grad.input.id
info = output._grad
if info is not None:
if info.grad_y is not None:
input_grads[output.id] = info.grad_y.id
grad_targets.append(output.id)
forward_ops, gradient_ops, _ = \
grad_maker.GradientMaker.make(
forward_ops=forward_ops,
targets=grad_targets,
input_grads=input_grads,
)
backward_defs = grad_maker.GradientMaker.make(
op_defs=forward_defs,
targets=grad_targets,
input_grads=input_grads,
)
else:
gradient_ops = []
backward_defs = []
# Fill with all known graph elements.
self.graph_def.op.extend(forward_ops + gradient_ops)
# Fill graph elements.
self.graph_def.op.extend(forward_defs + backward_defs)
self.graph_def.input.extend([input.name for input in inputs])
self.graph_def.output.extend(list(op_info._targets))
if len(outputs) > 0:
add_device_option(self.graph_def)
add_optimization(self.graph_def)
add_gradient_info(self.graph_def, outputs)
add_grad_info(self.graph_def, outputs)
add_phase(self.graph_def, outputs)
elif updater is not None:
elif optimizer is not None:
add_device_option(self.graph_def)
add_optimization(self.graph_def, level=0)
add_update_ops(self.graph_def, updater)
add_update_defs(self.graph_def, optimizer)
# Notify the backend to create and optimize.
self.graph_name = workspace.create_graph(self.graph_def)
current_ws = workspace.get_workspace()
self.graph_name = current_ws.create_graph(self.graph_def)
# Bind a callback to run this graph.
self.callback = lambda *args, **kwargs: \
workspace.run_graph(
graph=self.graph_name,
inputs=(inputs, args),
current_ws.run_graph(
name=self.graph_name,
inputs_and_values=(inputs, args),
outputs=outputs,
**kwargs
)
......@@ -273,15 +267,15 @@ class Function(object):
add_phase(graph_def, self.outputs)
# Notify the backend to create and optimize.
current_ws = workspace.get_workspace()
self.graph_def = graph_def
self.graph_name = workspace.create_graph(graph_def)
self.graph_name = current_ws.create_graph(graph_def)
# Bind a callback to run this graph.
callback_inputs = self.inputs if explicit_inputs else []
self.callback = lambda *args, **kwargs: \
workspace.run_graph(
graph=self.graph_name,
inputs=(callback_inputs, args),
current_ws.run_graph(
name=self.graph_name,
inputs_and_values=(self.inputs if explicit_inputs else [], args),
outputs=self.outputs,
**kwargs
)
......
......@@ -21,37 +21,26 @@ from dragon.core.util import nest
class GradientInfo(object):
"""A class to store the known gradient relations."""
def __init__(self, parent):
self._parent = parent
self._cost, self._wrt = [], []
self._input = None
def __init__(self, y, grad_y=None):
self._y, self._grad_y, self._xs = y, grad_y, []
@property
def cost(self):
return self._cost
def grad_y(self):
return self._grad_y
@property
def input(self):
return self._input
def xs(self):
return self._xs
@property
def wrt(self):
return self._wrt
def y(self):
return self._y
def add_cost(self, cost):
self._cost.append(cost)
def add_wrt(self, wrt):
self._wrt.append(wrt)
def make_pairs(self):
return [(self._parent.id, wrt) for wrt in self._wrt]
def add_x(self, x):
self._xs.append(x)
def required(self):
return len(self._wrt) > 0
def set_input(self, input):
self._input = input
return len(self._xs) > 0
def gradients(ys, xs, grad_ys=None):
......@@ -112,18 +101,14 @@ def gradients(ys, xs, grad_ys=None):
if grad_ys is not None:
grad_ys = nest.flatten(grad_ys)
# Record the gradient info (cost, wrt, input),
# Record the gradient info (y, grad_y, xs),
# then, generate the gradient references once.
for i, y in enumerate(ys):
if y._grad is None:
y._grad = GradientInfo(y)
if grad_ys is not None:
y._grad.set_input(grad_ys[i])
grad_y = grad_ys[i] if grad_ys is not None else None
y._grad = GradientInfo(y, grad_y)
for x in xs:
if not hasattr(x, '_grad') or x._grad is None:
x._grad = GradientInfo(x)
y._grad.add_wrt(x.id)
x._grad.add_cost(y)
y._grad.add_x(x)
if i == 0:
dxs.append(TensorRef(x.id + '_grad', x.shape, x.dtype))
......
......@@ -13,16 +13,7 @@
#
# ------------------------------------------------------------
"""Gradient maker implemented in python.
The basic idea of ``GradientMaker`` comes from ``caffe2``,
Jia provided a simple way to bridge the Generator(Python) with OpScheme(C++).
For the efficient C++ implementation, see,
<https://github.com/seetaresearch/Dragon/blob/master/Dragon/src/core/graph_gradient.cc>
"""
"""Python-implemented gradient maker."""
from __future__ import absolute_import
from __future__ import division
......@@ -40,25 +31,25 @@ class GradientMaker(object):
"""Make def for the gradient based on rules."""
@classmethod
def gen_def(cls, forward_op, g_outputs):
def gen_def(cls, op_def, g_outputs):
"""Generate the OperatorDef from forward op."""
g_ops, g_inputs, defaults = backend.CreateGradientDefs(
forward_op.SerializeToString(), g_outputs)
for idx, g_op in enumerate(g_ops):
grad_defs, g_inputs, defaults = backend.CreateGradientDefs(
op_def.SerializeToString(), g_outputs)
for i, grad_def in enumerate(grad_defs):
new_def = dragon_pb2.OperatorDef()
new_def.ParseFromString(g_op)
g_ops[idx] = new_def
return g_ops, g_inputs, defaults
new_def.ParseFromString(grad_def)
grad_defs[i] = new_def
return grad_defs, g_inputs, defaults
@classmethod
def check(cls, forward_op, inputs_to_grads, blacklist, targets):
def check(cls, op_def, inputs_to_grads, blacklist, targets):
"""Check if missing gradients. If missing, skip."""
if forward_op.type in backend.NO_GRADIENT_OPERATORS:
for input in forward_op.input:
if op_def.type in backend.NO_GRADIENT_OPERATORS:
for input in op_def.input:
blacklist.add(input)
return True, None
gen_grads = []
for idx, output in enumerate(forward_op.output):
for idx, output in enumerate(op_def.output):
if output not in inputs_to_grads:
if output in blacklist:
return True, gen_grads
......@@ -66,50 +57,43 @@ class GradientMaker(object):
# Consider to generate virtual gradient for targets.
gen_grads.append((output, idx))
inputs_to_grads[output] = output + '_grad'
elif len(forward_op.output) == 1:
elif len(op_def.output) == 1:
# We can skip this op, obviously.
return True, gen_grads
# Pass, even if missing some grads.
return False, gen_grads
@classmethod
def make(cls, forward_ops, targets, input_grads=None):
"""The making procedure."""
def make(cls, op_defs, targets, input_grads=None):
"""Make the backward op defs."""
inputs_to_grads = {} if input_grads is None else input_grads
inputs_count, grads_count = defaultdict(int), defaultdict(int)
all_split_grads, blacklist = set(), set()
backward_ops = []
# A DAG may not have any in-place operators.
is_dag = True
# PLAY for the forward.
for forward_op in forward_ops:
if forward_op.type in backend.NO_GRADIENT_OPERATORS:
for op_def in op_defs:
if op_def.type in backend.NO_GRADIENT_OPERATORS:
continue
outputs = [o for o in forward_op.output]
for input in forward_op.input:
outputs = [output for output in op_def.output]
for input in op_def.input:
if input not in outputs:
# Avoid to count the duplicate input,
# (i.e. the in-place output).
inputs_count[input] += 1
else:
is_dag = False
# PLAY for the backward.
for forward_op in forward_ops[::-1]:
backward_defs = []
for op_def in op_defs[::-1]:
# Collect inputs and outputs.
is_skip, gen_grads = cls.check(
forward_op=forward_op,
op_def=op_def,
inputs_to_grads=inputs_to_grads,
blacklist=blacklist,
targets=targets,
)
# Missing grads are represented as ``None``.
g_outputs = [inputs_to_grads.get(name, '')
for name in forward_op.output]
g_ops, g_inputs, defaults = cls.gen_def(forward_op, g_outputs)
g_outputs = [inputs_to_grads.get(name, '') for name in op_def.output]
grad_defs, g_inputs, defaults = cls.gen_def(op_def, g_outputs)
# Append operators.
if not is_skip:
......@@ -127,17 +111,17 @@ class GradientMaker(object):
outputs=op_outputs,
defaults=values,
)
if forward_op.HasField('device_option'):
gen_op.device_option.CopyFrom(forward_op.device_option)
backward_ops.append(gen_op)
if op_def.HasField('device_option'):
gen_op.device_option.CopyFrom(op_def.device_option)
backward_defs.append(gen_op)
# GradientOp
for g_op in g_ops:
g_op.name = OpDef.get_name()
backward_ops.append(g_op)
for grad_def in grad_defs:
grad_def.name = OpDef.get_name()
backward_defs.append(grad_def)
# Split and gather grads for multi-used input.
for g_op in g_ops:
for g_output_idx, g_output in enumerate(g_op.output):
for grad_def in grad_defs:
for g_output_idx, g_output in enumerate(grad_def.output):
original_idx = -1
for g_input_idx, g_input in enumerate(g_inputs):
if g_output == g_input:
......@@ -145,10 +129,10 @@ class GradientMaker(object):
# Ignore un-used && in-placed GI(?).
if original_idx == -1:
continue
if g_output in g_op.input:
if g_output in grad_def.input:
continue
# Found a split branch.
original_name = forward_op.input[original_idx]
original_name = op_def.input[original_idx]
if inputs_count[original_name] > 1:
# Split.
split_name = g_output + '_autosplit_%d' % grads_count[g_output]
......@@ -161,21 +145,21 @@ class GradientMaker(object):
for idx in range(grads_count[g_output]):
if '%s_autosplit_%d' % (g_output, idx) in all_split_grads:
split_inputs.append('%s_autosplit_%d' % (g_output, idx))
gather_op = proto_util.make_operator_def(
gather_def = proto_util.make_operator_def(
name=OpDef.get_name(),
op_type='GradientGather',
inputs=split_inputs,
outputs=[g_output],
)
if g_op.HasField('device_option'):
gather_op.device_option.CopyFrom(g_op.device_option)
backward_ops.append(gather_op)
g_op.output[g_output_idx] = split_name
if grad_def.HasField('device_option'):
gather_def.device_option.CopyFrom(grad_def.device_option)
backward_defs.append(gather_def)
grad_def.output[g_output_idx] = split_name
# Done.
if not is_skip:
for name, grad in zip(forward_op.input, g_inputs):
for name, grad in zip(op_def.input, g_inputs):
if grad != '':
inputs_to_grads[name] = grad
return forward_ops, backward_ops, is_dag
return backward_defs
......@@ -30,9 +30,9 @@ class OpInfo(object):
self._defs = dict()
self._targets = set()
def add_def(self, idx, op_def):
def add_def(self, index, op_def):
"""Add a operator definition."""
self._defs[idx] = op_def
self._defs[index] = op_def
def add_target(self, target):
"""Add an extra target relied by inputs."""
......@@ -74,13 +74,14 @@ class OpDef(object):
# Create outputs.
if outputs is None:
outputs = []
current_ws = workspace.get_workspace()
name_scope = context.get_name_scope()
for i in range(num_outputs):
outputs.append(TensorRef(
workspace.get_dummy_name(
current_ws.unique_name(
name_scope + (name if name else op_type),
suffix=':{}'.format(i),
domain='Tensor')))
namespace='Tensor')))
else:
outputs = nest.flatten(outputs)
num_outputs = len(outputs)
......@@ -124,13 +125,13 @@ class OpDef(object):
return spec_func(arguments, inputs, outputs)
@staticmethod
def get_index_and_name(prefix='Op'):
def get_index_and_name():
"""Return an unique op name and index."""
name = workspace.get_dummy_name(
prefix, domain='Operator', zero_based=False)
name = workspace.get_workspace().unique_name(
'Op', namespace='Op', zero_based=False)
return int(name.split('_')[-1]), name
@staticmethod
def get_name(prefix='Op'):
def get_name():
"""Return an unique op name."""
return OpDef.get_index_and_name(prefix)[1]
return OpDef.get_index_and_name()[1]
......@@ -190,24 +190,28 @@ def conv_spec(args, inputs, outputs):
out_shape = None
try:
out_shape = inputs[0].shape[:]
num_axes = len(out_shape) - 2
channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
if 'out_channels' in args:
out_shape[channel_axis] = args['out_channels']
else:
out_shape[channel_axis] = inputs[1].shape[0]
for i in range(len(out_shape) - 2):
input_size = out_shape[i + spatial_axis]
k = args['kernel_shape'][i]
s = args['strides'][i]
pl, pr = args['pads'][i], args['pads'][i + 2]
dk, dp = (k - 1) + 1, pl + pr
if 'SAME' not in args['padding']:
out_shape[i + spatial_axis] = \
int(float(input_size + dp - dk) / s) + 1
else:
out_shape[i + spatial_axis] = \
int(float(input_size + s - 1) / s)
for i in range(num_axes):
try:
k = args['kernel_shape'][i]
s = args['strides'][i]
d = args['dilations'][i]
in_size = out_shape[i + spatial_axis]
k_size = d * (k - 1) + 1
if 'SAME' not in args['padding']:
pad_size = args['pads'][i] + args['pads'][i + num_axes]
out_size = (in_size + pad_size - k_size) // s + 1
else:
out_size = (in_size + s - 1) // s
except IndexError:
out_size = None
out_shape[i + spatial_axis] = out_size
except (TypeError, IndexError):
pass
outputs[0].shape = out_shape
......@@ -220,30 +224,33 @@ def conv_transpose_spec(args, inputs, outputs):
out_shape = None
try:
out_shape = inputs[0].shape[:]
num_axes = len(out_shape) - 2
channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
if 'out_channels' in args:
out_shape[channel_axis] = args['out_channels']
else:
out_shape[channel_axis] = inputs[1].shape[1]
for i in range(len(out_shape) - 2):
k = args['kernel_shape'][i]
s = args['strides'][i]
d = args['dilations'][i]
pl, pr = args['pads'][i], args['pads'][i + 2]
dk, dp = d * (k - 1) + 1, pl + pr
input_size = out_shape[i + spatial_axis]
if 'SAME' not in args['padding']:
out_shape[i + spatial_axis] = s * \
(input_size - 1) + dk - dp
else:
out_shape[i + spatial_axis] = None
if args['output_padding'] is not None:
out_shape[i + spatial_axis] = \
s * (input_size - 1) + dk + \
args['output_padding'][i]
elif args['output_shape'] is not None:
out_shape[i + spatial_axis] = args['output_shape'][i]
for i in range(num_axes):
try:
k = args['kernel_shape'][i]
s = args['strides'][i]
d = args['dilations'][i]
in_size = out_shape[i + spatial_axis]
k_size = d * (k - 1) + 1
if 'SAME' not in args['padding']:
pad_size = args['pads'][i] + args['pads'][i + num_axes]
out_size = s * (in_size - 1) + k_size - pad_size
if 'output_padding' in args and args['output_padding']:
out_size += args['output_padding'][i]
else:
if 'output_shape' in args and args['output_shape']:
out_size = args['output_shape'][i]
else:
out_size = None
except IndexError:
out_size = None
out_shape[i + spatial_axis] = out_size
except (TypeError, IndexError):
pass
outputs[0].shape = out_shape
......@@ -606,21 +613,24 @@ def pool_spec(args, inputs, outputs):
out_shape = None
try:
out_shape = inputs[0].shape[:]
num_axes = len(out_shape) - 2
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
for i in range(len(out_shape) - 2):
k = args['kernel_shape'][i]
s = args['strides'][i]
pl, pr = args['pads'][i], args['pads'][i + 2]
for i in range(num_axes):
if not args['global_pooling']:
floor_or_ceil = math.ceil if args['ceil_mode'] else math.floor
if 'SAME' not in args['padding']:
in_size = out_shape[i + spatial_axis] + pl + pr
out_size = int(floor_or_ceil(float(in_size - k) / s) + 1)
out_shape[i + spatial_axis] = out_size
else:
try:
k = args['kernel_shape'][i]
s = args['strides'][i]
in_size = out_shape[i + spatial_axis]
out_size = int(floor_or_ceil(float(in_size) / s))
out_shape[i + spatial_axis] = out_size
if 'SAME' not in args['padding']:
floor_or_ceil = math.ceil if args['ceil_mode'] else math.floor
pad_size = args['pads'][i] + args['pads'][i + num_axes]
out_size = float(in_size + pad_size - k) / float(s) + 1
out_size = floor_or_ceil(out_size)
else:
out_size = math.ceil(float(in_size) / float(s))
except IndexError:
out_size = None
out_shape[i + spatial_axis] = out_size
else:
out_shape[i + spatial_axis] = 1
except (TypeError, IndexError):
......@@ -959,14 +969,14 @@ def stack_spec(args, inputs, outputs):
@register('Tile')
def tile_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype
multiples = args['multiples']
if multiples is not None:
repeats = args['repeats']
if repeats is not None:
try:
out_shape = inputs[0].shape[:]
for i, multiple in enumerate(multiples):
for i, size in enumerate(repeats):
if i < len(out_shape):
try:
out_shape[i] *= multiple
out_shape[i] *= size
except TypeError:
out_shape[i] = None
outputs[0].shape = out_shape
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!