Commit adb6fa64 by Ting PAN

Add native ops test

Summary:
This commit tests the executing of native ops and verifies the results.
Several bugs are found and fixed according to these tests.
1 parent df172cc8
Showing with 971 additions and 1258 deletions
...@@ -9,13 +9,13 @@ ...@@ -9,13 +9,13 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Implementation for the ``Layer`` C++ class.""" """The base layer class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.autograph.tensor import RefTensor from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context as eager_context from dragon.core.eager import context as eager_context
from dragon.core.framework import context from dragon.core.framework import context
...@@ -76,8 +76,8 @@ class Layer(object): ...@@ -76,8 +76,8 @@ class Layer(object):
param_name = scoped_name + '/param:{}'.format(len(self._blobs)) param_name = scoped_name + '/param:{}'.format(len(self._blobs))
# Set the name explicitly. # Set the name explicitly.
variable = RefTensor(param_name) variable = TensorRef(param_name)
variable_grad = RefTensor(param_name + '_grad') variable_grad = TensorRef(param_name + '_grad')
if filler is not None: if filler is not None:
variable._register_as(**filler) variable._register_as(**filler)
......
...@@ -455,8 +455,8 @@ class InnerProduct(Layer): ...@@ -455,8 +455,8 @@ class InnerProduct(Layer):
param = layer_param.inner_product_param param = layer_param.inner_product_param
self.arguments = { self.arguments = {
'axis': param.axis, 'axis': param.axis,
'num_output': param.num_output, 'out_channels': param.num_output,
'transW': not param.transpose, 'transpose_w': not param.transpose,
} }
# Add weights and biases # Add weights and biases
self.add_blob(filler=self.get_filler(param, 'weight_filler')) self.add_blob(filler=self.get_filler(param, 'weight_filler'))
...@@ -522,7 +522,7 @@ class Normalize(Layer): ...@@ -522,7 +522,7 @@ class Normalize(Layer):
normalize_param { normalize_param {
across_spatial: false across_spatial: false
channel_shared: false channel_shared: false
eps: 1e-5 eps: 1e-12
scale_filler: { scale_filler: {
type: "constant" type: "constant"
value: 1 value: 1
...@@ -548,7 +548,7 @@ class Normalize(Layer): ...@@ -548,7 +548,7 @@ class Normalize(Layer):
self.add_blob(filler=self.get_filler(param, 'scale_filler'), value=1) self.add_blob(filler=self.get_filler(param, 'scale_filler'), value=1)
def __call__(self, bottom): def __call__(self, bottom):
norm_out = [normalization_ops.l2_normalize(bottom, **self.l2norm_arguments)] norm_out = [normalization_ops.lp_normalize(bottom, **self.l2norm_arguments)]
norm_out += [blob['data'] for blob in self._blobs] norm_out += [blob['data'] for blob in self._blobs]
return math_ops.affine(norm_out, **self.affine_arguments) return math_ops.affine(norm_out, **self.affine_arguments)
......
...@@ -65,7 +65,7 @@ class Convolution(Layer): ...@@ -65,7 +65,7 @@ class Convolution(Layer):
super(Convolution, self).__init__(layer_param) super(Convolution, self).__init__(layer_param)
param = layer_param.convolution_param param = layer_param.convolution_param
self.arguments = { self.arguments = {
'num_output': param.num_output, 'out_channels': param.num_output,
'kernel_shape': [int(e) for e in param.kernel_size], 'kernel_shape': [int(e) for e in param.kernel_size],
'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1], 'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1],
'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0], 'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0],
...@@ -187,7 +187,7 @@ class DepthwiseConv2d(Layer): ...@@ -187,7 +187,7 @@ class DepthwiseConv2d(Layer):
super(DepthwiseConv2d, self).__init__(layer_param) super(DepthwiseConv2d, self).__init__(layer_param)
param = layer_param.convolution_param param = layer_param.convolution_param
self.arguments = { self.arguments = {
'num_output': param.num_output, 'out_channels': param.num_output,
'kernel_shape': [int(e) for e in param.kernel_size], 'kernel_shape': [int(e) for e in param.kernel_size],
'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1], 'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1],
'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0], 'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0],
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Implementation for the ``Net`` C++ class.""" """The base net class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -20,8 +20,8 @@ from google.protobuf import text_format ...@@ -20,8 +20,8 @@ from google.protobuf import text_format
from dragon.core.autograph import def_function from dragon.core.autograph import def_function
from dragon.core.autograph import grad_impl from dragon.core.autograph import grad_impl
from dragon.core.autograph.tensor import RefTensor
from dragon.core.autograph.tensor import Tensor from dragon.core.autograph.tensor import Tensor
from dragon.core.autograph.tensor import TensorRef
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.util import nest from dragon.core.util import nest
from dragon.vm.caffe import layers as layer_factory from dragon.vm.caffe import layers as layer_factory
...@@ -84,17 +84,13 @@ class Net(object): ...@@ -84,17 +84,13 @@ class Net(object):
if len(self._net_proto.input) > 0: if len(self._net_proto.input) > 0:
shapes = self._net_proto.input_shape shapes = self._net_proto.input_shape
for i, input in enumerate(self._net_proto.input): for i, input_name in enumerate(self._net_proto.input):
shape = [e for e in shapes[i].dim] if i < len(shapes) else None shape = [e for e in shapes[i].dim] if i < len(shapes) else None
if input not in self._blobs: if input not in self._blobs:
data = Tensor(input, shape=shape, dtype='float32').placeholder() data = Tensor(input_name, shape, 'float32').placeholder()
self._blobs[input] = { self._blobs[input_name] = {
'data': data, 'data': data,
'diff': RefTensor( 'diff': TensorRef(data.id + '_grad', shape, data.dtype),
data.id + '_grad',
shape=shape,
dtype=data.dtype
),
} }
for layer in self._net_proto.layer: for layer in self._net_proto.layer:
...@@ -145,7 +141,7 @@ class Net(object): ...@@ -145,7 +141,7 @@ class Net(object):
for i, blob in enumerate(layer._top): for i, blob in enumerate(layer._top):
self._blobs[blob] = { self._blobs[blob] = {
'data': outputs[i], 'data': outputs[i],
'diff': RefTensor(outputs[i].id + '_grad'), 'diff': TensorRef(outputs[i].id + '_grad'),
} }
self._net_outputs.add(blob) self._net_outputs.add(blob)
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Implementation for the ``Solver`` C++ class.""" """The solver to update parameters."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -19,9 +19,12 @@ import time ...@@ -19,9 +19,12 @@ import time
from google.protobuf import text_format from google.protobuf import text_format
from dragon import updaters
from dragon.core.autograph import def_function from dragon.core.autograph import def_function
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.training.adam import Adam
from dragon.core.training.rmsprop import RMSprop
from dragon.core.training.sgd import SGD
from dragon.core.training.sgd import Nesterov
from dragon.vm.caffe.net import Net from dragon.vm.caffe.net import Net
from dragon.vm.caffe.proto import caffe_pb2 from dragon.vm.caffe.proto import caffe_pb2
...@@ -47,10 +50,10 @@ class Solver(object): ...@@ -47,10 +50,10 @@ class Solver(object):
if self._param.iter_size > 1: if self._param.iter_size > 1:
raise NotImplementedError('GradientAccum is deprecated.') raise NotImplementedError('GradientAccum is deprecated.')
self._arguments = { self._arguments = {
'scale_gradient': 1. / self._param.iter_size, 'scale': 1. / self._param.iter_size,
'clip_gradient': float(self._param.clip_gradients), 'clip_norm': float(self._param.clip_gradients),
'l2_decay': float(self._param.weight_decay) 'weight_decay': float(self._param.weight_decay)
if str(self._param.regularization_type) == 'L2' else -1., if str(self._param.regularization_type) == 'L2' else 0,
} }
self._optimizer = None self._optimizer = None
self._net, self._test_nets = None, [] self._net, self._test_nets = None, []
...@@ -415,7 +418,7 @@ class AdamSolver(Solver): ...@@ -415,7 +418,7 @@ class AdamSolver(Solver):
self._arguments['beta1'] = self._param.momentum self._arguments['beta1'] = self._param.momentum
self._arguments['beta2'] = self._param.momentum2 self._arguments['beta2'] = self._param.momentum2
self._arguments['eps'] = self._param.delta self._arguments['eps'] = self._param.delta
self._optimizer = updaters.Adam(**self._arguments) self._optimizer = Adam(**self._arguments)
class NesterovSolver(Solver): class NesterovSolver(Solver):
...@@ -447,7 +450,7 @@ class NesterovSolver(Solver): ...@@ -447,7 +450,7 @@ class NesterovSolver(Solver):
super(NesterovSolver, self).__init__(solver_file, is_root) super(NesterovSolver, self).__init__(solver_file, is_root)
self._arguments['base_lr'] = self._param.base_lr self._arguments['base_lr'] = self._param.base_lr
self._arguments['momentum'] = self._param.momentum self._arguments['momentum'] = self._param.momentum
self._optimizer = updaters.Nesterov(**self._arguments) self._optimizer = Nesterov(**self._arguments)
class RMSPropSolver(Solver): class RMSPropSolver(Solver):
...@@ -481,7 +484,7 @@ class RMSPropSolver(Solver): ...@@ -481,7 +484,7 @@ class RMSPropSolver(Solver):
self._arguments['base_lr'] = self._param.base_lr self._arguments['base_lr'] = self._param.base_lr
self._arguments['decay'] = self._param.rms_decay self._arguments['decay'] = self._param.rms_decay
self._arguments['eps'] = self._param.delta self._arguments['eps'] = self._param.delta
self._optimizer = updaters.RMSProp(**self._arguments) self._optimizer = RMSprop(**self._arguments)
class SGDSolver(Solver): class SGDSolver(Solver):
...@@ -513,4 +516,4 @@ class SGDSolver(Solver): ...@@ -513,4 +516,4 @@ class SGDSolver(Solver):
super(SGDSolver, self).__init__(solver_file, is_root) super(SGDSolver, self).__init__(solver_file, is_root)
self._arguments['base_lr'] = self._param.base_lr self._arguments['base_lr'] = self._param.base_lr
self._arguments['momentum'] = self._param.momentum self._arguments['momentum'] = self._param.momentum
self._optimizer = updaters.SGD(**self._arguments) self._optimizer = SGD(**self._arguments)
...@@ -9,9 +9,6 @@ dragon.math ...@@ -9,9 +9,6 @@ dragon.math
`abs(...) <math/abs.html>`_ `abs(...) <math/abs.html>`_
: Compute the absolute value of input. : Compute the absolute value of input.
`accumulate(...) <math/accumulate.html>`_
: Compute the element-wise accumulation from input to output.
`add(...) <math/add.html>`_ `add(...) <math/add.html>`_
: Compute the element-wise addition. : Compute the element-wise addition.
...@@ -24,6 +21,9 @@ dragon.math ...@@ -24,6 +21,9 @@ dragon.math
`argmin(...) <math/argmin.html>`_ `argmin(...) <math/argmin.html>`_
: Compute the indices of minimum elements along the given axis. : Compute the indices of minimum elements along the given axis.
`axpby(...) <math/axpby.html>`_
: Compute the element-wise addition from input to output.
`ceil(...) <math/ceil.html>`_ `ceil(...) <math/ceil.html>`_
: Compute the smallest integer not less than input. : Compute the smallest integer not less than input.
...@@ -96,9 +96,6 @@ dragon.math ...@@ -96,9 +96,6 @@ dragon.math
`moments(...) <math/moments.html>`_ `moments(...) <math/moments.html>`_
: Compute the mean and variance of input along the given axes. : Compute the mean and variance of input along the given axes.
`moving_average(...) <math/moving_average.html>`_
: Compute the moving average of input to output.
`mul(...) <math/mul.html>`_ `mul(...) <math/mul.html>`_
: Compute the element-wise multiplication. : Compute the element-wise multiplication.
...@@ -148,11 +145,11 @@ dragon.math ...@@ -148,11 +145,11 @@ dragon.math
:hidden: :hidden:
math/abs math/abs
math/accumulate
math/add math/add
math/affine math/affine
math/argmax math/argmax
math/argmin math/argmin
math/axpby
math/ceil math/ceil
math/clip math/clip
math/cos math/cos
...@@ -177,7 +174,6 @@ dragon.math ...@@ -177,7 +174,6 @@ dragon.math
math/min math/min
math/minimum math/minimum
math/moments math/moments
math/moving_average
math/mul math/mul
math/negative math/negative
math/not_equal math/not_equal
......
accumulate axpby
========== =====
.. autofunction:: dragon.math.accumulate .. autofunction:: dragon.math.axpby
.. raw:: html .. raw:: html
......
moving_average
==============
.. autofunction:: dragon.math.moving_average
.. raw:: html
<style>
h1:before {
content: "dragon.math.";
color: #103d3e;
}
</style>
dragon.updaters dragon.optimizers
=============== =================
.. only:: html .. only:: html
Classes Classes
------- -------
`class Adam <updaters/Adam.html>`_ `class Adam <optimizers/Adam.html>`_
: The updater which implements Adam algorithm. : The optimizer to apply Adam algorithm.
`[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_. `[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_.
`class Nesterov <updaters/Nesterov.html>`_ `class Nesterov <optimizers/Nesterov.html>`_
: The updater which implements NesterovSGD algorithm. : The optimizer to apply NesterovSGD algorithm.
`[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_. `[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_.
`class RMSProp <updaters/RMSProp.html>`_ `class RMSProp <optimizers/RMSprop.html>`_
: The updater which implements RMSprop algorithm. : The optimizer to apply RMSprop algorithm.
`[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_. `[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_.
`class SGD <updaters/SGD.html>`_ `class SGD <optimizers/SGD.html>`_
: The updater which implements MomentumSGD algorithm. : The optimizer to apply MomentumSGD algorithm.
`[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_. `[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_.
.. toctree:: .. toctree::
:hidden: :hidden:
updaters/Adam optimizers/Adam
updaters/Nesterov optimizers/Nesterov
updaters/RMSProp optimizers/Optimizer
updaters/SGD optimizers/RMSprop
optimizers/SGD
.. raw:: html .. raw:: html
......
Adam Adam
==== ====
.. autoclass:: dragon.updaters.Adam .. autoclass:: dragon.optimizers.Adam
__init__ __init__
-------- --------
.. automethod:: dragon.updaters.Adam.__init__ .. automethod:: dragon.optimizers.Adam.__init__
Methods Methods
------- -------
apply_gradients apply_gradients
################ ################
.. automethod:: dragon.updaters.Updater.apply_gradients .. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex: :noindex:
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.updaters."; content: "dragon.optimizers.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
Nesterov Nesterov
======== ========
.. autoclass:: dragon.updaters.Nesterov .. autoclass:: dragon.optimizers.Nesterov
__init__ __init__
-------- --------
.. automethod:: dragon.updaters.Nesterov.__init__ .. automethod:: dragon.optimizers.Nesterov.__init__
Methods Methods
------- -------
apply_gradients apply_gradients
################ ################
.. automethod:: dragon.updaters.Updater.apply_gradients .. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex: :noindex:
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.updaters."; content: "dragon.optimizers.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
Optimizer
=========
.. autoclass:: dragon.optimizers.Optimizer
__init__
--------
.. automethod:: dragon.optimizers.Optimizer.__init__
Methods
-------
apply_gradients
################
.. automethod:: dragon.optimizers.Optimizer.apply_gradients
.. raw:: html
<style>
h1:before {
content: "dragon.optimizers.";
color: #103d3e;
}
</style>
RMSProp RMSprop
======= =======
.. autoclass:: dragon.updaters.RMSProp .. autoclass:: dragon.optimizers.RMSprop
__init__ __init__
-------- --------
.. automethod:: dragon.updaters.RMSProp.__init__ .. automethod:: dragon.optimizers.RMSprop.__init__
Methods Methods
------- -------
apply_gradients apply_gradients
################ ################
.. automethod:: dragon.updaters.Updater.apply_gradients .. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex: :noindex:
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.updaters."; content: "dragon.optimizers.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
SGD SGD
=== ===
.. autoclass:: dragon.updaters.SGD .. autoclass:: dragon.optimizers.SGD
__init__ __init__
-------- --------
.. automethod:: dragon.updaters.SGD.__init__ .. automethod:: dragon.optimizers.SGD.__init__
Methods Methods
------- -------
apply_gradients apply_gradients
################ ################
.. automethod:: dragon.updaters.Updater.apply_gradients .. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex: :noindex:
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.updaters."; content: "dragon.optimizers.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -14,15 +14,15 @@ For using it, import as follows: ...@@ -14,15 +14,15 @@ For using it, import as follows:
However, it will not help you much because you do not want to learn it. However, it will not help you much because you do not want to learn it.
We have extended it with following programming styles: To resolve this matter, we are concerned to design diverse styles for you:
Dragon Dragon
###### ######
*Dragon* takes a very light-weight programming style. *Dragon* is initially as a light-weight but professional style.
Our goal is to reduce unnecessary structures or interfaces. Therefore, Native interfaces are encouraged to manipulate the backend engine
in addition to feed or fetch, the last thing is designing a function. to perform the computation flexibly with data feeding or fetching.
This style involves the following components: This style involves the following components:
...@@ -38,15 +38,15 @@ Dragon ...@@ -38,15 +38,15 @@ Dragon
* `dragon.math <dragon/math.html>`_ * `dragon.math <dragon/math.html>`_
* `dragon.metrics <dragon/metrics.html>`_ * `dragon.metrics <dragon/metrics.html>`_
* `dragon.nn <dragon/nn.html>`_ * `dragon.nn <dragon/nn.html>`_
* `dragon.optimizers <dragon/optimizers.html>`_
* `dragon.random <dragon/random.html>`_ * `dragon.random <dragon/random.html>`_
* `dragon.updaters <dragon/updaters.html>`_
* `dragon.vision <dragon/vision.html>`_
* `dragon.workspace <dragon/workspace.html>`_ * `dragon.workspace <dragon/workspace.html>`_
* `dragon.vision <dragon/vision.html>`_
Caffe Caffe
##### #####
*Caffe* is one of the most famous deep learning framework for Computer Vision. *Caffe* is the most famous framework for vision.
Our work is very different from the official python wrappers, a.k.a, Our work is very different from the official python wrappers, a.k.a,
the *PyCaffe*, which comes from the exports of *BoostPython* the *PyCaffe*, which comes from the exports of *BoostPython*
...@@ -102,7 +102,7 @@ PyTorch ...@@ -102,7 +102,7 @@ PyTorch
*PyTorch* provides straight-forward operations on research prototyping. *PyTorch* provides straight-forward operations on research prototyping.
To bridge it, our *JIT* traces and dispatches the expressions, To bridge it, our *JIT* traces and dispatches the operations,
as well as the rewriting of *GC* (Garbage Collection) to reuse as well as the rewriting of *GC* (Garbage Collection) to reuse
the memories and operators by turns. the memories and operators by turns.
...@@ -168,52 +168,52 @@ Modules ...@@ -168,52 +168,52 @@ Modules
.. only:: html .. only:: html
`Module autograph <dragon/autograph.html>`_ `Module autograph <dragon/autograph.html>`_
: Public API for ``dragon.autograph`` namespace. : Native API for ``dragon.autograph`` namespace.
`Module bitwise <dragon/bitwise.html>`_ `Module bitwise <dragon/bitwise.html>`_
: Public API for ``dragon.bitwise`` namespace. : Native API for ``dragon.bitwise`` namespace.
`Module cuda <dragon/cuda.html>`_ `Module cuda <dragon/cuda.html>`_
: Public API for ``dragon.cuda`` namespace. : Native API for ``dragon.cuda`` namespace.
`Module distributed <dragon/distributed.html>`_ `Module distributed <dragon/distributed.html>`_
: Public API for ``dragon.distributed`` namespace. : Native API for ``dragon.distributed`` namespace.
`Module dlpack <dragon/dlpack.html>`_ `Module dlpack <dragon/dlpack.html>`_
: Public API for ``dragon.dlpack`` namespace. : Native API for ``dragon.dlpack`` namespace.
`Module io <dragon/io.html>`_ `Module io <dragon/io.html>`_
: Public API for ``dragon.io`` namespace. : Native API for ``dragon.io`` namespace.
`Module logging <dragon/logging.html>`_ `Module logging <dragon/logging.html>`_
: Public API for ``dragon.logging`` namespace. : Native API for ``dragon.logging`` namespace.
`Module losses <dragon/losses.html>`_ `Module losses <dragon/losses.html>`_
: Public API for ``dragon.losses`` namespace. : Native API for ``dragon.losses`` namespace.
`Module math <dragon/math.html>`_ `Module math <dragon/math.html>`_
: Public API for ``dragon.math`` namespace. : Native API for ``dragon.math`` namespace.
`Module metrics <dragon/metrics.html>`_ `Module metrics <dragon/metrics.html>`_
: Public API for ``dragon.metrics`` namespace. : Native API for ``dragon.metrics`` namespace.
`Module nn <dragon/nn.html>`_ `Module nn <dragon/nn.html>`_
: Public API for ``dragon.nn`` namespace. : Native API for ``dragon.nn`` namespace.
`Module optimizers <dragon/optimizers.html>`_
: Native API for ``dragon.optimizers`` namespace.
`Module random <dragon/random.html>`_ `Module random <dragon/random.html>`_
: Public API for ``dragon.random`` namespace. : Native API for ``dragon.random`` namespace.
`Module updaters <dragon/updaters.html>`_ `Module workspace <dragon/workspace.html>`_
: Public API for ``dragon.updaters`` namespace. : Native API for ``dragon.workspace`` namespace.
`Module vision <dragon/vision.html>`_ `Module vision <dragon/vision.html>`_
: Public API for ``dragon.vision`` namespace. : Native API for ``dragon.vision`` namespace.
`Module workspace <dragon/workspace.html>`_ `Module workspace <dragon/workspace.html>`_
: Public API for ``dragon.workspace`` namespace. : Native API for ``dragon.workspace`` namespace.
`Module workspace <dragon/workspace.html>`_
: Public API for ``dragon.workspace`` namespace.
`Module vm.caffe <caffe.html>`_ `Module vm.caffe <caffe.html>`_
: Virtual API for ``caffe`` namespace. : Virtual API for ``caffe`` namespace.
...@@ -317,10 +317,10 @@ Modules ...@@ -317,10 +317,10 @@ Modules
dragon/math dragon/math
dragon/metrics dragon/metrics
dragon/nn dragon/nn
dragon/optimizers
dragon/random dragon/random
dragon/updaters
dragon/vision
dragon/workspace dragon/workspace
dragon/vision
caffe caffe
caffe/layers caffe/layers
dali dali
......
...@@ -30,9 +30,6 @@ vm.torch ...@@ -30,9 +30,6 @@ vm.torch
`abs(...) <torch/abs.html>`_ `abs(...) <torch/abs.html>`_
: Compute the absolute value of input. : Compute the absolute value of input.
`accumulate(...) <torch/accumulate.html>`_
: Compute the element-wise accumulation from input to output.
`add(...) <torch/add.html>`_ `add(...) <torch/add.html>`_
: Compute the element-wise addition. : Compute the element-wise addition.
...@@ -45,6 +42,9 @@ vm.torch ...@@ -45,6 +42,9 @@ vm.torch
`argmin(...) <torch/argmin.html>`_ `argmin(...) <torch/argmin.html>`_
: Return the indices of minimum elements along the given axis. : Return the indices of minimum elements along the given axis.
`axpby(...) <torch/axpby.html>`_
: Compute the element-wise addition from input to output.
`bitwise_not(...) <torch/bitwise_not.html>`_ `bitwise_not(...) <torch/bitwise_not.html>`_
: Compute the element-wise NOT bitwise operation. : Compute the element-wise NOT bitwise operation.
...@@ -254,11 +254,11 @@ vm.torch ...@@ -254,11 +254,11 @@ vm.torch
:hidden: :hidden:
torch/abs torch/abs
torch/accumulate
torch/add torch/add
torch/arange torch/arange
torch/argmax torch/argmax
torch/argmin torch/argmin
torch/axpby
torch/bitwise_not torch/bitwise_not
torch/bitwise_xor torch/bitwise_xor
torch/cat torch/cat
......
accumulate axpby
========== =====
.. autofunction:: dragon.vm.torch.accumulate .. autofunction:: dragon.vm.torch.axpby
.. raw:: html .. raw:: html
......
...@@ -50,18 +50,18 @@ class CUDAObject { ...@@ -50,18 +50,18 @@ class CUDAObject {
*/ */
if (stream) cudaStreamDestroy(stream); if (stream) cudaStreamDestroy(stream);
} }
for (auto& e : cublas_handles_[i]) for (auto& handle : cublas_handles_[i])
if (e) { if (handle) {
CUBLAS_CHECK(cublasDestroy_v2(e)); CUBLAS_CHECK(cublasDestroy(handle));
} }
#ifdef USE_CUDNN #ifdef USE_CUDNN
for (auto& e : cudnn_handles_[i]) for (auto& handle : cudnn_handles_[i])
if (e) { if (handle) {
CUDNN_CHECK(cudnnDestroy(e)); CUDNN_CHECK(cudnnDestroy(handle));
} }
#endif #endif
#ifdef USE_NCCL #ifdef USE_NCCL
for (auto& e : nccl_comms_[i]) { for (auto& comm : nccl_comms_[i]) {
/*! /*!
* Temporarily disable the comm destroying, * Temporarily disable the comm destroying,
* to avoid an unhandled error. * to avoid an unhandled error.
...@@ -74,17 +74,18 @@ class CUDAObject { ...@@ -74,17 +74,18 @@ class CUDAObject {
/*! \brief Return the specified cublas handle */ /*! \brief Return the specified cublas handle */
cublasHandle_t cublas_handle(int device_id, int stream_id) { cublasHandle_t cublas_handle(int device_id, int stream_id) {
auto& handles = cublas_handles_[device_id]; auto& handles = cublas_handles_[device_id];
if (handles.size() <= (unsigned)stream_id) if (handles.size() <= (unsigned)stream_id) {
handles.resize(stream_id + 1, nullptr); handles.resize(stream_id + 1, nullptr);
}
if (!handles[stream_id]) { if (!handles[stream_id]) {
CUDADeviceGuard guard(device_id); CUDADeviceGuard guard(device_id);
CUBLAS_CHECK(cublasCreate_v2(&handles[stream_id])); CUBLAS_CHECK(cublasCreate(&handles[stream_id]));
CUBLAS_CHECK( auto& handle = handles[stream_id];
cublasSetStream_v2(handles[stream_id], stream(device_id, stream_id))); CUBLAS_CHECK(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSetStream(handle, stream(device_id, stream_id)));
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) { if (TENSOR_CORE_AVAILABLE()) {
CUBLAS_CHECK( CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
cublasSetMathMode(handles[stream_id], CUBLAS_TENSOR_OP_MATH));
} }
#endif #endif
} }
...@@ -95,13 +96,14 @@ class CUDAObject { ...@@ -95,13 +96,14 @@ class CUDAObject {
#ifdef USE_CUDNN #ifdef USE_CUDNN
cudnnHandle_t cudnn_handle(int device_id, int stream_id) { cudnnHandle_t cudnn_handle(int device_id, int stream_id) {
auto& handles = cudnn_handles_[device_id]; auto& handles = cudnn_handles_[device_id];
if (handles.size() <= (unsigned)stream_id) if (handles.size() <= (unsigned)stream_id) {
handles.resize(stream_id + 1, nullptr); handles.resize(stream_id + 1, nullptr);
}
if (!handles[stream_id]) { if (!handles[stream_id]) {
CUDADeviceGuard guard(device_id); CUDADeviceGuard guard(device_id);
CUDNN_CHECK(cudnnCreate(&handles[stream_id])); CUDNN_CHECK(cudnnCreate(&handles[stream_id]));
CUDNN_CHECK( auto& handle = handles[stream_id];
cudnnSetStream(handles[stream_id], stream(device_id, stream_id))); CUDNN_CHECK(cudnnSetStream(handle, stream(device_id, stream_id)));
} }
return handles[stream_id]; return handles[stream_id];
} }
...@@ -144,7 +146,7 @@ class CUDAObject { ...@@ -144,7 +146,7 @@ class CUDAObject {
if (!streams[stream_id]) { if (!streams[stream_id]) {
CUDADeviceGuard guard(device_id); CUDADeviceGuard guard(device_id);
unsigned int flags = unsigned int flags =
!stream_id ? cudaStreamDefault : cudaStreamNonBlocking; stream_id == 0 ? cudaStreamDefault : cudaStreamNonBlocking;
CUDA_CHECK(cudaStreamCreateWithFlags(&streams[stream_id], flags)); CUDA_CHECK(cudaStreamCreateWithFlags(&streams[stream_id], flags));
} }
return streams[stream_id]; return streams[stream_id];
......
...@@ -80,7 +80,7 @@ Tensor* OperatorBase::Output(int i, const vec32_t& inputs) { ...@@ -80,7 +80,7 @@ Tensor* OperatorBase::Output(int i, const vec32_t& inputs) {
} }
Tensor* OperatorBase::Buffer(const string& name) { Tensor* OperatorBase::Buffer(const string& name) {
return ws()->CreateTensor(unique_name(name)); return ws()->CreateTensor("/share/buffer/" + handle_ + "/" + name);
} }
string OperatorBase::TypeString(const Tensor& tensor, const Set<string>& types) string OperatorBase::TypeString(const Tensor& tensor, const Set<string>& types)
......
...@@ -133,11 +133,6 @@ class DRAGON_API OperatorBase { ...@@ -133,11 +133,6 @@ class DRAGON_API OperatorBase {
return handle_; return handle_;
} }
/*! \brief Return the unique name in this operator */
const string unique_name(const string& name) const {
return "/mnt/" + handle_ + "/" + name;
}
/*! \brief Return the stored def */ /*! \brief Return the stored def */
const OperatorDef& def() const { const OperatorDef& def() const {
return def_; return def_;
...@@ -268,7 +263,6 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*); ...@@ -268,7 +263,6 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
using OperatorBase::dtype; \ using OperatorBase::dtype; \
using OperatorBase::data_format; \ using OperatorBase::data_format; \
using OperatorBase::handle; \ using OperatorBase::handle; \
using OperatorBase::unique_name; \
using OperatorBase::def; \ using OperatorBase::def; \
using OperatorBase::ws using OperatorBase::ws
...@@ -278,16 +272,17 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*); ...@@ -278,16 +272,17 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
using Operator<Context>::ctx using Operator<Context>::ctx
#define STORE_INPUT_SPEC(i) \ #define STORE_INPUT_SPEC(i) \
*(ws()->CreateTensor(unique_name("Input[" + std::to_string(i) + "]")) \ *(Buffer("X_spec:" + std::to_string(i)) \
->ReshapeLike(Input(i)) \ ->ReshapeLike(Input(i)) \
->set_meta(Input(i).meta())) ->set_meta(Input(i).meta()))
#define RESTORE_INPUT_SPEC(i) \ #define RESTORE_INPUT_SPEC(i) \
*(ws()->GetTensor(unique_name("Input[" + std::to_string(i) + "]"))) *(ws()->GetTensor( \
"/share/buffer/" + handle() + "/X_spec:" + std::to_string(i)))
/* Dispatchers */ /* Dispatchers */
#define XIsType(x, type) x.template IsType<type>() #define XIsType(X, type) X.template IsType<type>()
template <typename... Types> template <typename... Types>
struct TensorTypes {}; struct TensorTypes {};
......
...@@ -53,14 +53,11 @@ __global__ void _EluGrad( ...@@ -53,14 +53,11 @@ __global__ void _EluGrad(
const T* y, const T* y,
T* dx) { T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = dy[i] *
(
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
__ldg(y + i) > T(0) ? T(1) : alpha + __ldg(y + i) dx[i] = dy[i] * (__ldg(y + i) > T(0) ? T(1) : alpha + __ldg(y + i));
#else #else
y[i] > T(0) ? T(1) : (alpha + y[i]) dx[i] = dy[i] * (y[i] > T(0) ? T(1) : (alpha + y[i]));
#endif #endif
);
} }
} }
......
...@@ -14,28 +14,28 @@ void _Softmax( ...@@ -14,28 +14,28 @@ void _Softmax(
const int inner_dim, const int inner_dim,
const T* x, const T* x,
T* y) { T* y) {
int row_ofs, col_ofs, yi; int row_offset, col_offset, yi;
auto x_stride = axis_dim * inner_dim; auto x_stride = axis_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
row_ofs = i * axis_dim * inner_dim; row_offset = i * axis_dim * inner_dim;
for (int j = 0; j < inner_dim; ++j) { for (int j = 0; j < inner_dim; ++j) {
col_ofs = row_ofs + j; col_offset = row_offset + j;
T val = x[col_ofs]; T val = x[col_offset];
for (int k = 1; k < axis_dim; ++k) { for (int k = 1; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim; yi = col_offset + k * inner_dim;
val = std::max(val, x[yi]); val = std::max(val, x[yi]);
} }
for (int k = 0; k < axis_dim; ++k) { for (int k = 0; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim; yi = col_offset + k * inner_dim;
y[yi] = std::exp(x[yi] - val); y[yi] = std::exp(x[yi] - val);
} }
val = y[col_ofs]; val = y[col_offset];
for (int k = 1; k < axis_dim; ++k) { for (int k = 1; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim; yi = col_offset + k * inner_dim;
val += y[yi]; val += y[yi];
} }
for (int k = 0; k < axis_dim; ++k) { for (int k = 0; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim; yi = col_offset + k * inner_dim;
y[yi] /= val; y[yi] /= val;
} }
} }
...@@ -60,19 +60,19 @@ void _SoftmaxGrad( ...@@ -60,19 +60,19 @@ void _SoftmaxGrad(
const T* dy, const T* dy,
const T* y, const T* y,
T* dx) { T* dx) {
int row_ofs, col_ofs, yi; int row_offset, col_offset, yi;
auto x_stride = axis_dim * inner_dim; auto x_stride = axis_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
row_ofs = i * axis_dim * inner_dim; row_offset = i * axis_dim * inner_dim;
for (int j = 0; j < inner_dim; ++j) { for (int j = 0; j < inner_dim; ++j) {
col_ofs = row_ofs + j; col_offset = row_offset + j;
T val = dy[col_ofs] * y[col_ofs]; T val = dy[col_offset] * y[col_offset];
for (int k = 1; k < axis_dim; ++k) { for (int k = 1; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim; yi = col_offset + k * inner_dim;
val += dy[yi] * y[yi]; val += dy[yi] * y[yi];
} }
for (int k = 0; k < axis_dim; ++k) { for (int k = 0; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim; yi = col_offset + k * inner_dim;
dx[yi] = (dy[yi] - val) * y[yi]; dx[yi] = (dy[yi] - val) * y[yi];
} }
} }
......
...@@ -53,11 +53,11 @@ void _CumSumReverse( ...@@ -53,11 +53,11 @@ void _CumSumReverse(
CPUContext* ctx) { CPUContext* ctx) {
const int kStart = axis_dim - 1; const int kStart = axis_dim - 1;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
const int n_ofs = n * axis_dim; const int n_offset = n * axis_dim;
for (int m = kStart; m >= 0; --m) { for (int m = kStart; m >= 0; --m) {
const int nm_ofs = (n_ofs + m) * inner_dim; const int nm_offset = (n_offset + m) * inner_dim;
for (int k = 0; k < inner_dim; ++k) { for (int k = 0; k < inner_dim; ++k) {
const int i = nm_ofs + k; const int i = nm_offset + k;
if (m < kStart) { if (m < kStart) {
const int j = i + inner_dim; const int j = i + inner_dim;
y[i] = y[j] + x[exclusive ? j : i]; y[i] = y[j] + x[exclusive ? j : i];
......
...@@ -25,9 +25,9 @@ void _SetEye(const int n, const int m, const int k, T* y) { ...@@ -25,9 +25,9 @@ void _SetEye(const int n, const int m, const int k, T* y) {
const int n, const int m, const int k, T* y, CPUContext* ctx) { \ const int n, const int m, const int k, T* y, CPUContext* ctx) { \
math::Set(n* m, cast::to<T>(0.f), y, ctx); \ math::Set(n* m, cast::to<T>(0.f), y, ctx); \
if (k > 0) { \ if (k > 0) { \
_SetEye(n - k, m, k, y); \ if (m - k > 0) _SetEye(m - k, m, k, y); \
} else { \ } else { \
_SetEye(n + k, m, 0, y - k * m); \ if (n + k > 0) _SetEye(n + k, m, 0, y - k * m); \
} \ } \
} }
......
...@@ -20,9 +20,9 @@ __global__ void _SetEye(const int n, const int m, const int k, T* y) { ...@@ -20,9 +20,9 @@ __global__ void _SetEye(const int n, const int m, const int k, T* y) {
template <> template <>
__global__ void _SetEye<half>(const int n, const int m, const int k, half* y) { __global__ void _SetEye<half>(const int n, const int m, const int k, half* y) {
const half kZero = __float2half(1.f); const half kOne = __float2half(1.f);
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
y[i * m + k + i] = kZero; y[i * m + k + i] = kOne;
} }
} }
...@@ -39,12 +39,16 @@ void Eye<float16, CUDAContext>( ...@@ -39,12 +39,16 @@ void Eye<float16, CUDAContext>(
CUDAContext* ctx) { CUDAContext* ctx) {
math::Set(n * m, cast::to<float16>(0.f), y, ctx); math::Set(n * m, cast::to<float16>(0.f), y, ctx);
if (k > 0) { if (k > 0) {
_SetEye<<<CUDA_BLOCKS(n - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( if (m - k > 0) {
n - k, m, k, reinterpret_cast<half*>(y)); _SetEye<<<CUDA_BLOCKS(m - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
m - k, m, k, reinterpret_cast<half*>(y));
}
} else { } else {
if (n + k > 0) {
_SetEye<<<CUDA_BLOCKS(n + k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _SetEye<<<CUDA_BLOCKS(n + k), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n + k, m, 0, reinterpret_cast<half*>(y - k * m)); n + k, m, 0, reinterpret_cast<half*>(y - k * m));
} }
}
} }
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
...@@ -53,12 +57,16 @@ void Eye<float16, CUDAContext>( ...@@ -53,12 +57,16 @@ void Eye<float16, CUDAContext>(
const int n, const int m, const int k, T* y, CUDAContext* ctx) { \ const int n, const int m, const int k, T* y, CUDAContext* ctx) { \
math::Set(n* m, T(0), y, ctx); \ math::Set(n* m, T(0), y, ctx); \
if (k > 0) { \ if (k > 0) { \
_SetEye<<<CUDA_BLOCKS(n - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ if (m - k > 0) { \
n - k, m, k, y); \ _SetEye<<<CUDA_BLOCKS(m - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
m - k, m, k, y); \
} \
} else { \ } else { \
if (n + k > 0) { \
_SetEye<<<CUDA_BLOCKS(n + k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _SetEye<<<CUDA_BLOCKS(n + k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n + k, m, 0, y - k * m); \ n + k, m, 0, y - k * m); \
} \ } \
} \
} }
DEFINE_KERNEL_LAUNCHER(bool); DEFINE_KERNEL_LAUNCHER(bool);
......
...@@ -47,7 +47,8 @@ void _BroadcastLossGrad<float16>( ...@@ -47,7 +47,8 @@ void _BroadcastLossGrad<float16>(
CPUContext* ctx) { \ CPUContext* ctx) { \
float inv_scale = std::max( \ float inv_scale = std::max( \
1e-5F, \ 1e-5F, \
num_masks > 0 ? (float)math::Sum(num_masks, 1.f, mask, ctx) \ num_masks > 0 && normalizer < 0.f \
? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \ : normalizer); \
y[0] = math::Sum(count, 1.f / inv_scale, x, ctx); \ y[0] = math::Sum(count, 1.f / inv_scale, x, ctx); \
} }
...@@ -64,7 +65,8 @@ void _BroadcastLossGrad<float16>( ...@@ -64,7 +65,8 @@ void _BroadcastLossGrad<float16>(
CPUContext* ctx) { \ CPUContext* ctx) { \
float inv_scale = std::max( \ float inv_scale = std::max( \
1e-5F, \ 1e-5F, \
num_masks > 0 ? (float)math::Sum(num_masks, 1.f, mask, ctx) \ num_masks > 0 && normalizer < 0.f \
? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \ : normalizer); \
math::Scale(count, cast::to<float>(dy[0]) / inv_scale, dx, dx, ctx); \ math::Scale(count, cast::to<float>(dy[0]) / inv_scale, dx, dx, ctx); \
} \ } \
......
...@@ -152,15 +152,15 @@ __global__ void _ReduceLossGradWithMask<half>( ...@@ -152,15 +152,15 @@ __global__ void _ReduceLossGradWithMask<half>(
template <typename T> template <typename T>
__global__ void _BroadcastLossGrad( __global__ void _BroadcastLossGrad(
const int nthreads, const int nthreads,
const int rows, const int dim1,
const int cols, const int dim2,
const T* dy, const T* dy,
T* dx) { T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
dx[i] *= __ldg(dy + (i / rows) * cols + (i % cols)); dx[i] *= __ldg(dy + (i / dim1) * dim2 + (i % dim2));
#else #else
dx[i] *= dy[(i / rows) * cols + (i % cols)]; dx[i] *= dy[(i / dim1) * dim2 + (i % dim2)];
#endif #endif
} }
} }
...@@ -168,18 +168,18 @@ __global__ void _BroadcastLossGrad( ...@@ -168,18 +168,18 @@ __global__ void _BroadcastLossGrad(
template <> template <>
__global__ void _BroadcastLossGrad<half>( __global__ void _BroadcastLossGrad<half>(
const int nthreads, const int nthreads,
const int rows, const int dim1,
const int cols, const int dim2,
const half* dy, const half* dy,
half* dx) { half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
dx[i] = __float2half( dx[i] = __float2half(
__half2float(dx[i]) * __half2float(dx[i]) *
__half2float(__ldg(dy + (i / rows) * cols + (i % cols)))); __half2float(__ldg(dy + (i / dim1) * dim2 + (i % dim2))));
#else #else
dx[i] = __float2half( dx[i] = __float2half(
__half2float(dx[i]) * __half2float(dy[(i / rows) * cols + (i % cols)])); __half2float(dx[i]) * __half2float(dy[(i / dim1) * dim2 + (i % dim2)]));
#endif #endif
} }
} }
...@@ -197,7 +197,7 @@ void ReduceLoss<float16, CUDAContext>( ...@@ -197,7 +197,7 @@ void ReduceLoss<float16, CUDAContext>(
const int* mask, const int* mask,
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
if (num_masks > 0) { if (num_masks > 0 && normalizer < 0.f) {
_ReduceLossWithMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( _ReduceLossWithMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
num_masks, num_masks,
reinterpret_cast<const half*>(x), reinterpret_cast<const half*>(x),
...@@ -221,7 +221,7 @@ void ReduceLossGrad<float16, CUDAContext>( ...@@ -221,7 +221,7 @@ void ReduceLossGrad<float16, CUDAContext>(
const int* mask, const int* mask,
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
if (num_masks > 0) { if (num_masks > 0 && normalizer < 0.f) {
_ReduceMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( _ReduceMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
num_masks, const_cast<int*>(mask)); num_masks, const_cast<int*>(mask));
_ReduceLossGradWithMask<<< _ReduceLossGradWithMask<<<
...@@ -254,16 +254,15 @@ void BroadcastLossGrad<float16, CUDAContext>( ...@@ -254,16 +254,15 @@ void BroadcastLossGrad<float16, CUDAContext>(
const float16* dy, const float16* dy,
float16* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
auto rows = outer_dim * axis_dim, cols = inner_dim; auto nthreads = outer_dim * axis_dim * inner_dim;
auto nthreads = rows * cols;
_BroadcastLossGrad<<< _BroadcastLossGrad<<<
CUDA_BLOCKS(nthreads), CUDA_BLOCKS(nthreads),
CUDA_THREADS, CUDA_THREADS,
0, 0,
ctx->cuda_stream()>>>( ctx->cuda_stream()>>>(
nthreads, nthreads,
rows, axis_dim * inner_dim,
cols, inner_dim,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx)); reinterpret_cast<half*>(dx));
} // BroadcastLossGrad } // BroadcastLossGrad
...@@ -278,7 +277,7 @@ void BroadcastLossGrad<float16, CUDAContext>( ...@@ -278,7 +277,7 @@ void BroadcastLossGrad<float16, CUDAContext>(
const int* mask, \ const int* mask, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
if (num_masks > 0) { \ if (num_masks > 0 && normalizer < 0.f) { \
_ReduceLossWithMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _ReduceLossWithMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
num_masks, x, mask, y); \ num_masks, x, mask, y); \
} else { \ } else { \
...@@ -297,7 +296,7 @@ void BroadcastLossGrad<float16, CUDAContext>( ...@@ -297,7 +296,7 @@ void BroadcastLossGrad<float16, CUDAContext>(
const int* mask, \ const int* mask, \
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
if (num_masks > 0) { \ if (num_masks > 0 && normalizer < 0.f) { \
_ReduceMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _ReduceMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
num_masks, const_cast<int*>(mask)); \ num_masks, const_cast<int*>(mask)); \
_ReduceLossGradWithMask<<< \ _ReduceLossGradWithMask<<< \
...@@ -322,13 +321,13 @@ void BroadcastLossGrad<float16, CUDAContext>( ...@@ -322,13 +321,13 @@ void BroadcastLossGrad<float16, CUDAContext>(
const T* dy, \ const T* dy, \
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
auto rows = outer_dim * axis_dim, cols = inner_dim; \ auto nthreads = outer_dim * axis_dim * inner_dim; \
auto nthreads = rows * cols; \
_BroadcastLossGrad<<< \ _BroadcastLossGrad<<< \
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(nthreads, rows, cols, dy, dx); \ ctx->cuda_stream()>>>( \
nthreads, axis_dim * inner_dim, inner_dim, dy, dx); \
} }
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
......
...@@ -7,31 +7,31 @@ namespace dragon { ...@@ -7,31 +7,31 @@ namespace dragon {
namespace kernel { namespace kernel {
template <> template <>
void MixedPrecL2Decay<float16, CPUContext>( void MixedPrecL2Penalty<float16, CPUContext>(
const int count, const int count,
const float alpha, const float alpha,
const float16* w, const float16* x,
float* dx, float* dx,
CPUContext* ctx) { CPUContext* ctx) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count)) #pragma omp parallel for num_threads(OMP_THREADS(count))
#endif #endif
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
dx[i] += (cast::to<float>(w[i]) * alpha); dx[i] += (cast::to<float>(x[i]) * alpha);
} }
} }
template <> template <>
void MixedPrecUpdate<float16, CPUContext>( void MixedPrecUpdate<float16, CPUContext>(
const int count, const int count,
const float* updates, const float* dx,
float16* w, float16* x,
CPUContext* ctx) { CPUContext* ctx) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count)) #pragma omp parallel for num_threads(OMP_THREADS(count))
#endif #endif
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
w[i] = cast::to<float16>(cast::to<float>(w[i]) - updates[i]); x[i] = cast::to<float16>(cast::to<float>(x[i]) - dx[i]);
} }
} }
......
...@@ -9,24 +9,19 @@ namespace kernel { ...@@ -9,24 +9,19 @@ namespace kernel {
namespace { namespace {
__global__ void _MixedPrecL2DecayHalf( __global__ void _MixedPrecL2Penalty(
const int nthreads, const int nthreads,
const float alpha, const float alpha,
const half* w, const half* x,
float* dx) { float* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530 dx[i] += __half2float(x[i]) * alpha;
dx[i] += __half2float(w[i]) * alpha;
#endif
} }
} }
__global__ void __global__ void _MixedPrecUpdate(const int nthreads, const float* dx, half* x) {
_MixedPrecUpdateHalf(const int nthreads, const float* updates, half* w) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530 x[i] = __float2half(__half2float(x[i]) - dx[i]);
w[i] = __float2half(__half2float(w[i]) - updates[i]);
#endif
} }
} }
...@@ -35,30 +30,27 @@ _MixedPrecUpdateHalf(const int nthreads, const float* updates, half* w) { ...@@ -35,30 +30,27 @@ _MixedPrecUpdateHalf(const int nthreads, const float* updates, half* w) {
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> template <>
void MixedPrecL2Decay<float16, CUDAContext>( void MixedPrecL2Penalty<float16, CUDAContext>(
const int count, const int count,
const float alpha, const float alpha,
const float16* w, const float16* x,
float* dx, float* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_MixedPrecL2DecayHalf<<< _MixedPrecL2Penalty<<<
CUDA_BLOCKS(count), CUDA_BLOCKS(count),
CUDA_THREADS, CUDA_THREADS,
0, 0,
ctx->cuda_stream()>>>(count, alpha, reinterpret_cast<const half*>(w), dx); ctx->cuda_stream()>>>(count, alpha, reinterpret_cast<const half*>(x), dx);
} }
template <> template <>
void MixedPrecUpdate<float16, CUDAContext>( void MixedPrecUpdate<float16, CUDAContext>(
const int count, const int count,
const float* updates, const float* dx,
float16* w, float16* x,
CUDAContext* ctx) { CUDAContext* ctx) {
_MixedPrecUpdateHalf<<< _MixedPrecUpdate<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
CUDA_BLOCKS(count), count, dx, reinterpret_cast<half*>(x));
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(count, updates, reinterpret_cast<half*>(w));
} }
} // namespace kernel } // namespace kernel
......
...@@ -116,15 +116,13 @@ __global__ void _AvgPool2dGradNCHW( ...@@ -116,15 +116,13 @@ __global__ void _AvgPool2dGradNCHW(
const T* dy, const T* dy,
T* dx) { T* dx) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) { CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int w = xi % W; const int w = xi % W + pad_w;
const int h = (xi / W) % H; const int h = (xi / W) % H + pad_h;
const int c = (xi / W / H) % C; const int c = (xi / W / H) % C;
const int n = xi / W / H / C; const int n = xi / W / H / C;
const int phstart = const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
(h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1; const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwstart =
(w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
const int phend = min(h / stride_h + 1, out_h); const int phend = min(h / stride_h + 1, out_h);
const int pwend = min(w / stride_w + 1, out_w); const int pwend = min(w / stride_w + 1, out_w);
...@@ -164,14 +162,12 @@ __global__ void _AvgPool2dGradNHWC( ...@@ -164,14 +162,12 @@ __global__ void _AvgPool2dGradNHWC(
T* dx) { T* dx) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) { CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int c = xi % C; const int c = xi % C;
const int w = (xi / C) % W; const int w = (xi / C) % W + pad_w;
const int h = (xi / C / W) % H; const int h = (xi / C / W) % H + pad_h;
const int n = xi / C / W / H; const int n = xi / C / W / H;
const int phstart = const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
(h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1; const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwstart =
(w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
const int phend = min(h / stride_h + 1, out_h); const int phend = min(h / stride_h + 1, out_h);
const int pwend = min(w / stride_w + 1, out_w); const int pwend = min(w / stride_w + 1, out_w);
......
...@@ -30,8 +30,8 @@ void _Im2Col2dNCHW( ...@@ -30,8 +30,8 @@ void _Im2Col2dNCHW(
const T* im, const T* im,
T* col) { T* col) {
int ih, iw; int ih, iw;
const int im_ofs = H * W; const int im_offset = H * W;
for (int c = 0; c < C; ++c, im += im_ofs) { for (int c = 0; c < C; ++c, im += im_offset) {
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) { for (int kw = 0; kw < kernel_w; ++kw) {
ih = -pad_h + kh * dilation_h; ih = -pad_h + kh * dilation_h;
...@@ -117,8 +117,8 @@ void _Col2Im2dNCHW( ...@@ -117,8 +117,8 @@ void _Col2Im2dNCHW(
const T* col, const T* col,
T* im) { T* im) {
int ih, iw; int ih, iw;
const int im_ofs = H * W; const int im_offset = H * W;
for (int c = 0; c < C; ++c, im += im_ofs) { for (int c = 0; c < C; ++c, im += im_offset) {
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) { for (int kw = 0; kw < kernel_w; ++kw) {
ih = -pad_h + kh * dilation_h; ih = -pad_h + kh * dilation_h;
......
...@@ -27,13 +27,13 @@ void _DepthwiseConv2dNCHW( ...@@ -27,13 +27,13 @@ void _DepthwiseConv2dNCHW(
T* y) { T* y) {
T sum_val; T sum_val;
int ih, iw, xi, wi; int ih, iw, xi, wi;
int yc_ofs, xc_start, yc_start; int yc_offset, xc_start, yc_start;
int ih_start, yh_start, iw_start; int ih_start, yh_start, iw_start;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) { for (int c = 0; c < C; ++c) {
yc_ofs = n * C + c; yc_offset = n * C + c;
xc_start = yc_ofs * H * W; xc_start = yc_offset * H * W;
yc_start = yc_ofs * out_h; yc_start = yc_offset * out_h;
for (int oh = 0; oh < out_h; ++oh) { for (int oh = 0; oh < out_h; ++oh) {
ih_start = oh * stride_h - pad_h; ih_start = oh * stride_h - pad_h;
yh_start = (yc_start + oh) * out_w; yh_start = (yc_start + oh) * out_w;
......
...@@ -46,7 +46,7 @@ void _ResizeLinearNCHW( ...@@ -46,7 +46,7 @@ void _ResizeLinearNCHW(
std::array<int, 4> idx = {0, 0, 0, 0}; std::array<int, 4> idx = {0, 0, 0, 0};
std::array<int, 4> dims = {N, C, out_h, out_w}; std::array<int, 4> dims = {N, C, out_h, out_w};
float h_in, w_in, u, v, t, b, tl, tr, bl, br; float h_in, w_in, u, v, t, b, tl, tr, bl, br;
int ti, bi, li, ri, ofs, h_max = H - 1, w_max = W - 1; int ti, bi, li, ri, offset, h_max = H - 1, w_max = W - 1;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
h_in = TransformCoordinate(idx[2], scale_h, align_corners); h_in = TransformCoordinate(idx[2], scale_h, align_corners);
w_in = TransformCoordinate(idx[3], scale_w, align_corners); w_in = TransformCoordinate(idx[3], scale_w, align_corners);
...@@ -54,11 +54,11 @@ void _ResizeLinearNCHW( ...@@ -54,11 +54,11 @@ void _ResizeLinearNCHW(
bi = (h_in < h_max) ? std::ceil(h_in) : h_max; bi = (h_in < h_max) ? std::ceil(h_in) : h_max;
ri = (w_in < w_max) ? std::ceil(w_in) : w_max; ri = (w_in < w_max) ? std::ceil(w_in) : w_max;
v = h_in - ti, u = w_in - li; v = h_in - ti, u = w_in - li;
ofs = (idx[0] * C + idx[1]) * H; offset = (idx[0] * C + idx[1]) * H;
tl = (float)x[(ofs + ti) * W + li]; tl = (float)x[(offset + ti) * W + li];
tr = (float)x[(ofs + ti) * W + ri]; tr = (float)x[(offset + ti) * W + ri];
bl = (float)x[(ofs + bi) * W + li]; bl = (float)x[(offset + bi) * W + li];
br = (float)x[(ofs + bi) * W + ri]; br = (float)x[(offset + bi) * W + ri];
t = tl + (tr - tl) * u; t = tl + (tr - tl) * u;
b = bl + (br - bl) * u; b = bl + (br - bl) * u;
y[i] = static_cast<T>(t + (b - t) * v); y[i] = static_cast<T>(t + (b - t) * v);
...@@ -83,7 +83,7 @@ void _ResizeLinearNHWC( ...@@ -83,7 +83,7 @@ void _ResizeLinearNHWC(
std::array<int, 4> idx = {0, 0, 0, 0}; std::array<int, 4> idx = {0, 0, 0, 0};
std::array<int, 4> dims = {N, out_h, out_w, C}; std::array<int, 4> dims = {N, out_h, out_w, C};
float h_in, w_in, u, v, t, b, tl, tr, bl, br; float h_in, w_in, u, v, t, b, tl, tr, bl, br;
int ti, bi, li, ri, ofs, h_max = H - 1, w_max = W - 1; int ti, bi, li, ri, offset, h_max = H - 1, w_max = W - 1;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
h_in = TransformCoordinate(idx[1], scale_h, align_corners); h_in = TransformCoordinate(idx[1], scale_h, align_corners);
w_in = TransformCoordinate(idx[2], scale_w, align_corners); w_in = TransformCoordinate(idx[2], scale_w, align_corners);
...@@ -91,11 +91,11 @@ void _ResizeLinearNHWC( ...@@ -91,11 +91,11 @@ void _ResizeLinearNHWC(
bi = (h_in < h_max) ? std::ceil(h_in) : h_max; bi = (h_in < h_max) ? std::ceil(h_in) : h_max;
ri = (w_in < w_max) ? std::ceil(w_in) : w_max; ri = (w_in < w_max) ? std::ceil(w_in) : w_max;
v = h_in - ti, u = w_in - li; v = h_in - ti, u = w_in - li;
ofs = idx[0] * H; offset = idx[0] * H;
tl = (float)x[((ofs + ti) * W + li) * C + idx[3]]; tl = (float)x[((offset + ti) * W + li) * C + idx[3]];
tr = (float)x[((ofs + ti) * W + ri) * C + idx[3]]; tr = (float)x[((offset + ti) * W + ri) * C + idx[3]];
bl = (float)x[((ofs + bi) * W + li) * C + idx[3]]; bl = (float)x[((offset + bi) * W + li) * C + idx[3]];
br = (float)x[((ofs + bi) * W + ri) * C + idx[3]]; br = (float)x[((offset + bi) * W + ri) * C + idx[3]];
t = tl + (tr - tl) * u; t = tl + (tr - tl) * u;
b = bl + (br - bl) * u; b = bl + (br - bl) * u;
y[i] = static_cast<T>(t + (b - t) * v); y[i] = static_cast<T>(t + (b - t) * v);
...@@ -120,7 +120,7 @@ void _ResizeLinearGradNCHW( ...@@ -120,7 +120,7 @@ void _ResizeLinearGradNCHW(
std::array<int, 4> idx = {0, 0, 0, 0}; std::array<int, 4> idx = {0, 0, 0, 0};
std::array<int, 4> dims = {N, C, out_h, out_w}; std::array<int, 4> dims = {N, C, out_h, out_w};
float h_in, w_in, u, v, dt, db, tl, tr, bl, br; float h_in, w_in, u, v, dt, db, tl, tr, bl, br;
int ti, bi, li, ri, ofs, h_max = H - 1, w_max = W - 1; int ti, bi, li, ri, offset, h_max = H - 1, w_max = W - 1;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
h_in = TransformCoordinate(idx[2], scale_h, align_corners); h_in = TransformCoordinate(idx[2], scale_h, align_corners);
w_in = TransformCoordinate(idx[3], scale_w, align_corners); w_in = TransformCoordinate(idx[3], scale_w, align_corners);
...@@ -128,13 +128,13 @@ void _ResizeLinearGradNCHW( ...@@ -128,13 +128,13 @@ void _ResizeLinearGradNCHW(
bi = (h_in < h_max) ? std::ceil(h_in) : h_max; bi = (h_in < h_max) ? std::ceil(h_in) : h_max;
ri = (w_in < w_max) ? std::ceil(w_in) : w_max; ri = (w_in < w_max) ? std::ceil(w_in) : w_max;
v = h_in - ti, u = w_in - li; v = h_in - ti, u = w_in - li;
ofs = (idx[0] * C + idx[1]) * H; offset = (idx[0] * C + idx[1]) * H;
dt = (1.f - v) * static_cast<float>(dy[i]); dt = (1.f - v) * static_cast<float>(dy[i]);
db = v * static_cast<float>(dy[i]); db = v * static_cast<float>(dy[i]);
dx[(ofs + ti) * W + li] += (1.f - u) * dt; // tl dx[(offset + ti) * W + li] += (1.f - u) * dt; // tl
dx[(ofs + ti) * W + ri] += u * dt; // tr dx[(offset + ti) * W + ri] += u * dt; // tr
dx[(ofs + bi) * W + li] += (1.f - u) * db; // bl dx[(offset + bi) * W + li] += (1.f - u) * db; // bl
dx[(ofs + bi) * W + ri] += u * db; // br dx[(offset + bi) * W + ri] += u * db; // br
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data()); utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
} }
} }
...@@ -156,7 +156,7 @@ void _ResizeLinearGradNHWC( ...@@ -156,7 +156,7 @@ void _ResizeLinearGradNHWC(
std::array<int, 4> idx = {0, 0, 0, 0}; std::array<int, 4> idx = {0, 0, 0, 0};
std::array<int, 4> dims = {N, out_h, out_w, C}; std::array<int, 4> dims = {N, out_h, out_w, C};
float h_in, w_in, u, v, dt, db, tl, tr, bl, br; float h_in, w_in, u, v, dt, db, tl, tr, bl, br;
int ti, bi, li, ri, ofs, h_max = H - 1, w_max = W - 1; int ti, bi, li, ri, offset, h_max = H - 1, w_max = W - 1;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
h_in = TransformCoordinate(idx[1], scale_h, align_corners); h_in = TransformCoordinate(idx[1], scale_h, align_corners);
w_in = TransformCoordinate(idx[2], scale_w, align_corners); w_in = TransformCoordinate(idx[2], scale_w, align_corners);
...@@ -164,13 +164,13 @@ void _ResizeLinearGradNHWC( ...@@ -164,13 +164,13 @@ void _ResizeLinearGradNHWC(
bi = (h_in < h_max) ? std::ceil(h_in) : h_max; bi = (h_in < h_max) ? std::ceil(h_in) : h_max;
ri = (w_in < w_max) ? std::ceil(w_in) : w_max; ri = (w_in < w_max) ? std::ceil(w_in) : w_max;
v = h_in - ti, u = w_in - li; v = h_in - ti, u = w_in - li;
ofs = idx[0] * H; offset = idx[0] * H;
dt = (1.f - v) * static_cast<float>(dy[i]); dt = (1.f - v) * static_cast<float>(dy[i]);
db = v * static_cast<float>(dy[i]); db = v * static_cast<float>(dy[i]);
dx[((ofs + ti) * W + li) * C + idx[3]] += (1.f - u) * dt; // tl dx[((offset + ti) * W + li) * C + idx[3]] += (1.f - u) * dt; // tl
dx[((ofs + ti) * W + ri) * C + idx[3]] += u * dt; // tr dx[((offset + ti) * W + ri) * C + idx[3]] += u * dt; // tr
dx[((ofs + bi) * W + li) * C + idx[3]] += (1.f - u) * db; // bl dx[((offset + bi) * W + li) * C + idx[3]] += (1.f - u) * db; // bl
dx[((ofs + bi) * W + ri) * C + idx[3]] += u * db; // br dx[((offset + bi) * W + ri) * C + idx[3]] += u * db; // br
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data()); utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
} }
} }
......
...@@ -61,17 +61,17 @@ __global__ void _ResizeLinearNCHW( ...@@ -61,17 +61,17 @@ __global__ void _ResizeLinearNCHW(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1; const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li; const float u = w_in - li;
const int ofs = (n * C + c) * H; const int offset = (n * C + c) * H;
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
const float tl = __ldg(x + ((ofs + ti) * W + li)); const float tl = __ldg(x + ((offset + ti) * W + li));
const float tr = __ldg(x + ((ofs + ti) * W + ri)); const float tr = __ldg(x + ((offset + ti) * W + ri));
const float bl = __ldg(x + ((ofs + bi) * W + li)); const float bl = __ldg(x + ((offset + bi) * W + li));
const float br = __ldg(x + ((ofs + bi) * W + ri)); const float br = __ldg(x + ((offset + bi) * W + ri));
#else #else
const float tl = x[(ofs + ti) * W + li]; const float tl = x[(offset + ti) * W + li];
const float tr = x[(ofs + ti) * W + ri]; const float tr = x[(offset + ti) * W + ri];
const float bl = x[(ofs + bi) * W + li]; const float bl = x[(offset + bi) * W + li];
const float br = x[(ofs + bi) * W + ri]; const float br = x[(offset + bi) * W + ri];
#endif #endif
const float t = tl + (tr - tl) * u; const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u; const float b = bl + (br - bl) * u;
...@@ -109,11 +109,11 @@ __global__ void _ResizeLinearNCHW<half>( ...@@ -109,11 +109,11 @@ __global__ void _ResizeLinearNCHW<half>(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1; const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li; const float u = w_in - li;
const int ofs = (n * C + c) * H; const int offset = (n * C + c) * H;
const float tl = __half2float(__ldg(x + ((ofs + ti) * W + li))); const float tl = __half2float(__ldg(x + ((offset + ti) * W + li)));
const float tr = __half2float(__ldg(x + ((ofs + ti) * W + ri))); const float tr = __half2float(__ldg(x + ((offset + ti) * W + ri)));
const float bl = __half2float(__ldg(x + ((ofs + bi) * W + li))); const float bl = __half2float(__ldg(x + ((offset + bi) * W + li)));
const float br = __half2float(__ldg(x + ((ofs + bi) * W + ri))); const float br = __half2float(__ldg(x + ((offset + bi) * W + ri)));
const float t = tl + (tr - tl) * u; const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u; const float b = bl + (br - bl) * u;
...@@ -151,17 +151,17 @@ __global__ void _ResizeLinearNHWC( ...@@ -151,17 +151,17 @@ __global__ void _ResizeLinearNHWC(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1; const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li; const float u = w_in - li;
const int ofs = n * H; const int offset = n * H;
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
const float tl = __ldg(x + (((ofs + ti) * W + li) * C + c)); const float tl = __ldg(x + (((offset + ti) * W + li) * C + c));
const float tr = __ldg(x + (((ofs + ti) * W + ri) * C + c)); const float tr = __ldg(x + (((offset + ti) * W + ri) * C + c));
const float bl = __ldg(x + (((ofs + bi) * W + li) * C + c)); const float bl = __ldg(x + (((offset + bi) * W + li) * C + c));
const float br = __ldg(x + (((ofs + bi) * W + ri) * C + c)); const float br = __ldg(x + (((offset + bi) * W + ri) * C + c));
#else #else
const float tl = x[((ofs + ti) * W + li) * C + c]; const float tl = x[((offset + ti) * W + li) * C + c];
const float tr = x[((ofs + ti) * W + ri) * C + c]; const float tr = x[((offset + ti) * W + ri) * C + c];
const float bl = x[((ofs + bi) * W + li) * C + c]; const float bl = x[((offset + bi) * W + li) * C + c];
const float br = x[((ofs + bi) * W + ri) * C + c]; const float br = x[((offset + bi) * W + ri) * C + c];
#endif #endif
const float t = tl + (tr - tl) * u; const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u; const float b = bl + (br - bl) * u;
...@@ -199,11 +199,15 @@ __global__ void _ResizeLinearNHWC<half>( ...@@ -199,11 +199,15 @@ __global__ void _ResizeLinearNHWC<half>(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1; const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li; const float u = w_in - li;
const int ofs = n * H; const int offset = n * H;
const float tl = __half2float(__ldg(x + (((ofs + ti) * W + li) * C + c))); const float tl =
const float tr = __half2float(__ldg(x + (((ofs + ti) * W + ri) * C + c))); __half2float(__ldg(x + (((offset + ti) * W + li) * C + c)));
const float bl = __half2float(__ldg(x + (((ofs + bi) * W + li) * C + c))); const float tr =
const float br = __half2float(__ldg(x + (((ofs + bi) * W + ri) * C + c))); __half2float(__ldg(x + (((offset + ti) * W + ri) * C + c)));
const float bl =
__half2float(__ldg(x + (((offset + bi) * W + li) * C + c)));
const float br =
__half2float(__ldg(x + (((offset + bi) * W + ri) * C + c)));
const float t = tl + (tr - tl) * u; const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u; const float b = bl + (br - bl) * u;
...@@ -249,11 +253,11 @@ __global__ void _ResizeLinearGradNCHW( ...@@ -249,11 +253,11 @@ __global__ void _ResizeLinearGradNCHW(
const float db = v * ((float)dy[yi]); const float db = v * ((float)dy[yi]);
#endif #endif
const int ofs = (n * C + c) * H; const int offset = (n * C + c) * H;
atomicAdd(&dx[(ofs + ti) * W + li], (1.f - u) * dt); atomicAdd(&dx[(offset + ti) * W + li], (1.f - u) * dt);
atomicAdd(&dx[(ofs + ti) * W + ri], u * dt); atomicAdd(&dx[(offset + ti) * W + ri], u * dt);
atomicAdd(&dx[(ofs + bi) * W + li], (1.f - u) * db); atomicAdd(&dx[(offset + bi) * W + li], (1.f - u) * db);
atomicAdd(&dx[(ofs + bi) * W + ri], u * db); atomicAdd(&dx[(offset + bi) * W + ri], u * db);
} }
} }
...@@ -290,11 +294,11 @@ __global__ void _ResizeLinearGradNCHW<half>( ...@@ -290,11 +294,11 @@ __global__ void _ResizeLinearGradNCHW<half>(
const float dt = (1.f - v) * __half2float(__ldg(dy + yi)); const float dt = (1.f - v) * __half2float(__ldg(dy + yi));
const float db = v * __half2float(__ldg(dy + yi)); const float db = v * __half2float(__ldg(dy + yi));
const int ofs = (n * C + c) * H; const int offset = (n * C + c) * H;
atomicAdd(&dx[(ofs + ti) * W + li], (1.f - u) * dt); atomicAdd(&dx[(offset + ti) * W + li], (1.f - u) * dt);
atomicAdd(&dx[(ofs + ti) * W + ri], u * dt); atomicAdd(&dx[(offset + ti) * W + ri], u * dt);
atomicAdd(&dx[(ofs + bi) * W + li], (1.f - u) * db); atomicAdd(&dx[(offset + bi) * W + li], (1.f - u) * db);
atomicAdd(&dx[(ofs + bi) * W + ri], u * db); atomicAdd(&dx[(offset + bi) * W + ri], u * db);
#endif #endif
} }
} }
...@@ -336,11 +340,11 @@ __global__ void _ResizeLinearGradNHWC( ...@@ -336,11 +340,11 @@ __global__ void _ResizeLinearGradNHWC(
const float db = v * ((float)dy[yi]); const float db = v * ((float)dy[yi]);
#endif #endif
const int ofs = n * H; const int offset = n * H;
atomicAdd(&dx[((ofs + ti) * W + li) * C + c], (1.f - u) * dt); atomicAdd(&dx[((offset + ti) * W + li) * C + c], (1.f - u) * dt);
atomicAdd(&dx[((ofs + ti) * W + ri) * C + c], u * dt); atomicAdd(&dx[((offset + ti) * W + ri) * C + c], u * dt);
atomicAdd(&dx[((ofs + bi) * W + li) * C + c], (1.f - u) * db); atomicAdd(&dx[((offset + bi) * W + li) * C + c], (1.f - u) * db);
atomicAdd(&dx[((ofs + bi) * W + ri) * C + c], u * db); atomicAdd(&dx[((offset + bi) * W + ri) * C + c], u * db);
} }
} }
......
...@@ -11,7 +11,6 @@ template <typename T> ...@@ -11,7 +11,6 @@ template <typename T>
void CuDNNEluOp<Context>::DoRunWithType() { void CuDNNEluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims()); CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
CUDNN_CHECK(cudnnActivationForward( CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
act_desc_, act_desc_,
...@@ -33,7 +32,6 @@ template <typename T> ...@@ -33,7 +32,6 @@ template <typename T>
void CuDNNEluGradientOp<Context>::DoRunWithType() { void CuDNNEluGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0); auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims()); CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
CUDNN_CHECK(cudnnActivationBackward( CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
act_desc_, act_desc_,
......
...@@ -7,7 +7,6 @@ template <class Context> ...@@ -7,7 +7,6 @@ template <class Context>
template <typename T> template <typename T>
void ReluOp<Context>::DoRunWithType() { void ReluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
if (max_value_ > 0.f) { if (max_value_ > 0.f) {
kernel::ReluN( kernel::ReluN(
X.count(), X.count(),
...@@ -34,7 +33,6 @@ template <class Context> ...@@ -34,7 +33,6 @@ template <class Context>
template <typename T> template <typename T>
void ReluGradientOp<Context>::DoRunWithType() { void ReluGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0); auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
if (max_value_ > 0.f) { if (max_value_ > 0.f) {
kernel::ReluNGrad( kernel::ReluNGrad(
Y.count(), Y.count(),
......
...@@ -9,7 +9,6 @@ template <typename T> ...@@ -9,7 +9,6 @@ template <typename T>
void CuDNNReluOp<Context>::DoRunWithType() { void CuDNNReluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims()); CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward( CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
...@@ -47,7 +46,6 @@ template <typename T> ...@@ -47,7 +46,6 @@ template <typename T>
void CuDNNReluGradientOp<Context>::DoRunWithType() { void CuDNNReluGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0); auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims()); CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward( CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
......
...@@ -9,7 +9,6 @@ template <typename T> ...@@ -9,7 +9,6 @@ template <typename T>
void CuDNNSigmoidOp<Context>::DoRunWithType() { void CuDNNSigmoidOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims()); CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward( CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
...@@ -43,7 +42,6 @@ template <typename T> ...@@ -43,7 +42,6 @@ template <typename T>
void CuDNNSigmoidGradientOp<Context>::DoRunWithType() { void CuDNNSigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0); auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims()); CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward( CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
......
...@@ -9,10 +9,8 @@ template <typename T> ...@@ -9,10 +9,8 @@ template <typename T>
void CuDNNSoftmaxOp<Context>::DoRunWithType() { void CuDNNSoftmaxOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
CuDNNSetTensorDesc<T>( CuDNNSetTensorDesc<T>(
&input_desc_, {X.count(0, axis), X.dim(axis), X.count(axis + 1)}); &input_desc_, {X.count(0, axis), X.dim(axis), X.count(axis + 1)});
CUDNN_CHECK(cudnnSoftmaxForward( CUDNN_CHECK(cudnnSoftmaxForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_ACCURATE,
...@@ -35,10 +33,8 @@ template <typename T> ...@@ -35,10 +33,8 @@ template <typename T>
void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() { void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0); auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(Y); CANONICALIZE_AXIS_WITH_TENSOR(Y);
CuDNNSetTensorDesc<T>( CuDNNSetTensorDesc<T>(
&input_desc_, {Y.count(0, axis), Y.dim(axis), Y.count(axis + 1)}); &input_desc_, {Y.count(0, axis), Y.dim(axis), Y.count(axis + 1)});
CUDNN_CHECK(cudnnSoftmaxBackward( CUDNN_CHECK(cudnnSoftmaxBackward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_ACCURATE,
......
...@@ -9,7 +9,6 @@ template <typename T> ...@@ -9,7 +9,6 @@ template <typename T>
void CuDNNTanhOp<Context>::DoRunWithType() { void CuDNNTanhOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims()); CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward( CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
...@@ -43,7 +42,6 @@ template <typename T> ...@@ -43,7 +42,6 @@ template <typename T>
void CuDNNTanhGradientOp<Context>::DoRunWithType() { void CuDNNTanhGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0); auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims()); CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward( CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
......
...@@ -64,9 +64,7 @@ void CastOp<Context>::RunOnDevice() { ...@@ -64,9 +64,7 @@ void CastOp<Context>::RunOnDevice() {
STORE_INPUT_SPEC(0); STORE_INPUT_SPEC(0);
DISPATCH_WITH_TENSOR(Input(0)); DISPATCH_WITH_TENSOR(Input(0));
} else { } else {
Buffer("X[" + std::to_string(0) + "]") Buffer("X_spec:0")->ReshapeLike(*Output(0))->set_meta(Output(0)->meta());
->ReshapeLike(*Output(0))
->set_meta(Output(0)->meta());
DISPATCH_WITH_TENSOR((*Output(0))); DISPATCH_WITH_TENSOR((*Output(0)));
}; };
} }
......
...@@ -26,7 +26,9 @@ namespace dragon { ...@@ -26,7 +26,9 @@ namespace dragon {
axes_(OpArgs<int64_t>("axes")), \ axes_(OpArgs<int64_t>("axes")), \
keep_dims_(OpArg<int64_t>("keep_dims", 0)) {} \ keep_dims_(OpArg<int64_t>("keep_dims", 0)) {} \
USE_OPERATOR_FUNCTIONS; \ USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \ void RunOnDevice() override; \
\
template <typename T> \ template <typename T> \
void DoRunWithType(); \ void DoRunWithType(); \
\ \
...@@ -41,7 +43,9 @@ namespace dragon { ...@@ -41,7 +43,9 @@ namespace dragon {
public: \ public: \
SIMPLE_CTOR_DTOR(name##GradientOp); \ SIMPLE_CTOR_DTOR(name##GradientOp); \
USE_OPERATOR_FUNCTIONS; \ USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \ void RunOnDevice() override; \
\
template <typename T> \ template <typename T> \
void DoRunWithType(); \ void DoRunWithType(); \
}; };
......
...@@ -15,9 +15,8 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -15,9 +15,8 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() {
auto inner_dim = X.count(axis + 1); auto inner_dim = X.count(axis + 1);
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
CHECK_EQ(num_preds, Input(1).count()) CHECK_EQ(X.count(), Input(1).count())
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
Buffer("prob")->ReshapeLike(X); Buffer("prob")->ReshapeLike(X);
auto* loss = ws()->template data<T, Context>({X.count()})[0]; auto* loss = ws()->template data<T, Context>({X.count()})[0];
......
...@@ -17,6 +17,8 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -17,6 +17,8 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
CHECK_EQ(num_preds, Input(1).count()) CHECK_EQ(num_preds, Input(1).count())
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
auto* X_prob = Buffer("prob")->ReshapeLike(X);
auto* prob = X_prob->template mutable_data<LogitType, Context>();
auto scratches = ws()->template data<Context>({ auto scratches = ws()->template data<Context>({
num_preds * sizeof(LogitType), // loss num_preds * sizeof(LogitType), // loss
...@@ -25,10 +27,6 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -25,10 +27,6 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
auto* loss = static_cast<LogitType*>(scratches[0]); auto* loss = static_cast<LogitType*>(scratches[0]);
auto* mask = static_cast<int*>(scratches[1]); auto* mask = static_cast<int*>(scratches[1]);
auto* prob = Buffer("prob")
->ReshapeLike(X)
->template mutable_data<LogitType, Context>();
kernel::Softmax( kernel::Softmax(
outer_dim, outer_dim,
X.dim(axis), X.dim(axis),
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MATH_ACCUMULATE_OP_H_
#define DRAGON_OPERATORS_MATH_ACCUMULATE_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class AccumulateOp final : public Operator<Context> {
public:
AccumulateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OpArg<float>("alpha", 1.f)),
beta_(OpArg<float>("beta", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType(Tensor* X, Tensor* Y);
protected:
float alpha_, beta_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MATH_ACCUMULATE_OP_H_
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/operators/math/accumulate_op.h" #include "dragon/operators/math/elementwise_ops.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void AccumulateOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) { void AxpbyOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) {
CHECK_EQ(X->count(), Y->count()); CHECK_EQ(X->count(), Y->count());
auto* x = X->template data<T, Context>(); auto* x = X->template data<T, Context>();
auto* y = Y->template mutable_data<T, Context>(); auto* y = Y->template mutable_data<T, Context>();
...@@ -26,41 +26,42 @@ void AccumulateOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) { ...@@ -26,41 +26,42 @@ void AccumulateOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) {
} }
template <class Context> template <class Context>
void AccumulateOp<Context>::RunOnDevice() { void AxpbyOp<Context>::RunOnDevice() {
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); i++) {
Output(i)->ReshapeLike(Input(i)); auto &X = Input(i), *Y = Output(i);
if (XIsType(Input(i), int8_t)) { Y->ReshapeLike(X);
DoRunWithType<int8_t>(&Input(i), Output(i)); if (XIsType(X, int8_t)) {
} else if (XIsType(Input(i), uint8_t)) { DoRunWithType<int8_t>(&X, Y);
DoRunWithType<uint8_t>(&Input(i), Output(i)); } else if (XIsType(X, uint8_t)) {
} else if (XIsType(Input(i), int)) { DoRunWithType<uint8_t>(&X, Y);
DoRunWithType<int>(&Input(i), Output(i)); } else if (XIsType(X, int)) {
} else if (XIsType(Input(i), int64_t)) { DoRunWithType<int>(&X, Y);
DoRunWithType<int64_t>(&Input(i), Output(i)); } else if (XIsType(X, int64_t)) {
} else if (XIsType(Input(i), float16)) { DoRunWithType<int64_t>(&X, Y);
DoRunWithType<float16>(&Input(i), Output(i)); } else if (XIsType(X, float16)) {
} else if (XIsType(Input(i), float)) { DoRunWithType<float16>(&X, Y);
DoRunWithType<float>(&Input(i), Output(i)); } else if (XIsType(X, float)) {
} else if (XIsType(Input(i), double)) { DoRunWithType<float>(&X, Y);
DoRunWithType<double>(&Input(i), Output(i)); } else if (XIsType(X, double)) {
DoRunWithType<double>(&X, Y);
} else } else
LOG(FATAL) << TypeString( LOG(FATAL) << TypeString(
Input(i), X,
{"int8", "uint8", "int32", "int64", "float16", "float32", "float64"}); {"int8", "uint8", "int32", "int64", "float16", "float32", "float64"});
} }
} }
DEPLOY_CPU(Accumulate); DEPLOY_CPU(Axpby);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA(Accumulate); DEPLOY_CUDA(Axpby);
#endif #endif
OPERATOR_SCHEMA(Accumulate) OPERATOR_SCHEMA(Axpby)
/* X1, ... */ /* X1, ... */
.NumInputs(1, INT_MAX) .NumInputs(1, INT_MAX)
/* Y1, ... */ /* Y1, ... */
.NumOutputs(1, INT_MAX); .NumOutputs(1, INT_MAX);
NO_GRADIENT(Accumulate); NO_GRADIENT(Axpby);
} // namespace dragon } // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MATH_DOT_OP_H_
#define DRAGON_OPERATORS_MATH_DOT_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class DotOp final : public Operator<Context> {
public:
DotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
transA_(OpArg<bool>("transA", false)),
transB_(OpArg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DotImpl();
template <typename T>
void GemmImpl();
template <typename T>
void GemvImpl();
template <typename T>
void DoRunWithType();
protected:
int64_t transA_, transB_;
int64_t M_, K1_, K2_, N_;
int64_t M1_, N1_, M2_, N2_;
};
template <class Context>
class DotGradientOp final : public Operator<Context> {
public:
DotGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
transA_(OpArg<bool>("transA", false)),
transB_(OpArg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DotImpl();
template <typename T>
void GemmImpl();
template <typename T>
void GemvImpl();
template <typename T>
void DoRunWithType();
protected:
int64_t transA_, transB_;
int64_t M_, K1_, K2_, N_;
int64_t M1_, N1_, M2_, N2_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MATH_DOT_OP_H_
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace dragon { namespace dragon {
#define DECLARE_SIMPLE_UNARY_OP(name) \ #define DECLARE_ELEMENTWISE_OP(name) \
template <class Context> \ template <class Context> \
class name##Op final : public Operator<Context> { \ class name##Op final : public Operator<Context> { \
public: \ public: \
...@@ -31,18 +31,23 @@ namespace dragon { ...@@ -31,18 +31,23 @@ namespace dragon {
void DoRunWithType(); \ void DoRunWithType(); \
}; };
#define DECLARE_SIMPLE_BINARY_OP(name) \ template <class Context>
template <class Context> \ class AxpbyOp final : public Operator<Context> {
class name##Op final : public Operator<Context> { \ public:
public: \ AxpbyOp(const OperatorDef& def, Workspace* ws)
SIMPLE_CTOR_DTOR(name##Op); \ : Operator<Context>(def, ws),
USE_OPERATOR_FUNCTIONS; \ alpha_(OpArg<float>("alpha", 1.f)),
\ beta_(OpArg<float>("beta", 1.f)) {}
void RunOnDevice() override; \ USE_OPERATOR_FUNCTIONS;
\
template <typename T> \ void RunOnDevice() override;
void DoRunWithType(); \
}; template <typename T>
void DoRunWithType(Tensor* X, Tensor* Y);
protected:
float alpha_, beta_;
};
inline vec32_t CheckOutputAliases( inline vec32_t CheckOutputAliases(
const Tensor& A, const Tensor& A,
...@@ -64,87 +69,61 @@ inline vec32_t CheckOutputAliases( ...@@ -64,87 +69,61 @@ inline vec32_t CheckOutputAliases(
return available_aliases; return available_aliases;
} }
inline void IsBroadcast( // Unary ElementwiseOp
const Tensor& A, DECLARE_ELEMENTWISE_OP(Abs);
const Tensor& B, DECLARE_ELEMENTWISE_OP(Ceil);
int& rows, DECLARE_ELEMENTWISE_OP(Cos);
int& cols, DECLARE_ELEMENTWISE_OP(Exp);
int& kind, DECLARE_ELEMENTWISE_OP(Floor);
Tensor* Y = nullptr) { DECLARE_ELEMENTWISE_OP(IsInf);
kind = -2; DECLARE_ELEMENTWISE_OP(IsNaN);
if (A.count() == B.count()) { DECLARE_ELEMENTWISE_OP(Log);
if (Y != nullptr) Y->ReshapeLike(A); DECLARE_ELEMENTWISE_OP(Neg);
kind = -1; DECLARE_ELEMENTWISE_OP(Invert);
} else if (B.count() < A.count()) { DECLARE_ELEMENTWISE_OP(Reciprocal);
if (Y != nullptr) Y->ReshapeLike(A); DECLARE_ELEMENTWISE_OP(Round);
if (utils::math::IsRowwiseBroadcast(A.dims(), B.dims(), &rows, &cols)) { DECLARE_ELEMENTWISE_OP(Rsqrt);
kind = 0; DECLARE_ELEMENTWISE_OP(Sign);
} else if (utils::math::IsColwiseBroadcast( DECLARE_ELEMENTWISE_OP(Sin);
A.dims(), B.dims(), &rows, &cols)) { DECLARE_ELEMENTWISE_OP(Sqrt);
kind = 1; DECLARE_ELEMENTWISE_OP(Square);
} DECLARE_ELEMENTWISE_OP(AbsGradient);
} else { DECLARE_ELEMENTWISE_OP(CosGradient);
if (Y != nullptr) Y->ReshapeLike(B); DECLARE_ELEMENTWISE_OP(ExpGradient);
if (utils::math::IsRowwiseBroadcast(A.dims(), B.dims(), &rows, &cols)) { DECLARE_ELEMENTWISE_OP(LogGradient);
kind = 2; DECLARE_ELEMENTWISE_OP(NegGradient);
} else if (utils::math::IsColwiseBroadcast( DECLARE_ELEMENTWISE_OP(ReciprocalGradient);
A.dims(), B.dims(), &rows, &cols)) { DECLARE_ELEMENTWISE_OP(RsqrtGradient);
kind = 3; DECLARE_ELEMENTWISE_OP(SignGradient);
} DECLARE_ELEMENTWISE_OP(SinGradient);
} DECLARE_ELEMENTWISE_OP(SqrtGradient);
} DECLARE_ELEMENTWISE_OP(SquareGradient);
DECLARE_SIMPLE_UNARY_OP(Abs); // Binary ElementwiseOp
DECLARE_SIMPLE_UNARY_OP(Ceil); DECLARE_ELEMENTWISE_OP(Add);
DECLARE_SIMPLE_UNARY_OP(Cos); DECLARE_ELEMENTWISE_OP(Sub);
DECLARE_SIMPLE_UNARY_OP(Exp); DECLARE_ELEMENTWISE_OP(Mul);
DECLARE_SIMPLE_UNARY_OP(Floor); DECLARE_ELEMENTWISE_OP(Div);
DECLARE_SIMPLE_UNARY_OP(IsInf); DECLARE_ELEMENTWISE_OP(Pow);
DECLARE_SIMPLE_UNARY_OP(IsNaN); DECLARE_ELEMENTWISE_OP(Dot);
DECLARE_SIMPLE_UNARY_OP(Log); DECLARE_ELEMENTWISE_OP(Minimum);
DECLARE_SIMPLE_UNARY_OP(Neg); DECLARE_ELEMENTWISE_OP(Maximum);
DECLARE_SIMPLE_UNARY_OP(Invert); DECLARE_ELEMENTWISE_OP(Equal);
DECLARE_SIMPLE_UNARY_OP(Reciprocal); DECLARE_ELEMENTWISE_OP(NotEqual);
DECLARE_SIMPLE_UNARY_OP(Round); DECLARE_ELEMENTWISE_OP(Less);
DECLARE_SIMPLE_UNARY_OP(Rsqrt); DECLARE_ELEMENTWISE_OP(LessEqual);
DECLARE_SIMPLE_UNARY_OP(Sign); DECLARE_ELEMENTWISE_OP(Greater);
DECLARE_SIMPLE_UNARY_OP(Sin); DECLARE_ELEMENTWISE_OP(GreaterEqual);
DECLARE_SIMPLE_UNARY_OP(Sqrt); DECLARE_ELEMENTWISE_OP(AddGradient);
DECLARE_SIMPLE_UNARY_OP(Square); DECLARE_ELEMENTWISE_OP(SubGradient);
DECLARE_SIMPLE_UNARY_OP(AbsGradient); DECLARE_ELEMENTWISE_OP(MulGradient);
DECLARE_SIMPLE_UNARY_OP(CosGradient); DECLARE_ELEMENTWISE_OP(DivGradient);
DECLARE_SIMPLE_UNARY_OP(ExpGradient); DECLARE_ELEMENTWISE_OP(PowGradient);
DECLARE_SIMPLE_UNARY_OP(LogGradient); DECLARE_ELEMENTWISE_OP(DotGradient);
DECLARE_SIMPLE_UNARY_OP(NegGradient); DECLARE_ELEMENTWISE_OP(MinimumGradient);
DECLARE_SIMPLE_UNARY_OP(ReciprocalGradient); DECLARE_ELEMENTWISE_OP(MaximumGradient);
DECLARE_SIMPLE_UNARY_OP(RsqrtGradient);
DECLARE_SIMPLE_UNARY_OP(SignGradient);
DECLARE_SIMPLE_UNARY_OP(SinGradient);
DECLARE_SIMPLE_UNARY_OP(SqrtGradient);
DECLARE_SIMPLE_UNARY_OP(SquareGradient);
#undef DECLARE_SIMPLE_UNARY_OP
DECLARE_SIMPLE_BINARY_OP(Add); #undef DECLARE_ELEMENTWISE_OP
DECLARE_SIMPLE_BINARY_OP(Sub);
DECLARE_SIMPLE_BINARY_OP(Mul);
DECLARE_SIMPLE_BINARY_OP(Div);
DECLARE_SIMPLE_BINARY_OP(Pow);
DECLARE_SIMPLE_BINARY_OP(Minimum);
DECLARE_SIMPLE_BINARY_OP(Maximum);
DECLARE_SIMPLE_BINARY_OP(Equal);
DECLARE_SIMPLE_BINARY_OP(NotEqual);
DECLARE_SIMPLE_BINARY_OP(Less);
DECLARE_SIMPLE_BINARY_OP(LessEqual);
DECLARE_SIMPLE_BINARY_OP(Greater);
DECLARE_SIMPLE_BINARY_OP(GreaterEqual);
DECLARE_SIMPLE_BINARY_OP(AddGradient);
DECLARE_SIMPLE_BINARY_OP(SubGradient);
DECLARE_SIMPLE_BINARY_OP(MulGradient);
DECLARE_SIMPLE_BINARY_OP(DivGradient);
DECLARE_SIMPLE_BINARY_OP(PowGradient);
DECLARE_SIMPLE_BINARY_OP(MinimumGradient);
DECLARE_SIMPLE_BINARY_OP(MaximumGradient);
#undef DECLARE_SIMPLE_BINARY_OP
} // namespace dragon } // namespace dragon
......
...@@ -13,14 +13,14 @@ void FullyConnectedOp<Context>::DoRunWithType() { ...@@ -13,14 +13,14 @@ void FullyConnectedOp<Context>::DoRunWithType() {
// Determine the number of output channels // Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N; int64_t M = X.count(0, axis), K = X.count(axis), N;
if (num_output_ <= 0) { if (out_channels_ <= 0) {
// Infer the "N" from the weights shape // Infer the "N" from the weights shape
N = W.count() / K; N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from " CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString(); << "the weights shape: " << W.DimString();
} else { } else {
// Use a fixed "N" from the argument // Use a fixed "N" from the argument
N = num_output_; N = out_channels_;
} }
vec64_t Y_dims(axis + 1); vec64_t Y_dims(axis + 1);
...@@ -82,14 +82,14 @@ void FullyConnectedGradientOp<Context>::DoRunWithType() { ...@@ -82,14 +82,14 @@ void FullyConnectedGradientOp<Context>::DoRunWithType() {
// Determine the number of output channels // Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N; int64_t M = X.count(0, axis), K = X.count(axis), N;
if (num_output_ <= 0) { if (out_channels_ <= 0) {
// Infer the "N" from the weights shape // Infer the "N" from the weights shape
N = W.count() / K; N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from " CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString(); << "the weights shape: " << W.DimString();
} else { } else {
// Use a fixed "N" from the argument // Use a fixed "N" from the argument
N = num_output_; N = out_channels_;
} }
if (dX->has_name()) { if (dX->has_name()) {
......
...@@ -22,7 +22,7 @@ class FullyConnectedOp final : public Operator<Context> { ...@@ -22,7 +22,7 @@ class FullyConnectedOp final : public Operator<Context> {
public: public:
FullyConnectedOp(const OperatorDef& def, Workspace* ws) FullyConnectedOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
num_output_(OpArg<int64_t>("num_output", 0)), out_channels_(OpArg<int64_t>("out_channels", 0)),
transW_(OpArg<int64_t>("transW", 1)) {} transW_(OpArg<int64_t>("transW", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -32,7 +32,7 @@ class FullyConnectedOp final : public Operator<Context> { ...@@ -32,7 +32,7 @@ class FullyConnectedOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
int64_t num_output_, transW_; int64_t out_channels_, transW_;
}; };
template <class Context> template <class Context>
...@@ -40,7 +40,7 @@ class FullyConnectedGradientOp final : public Operator<Context> { ...@@ -40,7 +40,7 @@ class FullyConnectedGradientOp final : public Operator<Context> {
public: public:
FullyConnectedGradientOp(const OperatorDef& def, Workspace* ws) FullyConnectedGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
num_output_(OpArg<int64_t>("num_output", 0)), out_channels_(OpArg<int64_t>("out_channels", 0)),
transW_(OpArg<int64_t>("transW", 1)) {} transW_(OpArg<int64_t>("transW", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -50,7 +50,7 @@ class FullyConnectedGradientOp final : public Operator<Context> { ...@@ -50,7 +50,7 @@ class FullyConnectedGradientOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
int64_t num_output_, transW_; int64_t out_channels_, transW_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -5,7 +5,7 @@ namespace dragon { ...@@ -5,7 +5,7 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void MatmulOp<Context>::DoRunWithType() { void MatMulOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), *Y = Output(0); auto &A = Input(0), &B = Input(1), *Y = Output(0);
CHECK_GE(A.ndim(), 2) << "\nTensor(" << A.name() + ") must be a matrix" CHECK_GE(A.ndim(), 2) << "\nTensor(" << A.name() + ") must be a matrix"
...@@ -51,13 +51,13 @@ void MatmulOp<Context>::DoRunWithType() { ...@@ -51,13 +51,13 @@ void MatmulOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void MatmulOp<Context>::RunOnDevice() { void MatMulOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
} }
template <class Context> template <class Context>
template <typename T> template <typename T>
void MatmulGradientOp<Context>::DoRunWithType() { void MatMulGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(2); auto &A = Input(0), &B = Input(1), &dY = Input(2);
auto *dA = Output(0), *dB = Output(1); auto *dA = Output(0), *dB = Output(1);
...@@ -154,32 +154,32 @@ void MatmulGradientOp<Context>::DoRunWithType() { ...@@ -154,32 +154,32 @@ void MatmulGradientOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void MatmulGradientOp<Context>::RunOnDevice() { void MatMulGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
} }
DEPLOY_CPU(Matmul); DEPLOY_CPU(MatMul);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA(Matmul); DEPLOY_CUDA(MatMul);
#endif #endif
DEPLOY_CPU(MatmulGradient); DEPLOY_CPU(MatMulGradient);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA(MatmulGradient); DEPLOY_CUDA(MatMulGradient);
#endif #endif
OPERATOR_SCHEMA(Matmul) OPERATOR_SCHEMA(MatMul)
/* A, B */ /* A, B */
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1); .NumOutputs(1);
OPERATOR_SCHEMA(MatmulGradient) OPERATOR_SCHEMA(MatMulGradient)
/* A, B, dY */ /* A, B, dY */
.NumInputs(3) .NumInputs(3)
/* dA, dB */ /* dA, dB */
.NumOutputs(2); .NumOutputs(2);
REGISTER_GRADIENT(Matmul, GenericGradientMaker); REGISTER_GRADIENT(MatMul, GenericGradientMaker);
} // namespace dragon } // namespace dragon
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class MatmulOp final : public Operator<Context> { class MatMulOp final : public Operator<Context> {
public: public:
MatmulOp(const OperatorDef& def, Workspace* ws) MatMulOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
transA_(OpArg<int64_t>("transA", 0)), transA_(OpArg<int64_t>("transA", 0)),
transB_(OpArg<int64_t>("transB", 0)) {} transB_(OpArg<int64_t>("transB", 0)) {}
...@@ -36,9 +36,9 @@ class MatmulOp final : public Operator<Context> { ...@@ -36,9 +36,9 @@ class MatmulOp final : public Operator<Context> {
}; };
template <class Context> template <class Context>
class MatmulGradientOp final : public Operator<Context> { class MatMulGradientOp final : public Operator<Context> {
public: public:
MatmulGradientOp(const OperatorDef& def, Workspace* ws) MatMulGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
transA_(OpArg<int64_t>("transA", 0)), transA_(OpArg<int64_t>("transA", 0)),
transB_(OpArg<int64_t>("transB", 0)) {} transB_(OpArg<int64_t>("transB", 0)) {}
......
...@@ -94,8 +94,8 @@ void GroupNormGradientOp<Context>::DoRunWithType() { ...@@ -94,8 +94,8 @@ void GroupNormGradientOp<Context>::DoRunWithType() {
template <class Context> template <class Context>
void GroupNormGradientOp<Context>::RunOnDevice() { void GroupNormGradientOp<Context>::RunOnDevice() {
DetermineBaseArguments(); DetermineBaseArguments();
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) { if (XIsType(Input(0), float)) {
DoRunWithType<float, float>(); DoRunWithType<float, float>();
} else if (XIsType(Input(0), float16)) { } else if (XIsType(Input(0), float16)) {
......
...@@ -42,9 +42,6 @@ class GroupNormOpBase : public Operator<Context> { ...@@ -42,9 +42,6 @@ class GroupNormOpBase : public Operator<Context> {
// Check the channels and groups // Check the channels and groups
CHECK_EQ(C_ % G_, 0) << "\nThe " << C_ << " channels " CHECK_EQ(C_ % G_, 0) << "\nThe " << C_ << " channels "
<< "can not be split into " << G_ << " groups."; << "can not be split into " << G_ << " groups.";
if (G_ == C_ && X.ndim() == 2) {
LOG(WARNING) << "The 2d input will output all zeros.";
}
} }
protected: protected:
......
#include "dragon/operators/training/adam_update_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void AdamUpdateOp<Context>::Compute(Tensor* dX) { void AdamUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
auto* m = ws()->CreateTensor("/mnt/" + slot() + "/m")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
auto* v = ws()->CreateTensor("/mnt/" + slot() + "/v")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
t_++; t_++;
auto beta1 = param("beta1"); auto beta1 = Parameter("beta1"), beta2 = Parameter("beta2");
auto beta2 = param("beta2");
auto coef = sqrt(1.f - pow(beta2, t_)) / (1.f - pow(beta1, t_)); auto coef = sqrt(1.f - pow(beta2, t_)) / (1.f - pow(beta1, t_));
kernel::AdamUpdate( kernel::AdamUpdate(
dX->count(), dX->count(),
param("base_lr") * coef * lr_mult(), Parameter("base_lr") * coef * this->lr_mult_,
beta1, beta1,
beta2, beta2,
param("eps"), Parameter("eps"),
dX->template mutable_data<float, Context>(), dX->template mutable_data<float, Context>(),
m, Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
v, Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx()); ctx());
} }
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_ADAM_UPDATE_OP_H_
#define DRAGON_OPERATORS_TRAINING_ADAM_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
namespace dragon {
template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), t_(0) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void Compute(Tensor* dX) override;
protected:
int t_;
// float lr_, beta1_, beta2_, eps_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_ADAM_UPDATE_OP_H_
#include "dragon/operators/training/nesterov_update_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void NesterovUpdateOp<Context>::Compute(Tensor* dX) { void NesterovUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
auto* m = ws()->CreateTensor("/mnt/" + slot() + "/m")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
kernel::NesterovUpdate( kernel::NesterovUpdate(
dX->count(), dX->count(),
param("base_lr") * lr_mult(), Parameter("base_lr") * this->lr_mult_,
param("momentum"), Parameter("momentum"),
dX->template mutable_data<float, Context>(), dX->template mutable_data<float, Context>(),
m, Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx()); ctx());
} }
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_TRAINING_NESTEROV_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
namespace dragon {
template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void Compute(Tensor* dX) override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_NESTEROV_UPDATE_OP_H_
#include "dragon/operators/training/rmsprop_update_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void RMSPropUpdateOp<Context>::Compute(Tensor* dX) { void RMSpropUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
auto* m = ws()->CreateTensor("/mnt/" + slot() + "/m")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
auto* v = ws()->CreateTensor("/mnt/" + slot() + "/v")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
kernel::RMSPropUpdate( kernel::RMSPropUpdate(
dX->count(), dX->count(),
param("base_lr") * lr_mult(), Parameter("base_lr") * this->lr_mult_,
param("momentum"), Parameter("momentum"),
param("decay"), Parameter("decay"),
param("eps"), Parameter("eps"),
dX->template mutable_data<float, Context>(), dX->template mutable_data<float, Context>(),
m, Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
v, Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx()); ctx());
} }
DEPLOY_CPU(RMSPropUpdate); DEPLOY_CPU(RMSpropUpdate);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA(RMSPropUpdate); DEPLOY_CUDA(RMSpropUpdate);
#endif #endif
OPERATOR_SCHEMA(RMSPropUpdate) OPERATOR_SCHEMA(RMSpropUpdate)
/* dX */ /* dX */
.NumInputs(1) .NumInputs(1)
/* X */ /* X */
.NumOutputs(1); .NumOutputs(1);
NO_GRADIENT(RMSPropUpdate); NO_GRADIENT(RMSpropUpdate);
} // namespace dragon } // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_RMSPROP_UPDATE_OP_H_
#define DRAGON_OPERATORS_TRAINING_RMSPROP_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
namespace dragon {
template <class Context>
class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSPropUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void Compute(Tensor* dX) override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_RMSPROP_UPDATE_OP_H_
#include "dragon/operators/training/sgd_update_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void SGDUpdateOp<Context>::Compute(Tensor* dX) { void SGDUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
auto* m = ws()->CreateTensor("/mnt/" + slot() + "/m")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
// Momentum Correction, See arXiv:1706.02677 // Momentum Correction, See arXiv:1706.02677
auto lr = param("base_lr") * lr_mult(); auto lr = Parameter("base_lr") * this->lr_mult_;
if (last_lr_ > 0) correction_ = lr / last_lr_; if (last_lr_ > 0) correction_ = lr / last_lr_;
last_lr_ = lr; // Record the last value last_lr_ = lr; // Record the last value
kernel::SGDUpdate( kernel::SGDUpdate(
dX->count(), dX->count(),
lr, lr,
param("momentum") * correction_, Parameter("momentum") * correction_,
dX->template mutable_data<float, Context>(), dX->template mutable_data<float, Context>(),
m, Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx()); ctx());
} }
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_SGD_UPDATE_OP_H_
#define DRAGON_OPERATORS_TRAINING_SGD_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
namespace dragon {
template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> {
public:
SGDUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), last_lr_(-1.f), correction_(1.f) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void Compute(Tensor* dX) override;
protected:
float last_lr_, correction_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_SGD_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/cast.h" #include "dragon/operators/training/update_ops.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
float UpdateOpBase<Context>::param(const string& name) const { Tensor* UpdateOpBase<Context>::Slot(const string& name) {
return ws() return Buffer(Output(0)->name() + "/" + name);
->GetTensor(slot_ + "/" + name) }
->template mutable_data<float, CPUContext>()[0];
template <class Context>
float UpdateOpBase<Context>::Parameter(const string& name) const {
auto* P = ws()->GetTensor("/share/hyper/" + handle() + "/" + name);
return P->template mutable_data<float, CPUContext>()[0];
} }
template <class Context> template <class Context>
template <typename T> template <typename T>
void UpdateOpBase<Context>::Process(Tensor* dX, Tensor* X) { void UpdateOpBase<Context>::AdjustGradient(Tensor* dX, Tensor* X) {
// Scale // Scale
auto scale_factor = param("scale_gradient"); auto scale = Parameter("scale");
if (scale_factor != 1.f) { if (scale != 1.f) {
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
math::Scale(dX->count(), scale_factor, dx, dx, ctx()); math::Scale(dX->count(), scale, dx, dx, ctx());
} }
// Clip // Clip
auto clip_thresh = param("clip_gradient"); auto clip_norm = Parameter("clip_norm");
if (clip_thresh > 0.f) { if (clip_norm > 0.f) {
T sumsq_grad;
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
math::Dot(dX->count(), dx, dx, &sumsq_grad, ctx()); auto grad_norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx()));
auto l2_norm = sqrt(cast::to<float>(sumsq_grad)); if (grad_norm > clip_norm) {
if (l2_norm > clip_thresh) { math::Scale(dX->count(), clip_norm / grad_norm, dx, dx, ctx());
math::Scale(dX->count(), clip_thresh / l2_norm, dx, dx, ctx());
} }
} }
// L2 Decay // Penalty
auto l2_decay = param("l2_decay") * decay_mult_; auto weight_decay = Parameter("weight_decay");
if (l2_decay > 0) { if (weight_decay > 0.f) {
if (XIsType((*X), float16)) { if (XIsType((*X), float16)) {
kernel::MixedPrecL2Decay( kernel::MixedPrecL2Penalty(
X->count(), X->count(),
l2_decay, weight_decay * decay_mult_,
X->template data<float16, Context>(), X->template data<float16, Context>(),
dX->template mutable_data<float, Context>(), dX->template mutable_data<float, Context>(),
ctx()); ctx());
} else { } else {
math::Axpy( math::Axpy(
X->count(), X->count(),
l2_decay, weight_decay * decay_mult_,
X->template data<T, Context>(), X->template data<T, Context>(),
dX->template mutable_data<T, Context>(), dX->template mutable_data<T, Context>(),
ctx()); ctx());
...@@ -56,7 +57,7 @@ void UpdateOpBase<Context>::Process(Tensor* dX, Tensor* X) { ...@@ -56,7 +57,7 @@ void UpdateOpBase<Context>::Process(Tensor* dX, Tensor* X) {
template <class Context> template <class Context>
template <typename T> template <typename T>
void UpdateOpBase<Context>::Apply(Tensor* dX, Tensor* X) { void UpdateOpBase<Context>::ApplyUpdate(Tensor* dX, Tensor* X) {
if (XIsType((*X), float16)) { if (XIsType((*X), float16)) {
kernel::MixedPrecUpdate( kernel::MixedPrecUpdate(
X->count(), X->count(),
...@@ -64,9 +65,9 @@ void UpdateOpBase<Context>::Apply(Tensor* dX, Tensor* X) { ...@@ -64,9 +65,9 @@ void UpdateOpBase<Context>::Apply(Tensor* dX, Tensor* X) {
X->template mutable_data<float16, Context>(), X->template mutable_data<float16, Context>(),
ctx()); ctx());
} else { } else {
math::Axpy( math::Sub(
X->count(), X->count(),
-1.f, X->template data<T, Context>(),
dX->template data<T, Context>(), dX->template data<T, Context>(),
X->template mutable_data<T, Context>(), X->template mutable_data<T, Context>(),
ctx()); ctx());
...@@ -85,19 +86,19 @@ void UpdateOpBase<Context>::RunOnDevice() { ...@@ -85,19 +86,19 @@ void UpdateOpBase<Context>::RunOnDevice() {
<< "\nGot" << X->DimString() << " and " << dX.DimString(); << "\nGot" << X->DimString() << " and " << dX.DimString();
if (XIsType(dX, float)) { if (XIsType(dX, float)) {
Process<float>(&dX, X); AdjustGradient<float>(&dX, X);
Compute(&dX); ComputeUpdate(&dX);
Apply<float>(&dX, X); ApplyUpdate<float>(&dX, X);
} else if (XIsType(dX, float16)) { } else if (XIsType(dX, float16)) {
auto* dX_fp32 = ws()->CreateTensor(dX.name() + "/fp32"); auto* dX_cast = ws()->CreateTensor(dX.name() + "[float32]");
kernel::Cast( kernel::Cast(
dX.count(), dX.count(),
dX.template data<float16, Context>(), dX.template data<float16, Context>(),
dX_fp32->ReshapeLike(dX)->template mutable_data<float, Context>(), dX_cast->ReshapeLike(dX)->template mutable_data<float, Context>(),
ctx()); ctx());
Process<float>(dX_fp32, X); AdjustGradient<float>(dX_cast, X);
Compute(dX_fp32); ComputeUpdate(dX_cast);
Apply<float>(dX_fp32, X); ApplyUpdate<float>(dX_cast, X);
} else { } else {
LOG(FATAL) << TypeString(dX, {"float16", "float32"}); LOG(FATAL) << TypeString(dX, {"float16", "float32"});
} }
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_UPDATE_OP_BASE_H_
#define DRAGON_OPERATORS_TRAINING_UPDATE_OP_BASE_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class UpdateOpBase : public Operator<Context> {
public:
UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
lr_mult_(OpArg<float>("lr_mult", 1.f)),
decay_mult_(OpArg<float>("decay_mult", 1.f)),
slot_(OpArg<string>("slot", "")) {
CHECK(!slot_.empty()) << "\nRequired a non-empty slot";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
virtual void Compute(Tensor* dX) = 0;
template <typename T>
void Process(Tensor* dX, Tensor* X);
template <typename T>
void Apply(Tensor* dX, Tensor* X);
string slot() {
return slot_ + "/" + Output(0)->name();
}
float param(const string& name) const;
float lr_mult() const {
return lr_mult_;
}
protected:
string slot_;
float lr_mult_, decay_mult_;
};
#define USE_PARAM_UPDATE_FUNCTIONS \
using UpdateOpBase<Context>::slot; \
using UpdateOpBase<Context>::param; \
using UpdateOpBase<Context>::lr_mult
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_UPDATE_OP_BASE_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_UPDATE_OPS_H_
#define DRAGON_OPERATORS_TRAINING_UPDATE_OPS_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class UpdateOpBase : public Operator<Context> {
public:
UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
lr_mult_(OpArg<float>("lr_mult", 1.f)),
decay_mult_(OpArg<float>("decay_mult", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
virtual void ComputeUpdate(Tensor* dX) = 0;
template <typename T>
void AdjustGradient(Tensor* dX, Tensor* X);
template <typename T>
void ApplyUpdate(Tensor* dX, Tensor* X);
Tensor* Slot(const string& name);
float Parameter(const string& name) const;
protected:
float lr_mult_, decay_mult_;
};
#define USE_PARAM_UPDATE_FUNCTIONS \
using UpdateOpBase<Context>::Slot; \
using UpdateOpBase<Context>::Parameter
template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> {
public:
SGDUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), last_lr_(-1.f), correction_(1.f) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void ComputeUpdate(Tensor* dX) override;
protected:
float last_lr_, correction_;
};
template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void ComputeUpdate(Tensor* dX) override;
};
template <class Context>
class RMSpropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSpropUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void ComputeUpdate(Tensor* dX) override;
};
template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), t_(0) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void ComputeUpdate(Tensor* dX) override;
protected:
int t_;
};
#undef USE_PARAM_UPDATE_FUNCTIONS
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_UPDATE_OPS_H_
...@@ -59,9 +59,9 @@ void BiasAddGradientOp<Context>::DoRunWithType() { ...@@ -59,9 +59,9 @@ void BiasAddGradientOp<Context>::DoRunWithType() {
dB->Reshape({dY.dim(-1)}); dB->Reshape({dY.dim(-1)});
} }
math::ReduceSum( math::ReduceSum(
3, dims.size(),
dims.data(), dims.data(),
2, axes.size(),
axes.data(), axes.data(),
1.f, 1.f,
dY.template data<T, Context>(), dY.template data<T, Context>(),
......
...@@ -16,7 +16,7 @@ void Conv2dOp<Context>::DoRunWithType() { ...@@ -16,7 +16,7 @@ void Conv2dOp<Context>::DoRunWithType() {
auto* y = Y->template mutable_data<T, Context>(); auto* y = Y->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) { for (int i = 0; i < X.dim(0); ++i) {
Wx(x + i * x_ofs_, w, y + i * y_ofs_); Wx(x + i * x_offset_, w, y + i * y_offset_);
} }
if (HasBias()) { if (HasBias()) {
...@@ -46,7 +46,7 @@ void Conv2dGradientOp<Context>::DoRunWithType() { ...@@ -46,7 +46,7 @@ void Conv2dGradientOp<Context>::DoRunWithType() {
auto* w = W.template data<T, Context>(); auto* w = W.template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) { for (int i = 0; i < X.dim(0); ++i) {
Dx(dy + i * y_ofs_, w, dx + i * x_ofs_); Dx(dy + i * y_offset_, w, dx + i * x_offset_);
} }
} }
...@@ -55,7 +55,7 @@ void Conv2dGradientOp<Context>::DoRunWithType() { ...@@ -55,7 +55,7 @@ void Conv2dGradientOp<Context>::DoRunWithType() {
auto* x = X.template data<T, Context>(); auto* x = X.template data<T, Context>();
auto* dw = dW->template mutable_data<T, Context>(); auto* dw = dW->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) { for (int i = 0; i < X.dim(0); ++i) {
Dw(dy + i * y_ofs_, x + i * x_ofs_, dw, i > 0); Dw(dy + i * y_offset_, x + i * x_offset_, dw, i > 0);
} }
} }
......
...@@ -73,8 +73,8 @@ void CuDNNConv2dOp<Context>::ResetDesc() { ...@@ -73,8 +73,8 @@ void CuDNNConv2dOp<Context>::ResetDesc() {
filter_desc_, filter_desc_,
CuDNNType<T>::type, CuDNNType<T>::type,
format_, format_,
num_output_ / cudnn_group_, out_channels_ / cudnn_group_,
channels_ / group_, in_channels_ / group_,
kshape_[0], kshape_[0],
kshape_[1])); kshape_[1]));
#else #else
...@@ -82,14 +82,15 @@ void CuDNNConv2dOp<Context>::ResetDesc() { ...@@ -82,14 +82,15 @@ void CuDNNConv2dOp<Context>::ResetDesc() {
filter_desc_, filter_desc_,
CuDNNType<T>::type, CuDNNType<T>::type,
format_, format_,
num_output_ / cudnn_group_, out_channels_ / cudnn_group_,
channels_ / group_, in_channels_ / group_,
kshape_[0], kshape_[0],
kshape_[1])); kshape_[1]));
#endif #endif
// Determine the bias shape // Determine the bias shape
if (HasBias()) { if (HasBias()) {
CuDNNSetBiasDesc<T>(&bias_desc_, X.ndim(), num_output_, data_format()); CuDNNSetBiasDesc<T>(
&bias_desc_, X.ndim(), out_channels_, data_format());
} }
} }
// Set the conv configuration // Set the conv configuration
...@@ -179,16 +180,16 @@ void CuDNNConv2dOp<Context>::DoRunWithType() { ...@@ -179,16 +180,16 @@ void CuDNNConv2dOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
input_desc_, input_desc_,
x + x_ofs_ * g, x + x_offset_ * g,
filter_desc_, filter_desc_,
w + w_ofs_ * g, w + w_offset_ * g,
conv_desc_, conv_desc_,
fwd_algo_, fwd_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_nbytes_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
output_desc_, output_desc_,
y + y_ofs_ * g)); y + y_offset_ * g));
} }
if (HasBias()) { if (HasBias()) {
...@@ -217,11 +218,11 @@ void CuDNNConv2dOp<Context>::RunOnDevice() { ...@@ -217,11 +218,11 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(); ConvOpBase<Context>::Reshape();
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
x_ofs_ = Input(0).stride(0) / cudnn_group_; x_offset_ = Input(0).stride(0) / cudnn_group_;
y_ofs_ = Output(0)->stride(0) / cudnn_group_; y_offset_ = Output(0)->stride(0) / cudnn_group_;
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
x_ofs_ = Input(0).dim(-1) / cudnn_group_; x_offset_ = Input(0).dim(-1) / cudnn_group_;
y_ofs_ = Output(0)->dim(-1) / cudnn_group_; y_offset_ = Output(0)->dim(-1) / cudnn_group_;
} }
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
...@@ -294,8 +295,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -294,8 +295,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
filter_desc_, filter_desc_,
CuDNNType<T>::type, CuDNNType<T>::type,
format_, format_,
num_output_ / cudnn_group_, out_channels_ / cudnn_group_,
channels_ / group_, in_channels_ / group_,
kshape_[0], kshape_[0],
kshape_[1])); kshape_[1]));
#else #else
...@@ -303,14 +304,15 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -303,14 +304,15 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
filter_desc_, filter_desc_,
CuDNNType<T>::type, CuDNNType<T>::type,
format_, format_,
num_output_ / cudnn_group_, out_channels_ / cudnn_group_,
channels_ / group_, in_channels_ / group_,
kshape_[0], kshape_[0],
kshape_[1])); kshape_[1]));
#endif #endif
// Determine the bias shape // Determine the bias shape
if (HasBias()) { if (HasBias()) {
CuDNNSetBiasDesc<T>(&bias_desc_, X.ndim(), num_output_, data_format()); CuDNNSetBiasDesc<T>(
&bias_desc_, X.ndim(), out_channels_, data_format());
} }
} }
// Set the conv configuration // Set the conv configuration
...@@ -470,16 +472,16 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() { ...@@ -470,16 +472,16 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
output_desc_, output_desc_,
x + x_ofs_ * g, x + x_offset_ * g,
input_desc_, input_desc_,
dy + y_ofs_ * g, dy + y_offset_ * g,
conv_desc_, conv_desc_,
bwd_filter_algo_, bwd_filter_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_nbytes_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
filter_desc_, filter_desc_,
dw + w_ofs_ * g)); dw + w_offset_ * g));
} }
} }
...@@ -491,16 +493,16 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() { ...@@ -491,16 +493,16 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
filter_desc_, filter_desc_,
w + w_ofs_ * g, w + w_offset_ * g,
input_desc_, input_desc_,
dy + y_ofs_ * g, dy + y_offset_ * g,
conv_desc_, conv_desc_,
bwd_data_algo_, bwd_data_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_nbytes_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
output_desc_, output_desc_,
dx + x_ofs_ * g)); dx + x_offset_ * g));
} }
} }
} }
...@@ -518,11 +520,11 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() { ...@@ -518,11 +520,11 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(true); ConvOpBase<Context>::Reshape(true);
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
x_ofs_ = Input(0).stride(0) / cudnn_group_; x_offset_ = Input(0).stride(0) / cudnn_group_;
y_ofs_ = Input(-1).stride(0) / cudnn_group_; y_offset_ = Input(-1).stride(0) / cudnn_group_;
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
x_ofs_ = Input(0).dim(-1) / cudnn_group_; x_offset_ = Input(0).dim(-1) / cudnn_group_;
y_ofs_ = Input(-1).dim(-1) / cudnn_group_; y_offset_ = Input(-1).dim(-1) / cudnn_group_;
} }
DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1));
......
...@@ -8,12 +8,7 @@ template <class Context> ...@@ -8,12 +8,7 @@ template <class Context>
template <typename T> template <typename T>
void ConvTranspose2dOp<Context>::DoRunWithType() { void ConvTranspose2dOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0); auto &X = Input(0), &W = Input(1), *Y = Output(0);
ConvOpBase<Context>::Reshape(); ConvOpBase<Context>::Reshape();
// Fix the output shape for im2col/col2im
for (int i = 0; i < num_axes_; i++) {
out_shape_[i] = X.dim(axis_ + i);
}
TENSOR_FILL(W, w_shape_); TENSOR_FILL(W, w_shape_);
auto* x = X.template data<T, Context>(); auto* x = X.template data<T, Context>();
...@@ -21,7 +16,7 @@ void ConvTranspose2dOp<Context>::DoRunWithType() { ...@@ -21,7 +16,7 @@ void ConvTranspose2dOp<Context>::DoRunWithType() {
auto* y = Y->template mutable_data<T, Context>(); auto* y = Y->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) { for (int i = 0; i < X.dim(0); ++i) {
Dx(x + i * x_ofs_, w, y + i * y_ofs_); Dx(x + i * x_offset_, w, y + i * y_offset_);
} }
if (HasBias()) { if (HasBias()) {
...@@ -44,19 +39,14 @@ template <typename T> ...@@ -44,19 +39,14 @@ template <typename T>
void ConvTranspose2dGradientOp<Context>::DoRunWithType() { void ConvTranspose2dGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2); auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
ConvOpBase<Context>::Reshape(true); ConvOpBase<Context>::Reshape(true);
// Fix the output shape for im2col/col2im
for (int i = 0; i < num_axes_; i++) {
out_shape_[i] = X.dim(axis_ + i);
}
if (dX->has_name()) { if (dX->has_name()) {
auto* dy = dY.template data<T, Context>(); auto* dy = dY.template data<T, Context>();
auto* w = W.template data<T, Context>(); auto* w = W.template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) { for (int i = 0; i < X.dim(0); ++i) {
Wx(dy + i * y_ofs_, w, dx + i * x_ofs_); Wx(dy + i * y_offset_, w, dx + i * x_offset_);
} }
} }
...@@ -65,7 +55,7 @@ void ConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -65,7 +55,7 @@ void ConvTranspose2dGradientOp<Context>::DoRunWithType() {
auto* dy = dY.template data<T, Context>(); auto* dy = dY.template data<T, Context>();
auto* dw = dW->template mutable_data<T, Context>(); auto* dw = dW->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) { for (int i = 0; i < X.dim(0); ++i) {
Dw(x + i * x_ofs_, dy + i * y_ofs_, dw, i > 0); Dw(x + i * x_offset_, dy + i * y_offset_, dw, i > 0);
} }
} }
......
...@@ -71,8 +71,8 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() { ...@@ -71,8 +71,8 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() {
filter_desc_, filter_desc_,
CuDNNType<T>::type, CuDNNType<T>::type,
format_, format_,
channels_ / cudnn_group_, in_channels_ / cudnn_group_,
num_output_ / group_, out_channels_ / group_,
kshape_[0], kshape_[0],
kshape_[1])); kshape_[1]));
#else #else
...@@ -80,14 +80,15 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() { ...@@ -80,14 +80,15 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() {
filter_desc_, filter_desc_,
CuDNNType<T>::type, CuDNNType<T>::type,
format_, format_,
channels_ / cudnn_group_, in_channels_ / cudnn_group_,
num_output_ / group_, out_channels_ / group_,
kshape_[0], kshape_[0],
kshape_[1])); kshape_[1]));
#endif #endif
// Determine the bias shape // Determine the bias shape
if (HasBias()) { if (HasBias()) {
CuDNNSetBiasDesc<T>(&bias_desc_, X.ndim(), num_output_, data_format()); CuDNNSetBiasDesc<T>(
&bias_desc_, X.ndim(), out_channels_, data_format());
} }
} }
// Set the conv configuration // Set the conv configuration
...@@ -180,16 +181,16 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() { ...@@ -180,16 +181,16 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
filter_desc_, filter_desc_,
w + w_ofs_ * g, w + w_offset_ * g,
input_desc_, input_desc_,
x + x_ofs_ * g, x + x_offset_ * g,
conv_desc_, conv_desc_,
fwd_algo_, fwd_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_nbytes_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
output_desc_, output_desc_,
y + y_ofs_ * g)); y + y_offset_ * g));
} }
if (HasBias()) { if (HasBias()) {
...@@ -218,11 +219,11 @@ void CuDNNConvTranspose2dOp<Context>::RunOnDevice() { ...@@ -218,11 +219,11 @@ void CuDNNConvTranspose2dOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(); ConvOpBase<Context>::Reshape();
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
x_ofs_ = Input(0).stride(0) / cudnn_group_; x_offset_ = Input(0).stride(0) / cudnn_group_;
y_ofs_ = Output(0)->stride(0) / cudnn_group_; y_offset_ = Output(0)->stride(0) / cudnn_group_;
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
x_ofs_ = Input(0).dim(-1) / cudnn_group_; x_offset_ = Input(0).dim(-1) / cudnn_group_;
y_ofs_ = Output(0)->dim(-1) / cudnn_group_; y_offset_ = Output(0)->dim(-1) / cudnn_group_;
} }
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
...@@ -293,8 +294,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -293,8 +294,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
filter_desc_, filter_desc_,
CuDNNType<T>::type, CuDNNType<T>::type,
format_, format_,
channels_ / cudnn_group_, in_channels_ / cudnn_group_,
num_output_ / group_, out_channels_ / group_,
kshape_[0], kshape_[0],
kshape_[1])); kshape_[1]));
#else #else
...@@ -302,14 +303,15 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -302,14 +303,15 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
filter_desc, filter_desc,
CuDNNType<T>::type, CuDNNType<T>::type,
format_, format_,
channels_ / cudnn_group_, in_channels_ / cudnn_group_,
num_output_ / group_, out_channels_ / group_,
kshape_[0], kshape_[0],
kshape_[1])); kshape_[1]));
#endif #endif
// Determine the bias shape // Determine the bias shape
if (HasBias()) { if (HasBias()) {
CuDNNSetBiasDesc<T>(&bias_desc_, X.ndim(), num_output_, data_format()); CuDNNSetBiasDesc<T>(
&bias_desc_, X.ndim(), out_channels_, data_format());
} }
} }
// Set the conv configuration // Set the conv configuration
...@@ -466,16 +468,16 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -466,16 +468,16 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
input_desc_, input_desc_,
dy + y_ofs_ * g, dy + y_offset_ * g,
output_desc_, output_desc_,
x + x_ofs_ * g, x + x_offset_ * g,
conv_desc_, conv_desc_,
bwd_filter_algo_, bwd_filter_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_nbytes_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
filter_desc_, filter_desc_,
dw + w_ofs_ * g)); dw + w_offset_ * g));
} }
} }
...@@ -487,16 +489,16 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -487,16 +489,16 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
input_desc_, input_desc_,
dy + y_ofs_ * g, dy + y_offset_ * g,
filter_desc_, filter_desc_,
w + w_ofs_ * g, w + w_offset_ * g,
conv_desc_, conv_desc_,
bwd_data_algo_, bwd_data_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_nbytes_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
output_desc_, output_desc_,
dx + x_ofs_ * g)); dx + x_offset_ * g));
} }
} }
} }
...@@ -514,11 +516,11 @@ void CuDNNConvTranspose2dGradientOp<Context>::RunOnDevice() { ...@@ -514,11 +516,11 @@ void CuDNNConvTranspose2dGradientOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(true); ConvOpBase<Context>::Reshape(true);
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
x_ofs_ = Input(0).stride(0) / cudnn_group_; x_offset_ = Input(0).stride(0) / cudnn_group_;
y_ofs_ = Input(-1).stride(0) / cudnn_group_; y_offset_ = Input(-1).stride(0) / cudnn_group_;
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
x_ofs_ = Input(0).dim(-1) / cudnn_group_; x_offset_ = Input(0).dim(-1) / cudnn_group_;
y_ofs_ = Input(-1).dim(-1) / cudnn_group_; y_offset_ = Input(-1).dim(-1) / cudnn_group_;
} }
DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1));
......
...@@ -10,10 +10,11 @@ namespace dragon { ...@@ -10,10 +10,11 @@ namespace dragon {
template <class Context> template <class Context>
void ConvOpBase<Context>::ComputeOutShape() { void ConvOpBase<Context>::ComputeOutShape() {
auto X_dims = Input(0).dims();
out_shape_.clear(); out_shape_.clear();
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; i++) {
if (!Transposed()) { if (!Transposed()) {
auto idm = x_shape_[axis_ + i]; auto idm = X_dims[axis_ + i];
auto dk = dilation_[i] * (kshape_[i] - 1) + 1; auto dk = dilation_[i] * (kshape_[i] - 1) + 1;
if (!str::find(padding_, "SAME")) { if (!str::find(padding_, "SAME")) {
// Explicit pads // Explicit pads
...@@ -32,7 +33,7 @@ void ConvOpBase<Context>::ComputeOutShape() { ...@@ -32,7 +33,7 @@ void ConvOpBase<Context>::ComputeOutShape() {
} // SAME_LOWER or SAME } // SAME_LOWER or SAME
} }
} else { } else {
auto idm = x_shape_[axis_ + i]; auto idm = X_dims[axis_ + i];
auto dk = dilation_[i] * (kshape_[i] - 1) + 1; auto dk = dilation_[i] * (kshape_[i] - 1) + 1;
if (!str::find(padding_, "SAME")) { if (!str::find(padding_, "SAME")) {
// Explicit pads // Explicit pads
...@@ -79,13 +80,11 @@ template <class Context> ...@@ -79,13 +80,11 @@ template <class Context>
template <typename T> template <typename T>
void ConvOpBase<Context>::Wx(const T* x, const T* w, T* y, bool skip) { void ConvOpBase<Context>::Wx(const T* x, const T* w, T* y, bool skip) {
auto* col = x; auto* col = x;
if (!is_1x1_) { if (!is_1x1_) {
auto* scratch = ws()->template data<T, Context>({col_dim_})[0]; auto* scratch = ws()->template data<T, Context>({col_dim_})[0];
if (!skip) Im2Col(x, scratch); if (!skip) Im2Col(x, scratch);
col = scratch; col = scratch;
} }
for (int g = 0; g < group_; g++) { for (int g = 0; g < group_; g++) {
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
math::Gemm( math::Gemm(
...@@ -95,10 +94,10 @@ void ConvOpBase<Context>::Wx(const T* x, const T* w, T* y, bool skip) { ...@@ -95,10 +94,10 @@ void ConvOpBase<Context>::Wx(const T* x, const T* w, T* y, bool skip) {
conv_out_dim_, conv_out_dim_,
kernel_dim_, kernel_dim_,
1.f, 1.f,
w + w_ofs_ * g, w + w_offset_ * g,
col + col_ofs_ * g, col + col_offset_ * g,
0.f, 0.f,
y + output_ofs_ * g, y + out_offset_ * g,
ctx()); ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
math::Gemm( math::Gemm(
...@@ -121,10 +120,11 @@ template <class Context> ...@@ -121,10 +120,11 @@ template <class Context>
template <typename T> template <typename T>
void ConvOpBase<Context>::Pb(const T* bias, T* y) { void ConvOpBase<Context>::Pb(const T* bias, T* y) {
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
kernel::BiasAdd(Input(0).dim(0), num_output_, out_dim_, y, bias, y, ctx()); kernel::BiasAdd(
Input(0).dim(0), out_channels_, out_dim_, y, bias, y, ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
kernel::BiasAdd( kernel::BiasAdd(
Input(0).dim(0) * out_dim_, num_output_, 1, y, bias, y, ctx()); Input(0).dim(0) * out_dim_, out_channels_, 1, y, bias, y, ctx());
} }
} }
...@@ -141,10 +141,10 @@ void ConvOpBase<Context>::Dx(const T* dy, const T* w, T* dx) { ...@@ -141,10 +141,10 @@ void ConvOpBase<Context>::Dx(const T* dy, const T* w, T* dx) {
conv_out_dim_, conv_out_dim_,
conv_out_channels_ / group_, conv_out_channels_ / group_,
1.f, 1.f,
w + w_ofs_ * g, w + w_offset_ * g,
dy + output_ofs_ * g, dy + out_offset_ * g,
0.f, 0.f,
col + col_ofs_ * g, col + col_offset_ * g,
ctx()); ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
math::Gemm( math::Gemm(
...@@ -168,13 +168,11 @@ template <class Context> ...@@ -168,13 +168,11 @@ template <class Context>
template <typename T> template <typename T>
void ConvOpBase<Context>::Dw(const T* dy, const T* x, T* dw, bool accum) { void ConvOpBase<Context>::Dw(const T* dy, const T* x, T* dw, bool accum) {
auto* col = x; auto* col = x;
if (!is_1x1_) { if (!is_1x1_) {
auto* scratch = ws()->template data<T, Context>({col_dim_})[0]; auto* scratch = ws()->template data<T, Context>({col_dim_})[0];
Im2Col(x, scratch); Im2Col(x, scratch);
col = scratch; col = scratch;
} }
for (int g = 0; g < group_; g++) { for (int g = 0; g < group_; g++) {
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
math::Gemm( math::Gemm(
...@@ -184,10 +182,10 @@ void ConvOpBase<Context>::Dw(const T* dy, const T* x, T* dw, bool accum) { ...@@ -184,10 +182,10 @@ void ConvOpBase<Context>::Dw(const T* dy, const T* x, T* dw, bool accum) {
kernel_dim_, kernel_dim_,
conv_out_dim_, conv_out_dim_,
1.f, 1.f,
dy + output_ofs_ * g, dy + out_offset_ * g,
col + col_ofs_ * g, col + col_offset_ * g,
accum ? 1.f : 0.f, accum ? 1.f : 0.f,
dw + w_ofs_ * g, dw + w_offset_ * g,
ctx()); ctx());
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
math::Gemm( math::Gemm(
...@@ -211,10 +209,10 @@ template <typename T> ...@@ -211,10 +209,10 @@ template <typename T>
void ConvOpBase<Context>::Db(const T* dy, T* db) { void ConvOpBase<Context>::Db(const T* dy, T* db) {
vec32_t dims, axes; vec32_t dims, axes;
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
dims = {(int)Input(0).dim(0), (int)num_output_, (int)out_dim_}; dims = {(int)Input(0).dim(0), (int)out_channels_, (int)out_dim_};
axes = {0, 2}; axes = {0, 2};
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
dims = {(int)Input(0).dim(0), (int)out_dim_, (int)num_output_}; dims = {(int)Input(0).dim(0), (int)out_dim_, (int)out_channels_};
axes = {0, 1}; axes = {0, 1};
} }
math::ReduceSum(3, dims.data(), 2, axes.data(), 1.f, dy, db, ctx()); math::ReduceSum(3, dims.data(), 2, axes.data(), 1.f, dy, db, ctx());
...@@ -223,16 +221,15 @@ void ConvOpBase<Context>::Db(const T* dy, T* db) { ...@@ -223,16 +221,15 @@ void ConvOpBase<Context>::Db(const T* dy, T* db) {
template <class Context> template <class Context>
void ConvOpBase<Context>::Setup(int num_axes) { void ConvOpBase<Context>::Setup(int num_axes) {
num_axes_ = num_axes; num_axes_ = num_axes;
auto at = [&](const vec64_t& vec, int i) {
return i < vec.size() ? vec[i] : vec[0];
};
auto pads = OpArgs<int64_t>("pads"); auto pads = OpArgs<int64_t>("pads");
auto strides = OpArgs<int64_t>("strides"); auto strides = OpArgs<int64_t>("strides");
auto kshape = OpArgs<int64_t>("kernel_shape"); auto kshape = OpArgs<int64_t>("kernel_shape");
auto dilations = OpArgs<int64_t>("dilations"); auto dilations = OpArgs<int64_t>("dilations");
auto at = [&](const vec64_t& vec, int i) {
return i < vec.size() ? vec[i] : vec[0];
};
for (int i = 0; i < num_axes; i++) { for (int i = 0; i < num_axes; i++) {
pad_l_.push_back(at(pads, i)); pad_l_.push_back(at(pads, i));
stride_.push_back(at(strides, i)); stride_.push_back(at(strides, i));
...@@ -241,8 +238,9 @@ void ConvOpBase<Context>::Setup(int num_axes) { ...@@ -241,8 +238,9 @@ void ConvOpBase<Context>::Setup(int num_axes) {
} }
if ((int64_t)pads.size() == (num_axes * 2)) { if ((int64_t)pads.size() == (num_axes * 2)) {
for (int i = 0; i < num_axes; i++) for (int i = 0; i < num_axes; i++) {
pad_r_.push_back(pads[num_axes + i]); pad_r_.push_back(pads[num_axes + i]);
}
} else { } else {
pad_r_.assign(pad_l_.begin(), pad_l_.end()); pad_r_.assign(pad_l_.begin(), pad_l_.end());
} }
...@@ -264,63 +262,56 @@ void ConvOpBase<Context>::Reshape(bool backward) { ...@@ -264,63 +262,56 @@ void ConvOpBase<Context>::Reshape(bool backward) {
auto* Y_ref = backward ? &Input(-1) : Output(0); auto* Y_ref = backward ? &Input(-1) : Output(0);
// Determine the in/out channels // Determine the in/out channels
channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1); in_channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
if (num_output_ <= 0) { if (out_channels_ <= 0) {
// Infer the out channels from the weights shape // Infer the out channels from the weights shape
num_output_ = W.count() / channels_; out_channels_ = W.count() / (in_channels_ / group_);
for (int i = 0; i < num_axes_; i++) for (int i = 0; i < num_axes_; i++) {
num_output_ /= kshape_[i]; out_channels_ /= kshape_[i];
CHECK_GT(num_output_, 0) << "\nFailed to infer the out channels " }
CHECK_GT(out_channels_, 0) << "\nFailed to infer the out channels "
<< "from weights: " << W.DimString(); << "from weights: " << W.DimString();
} }
if (Transposed()) { if (Transposed()) {
conv_out_channels_ = channels_; conv_out_channels_ = in_channels_;
conv_in_channels_ = num_output_; conv_in_channels_ = out_channels_;
} else { } else {
conv_out_channels_ = num_output_; conv_out_channels_ = out_channels_;
conv_in_channels_ = channels_; conv_in_channels_ = in_channels_;
} }
// Determine the weight and bias shape // Determine the weight and bias shape
// Weight shape is assumed as NCHW format // Weight shape is assumed as NCHW format
// whatever to compute the fans correctly // whatever to compute the fans correctly
w_shape_ = {conv_out_channels_, conv_in_channels_ / group_}; w_shape_ = {conv_out_channels_, conv_in_channels_ / group_};
for (int i = 0; i < num_axes_; i++) for (int i = 0; i < num_axes_; i++) {
w_shape_.push_back(kshape_[i]); w_shape_.push_back(kshape_[i]);
b_shape_ = {num_output_}; }
b_shape_ = {out_channels_};
// Determine the Y shape // Determine the output shape
x_shape_ = X.dims();
ComputeOutShape(); ComputeOutShape();
if (backward) { if (backward) {
if (Output(0)->has_name()) Output(0)->ReshapeLike(X); if (Output(0)->has_name()) Output(0)->ReshapeLike(X);
if (Output(1)->has_name()) Output(1)->ReshapeLike(W); if (Output(1)->has_name()) Output(1)->ReshapeLike(W);
if (Output(2)->has_name()) Output(2)->Reshape({num_output_}); if (Output(2)->has_name()) Output(2)->Reshape({out_channels_});
} else { } else {
vec64_t Y_dims{X.dim(0)};
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
y_shape_ = {X.dim(0), num_output_}; Y_dims.push_back(out_channels_);
for (int i = 0; i < num_axes_; i++) for (int i = 0; i < num_axes_; i++) {
y_shape_.push_back(out_shape_[i]); Y_dims.push_back(out_shape_[i]);
} else if (data_format() == "NHWC") {
y_shape_ = {X.dim(0)};
for (int i = 0; i < num_axes_; i++)
y_shape_.push_back(out_shape_[i]);
y_shape_.push_back(num_output_);
}
Output(0)->Reshape(y_shape_);
} }
} else if (data_format() == "NHWC") {
// Determine the input shape for im2col/col2im
in_shape_.clear();
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; i++) {
if (Transposed()) { Y_dims.push_back(out_shape_[i]);
in_shape_.push_back(Y_ref->dim(axis_ + i));
} else {
in_shape_.push_back(X.dim(axis_ + i));
} }
Y_dims.push_back(out_channels_);
}
Output(0)->Reshape(Y_dims);
} }
// Determine the out spatial dim // Determine the output dim
auto end_axis = X.ndim() - 1; auto end_axis = X.ndim() - 1;
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
if (Transposed()) { if (Transposed()) {
...@@ -338,24 +329,30 @@ void ConvOpBase<Context>::Reshape(bool backward) { ...@@ -338,24 +329,30 @@ void ConvOpBase<Context>::Reshape(bool backward) {
out_dim_ = Y_ref->count(axis_, end_axis); out_dim_ = Y_ref->count(axis_, end_axis);
} }
// Determine the misc // Compute the miscellaneous
x_ofs_ = X.stride(0); x_offset_ = X.stride(0);
y_ofs_ = Y_ref->stride(0); y_offset_ = Y_ref->stride(0);
kernel_dim_ = conv_in_channels_ / group_; kernel_dim_ = conv_in_channels_ / group_;
for (int i = 0; i < num_axes_; i++) for (int i = 0; i < num_axes_; i++) {
kernel_dim_ *= kshape_[i]; kernel_dim_ *= kshape_[i];
col_ofs_ = kernel_dim_ * conv_out_dim_; }
w_ofs_ = conv_out_channels_ * kernel_dim_ / group_; col_offset_ = kernel_dim_ * conv_out_dim_;
output_ofs_ = conv_out_channels_ * conv_out_dim_ / group_; w_offset_ = conv_out_channels_ * kernel_dim_ / group_;
out_offset_ = conv_out_channels_ * conv_out_dim_ / group_;
// Determine the workspace size for col buffer // Compute the arguments for im2col/col2im
col_dim_ = kernel_dim_ * group_; in_shape_.clear();
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; i++) {
if (Transposed()) { if (Transposed()) {
col_dim_ *= x_shape_[axis_ + i]; in_shape_.push_back(Y_ref->dim(axis_ + i));
out_shape_[i] = X.dim(axis_ + i);
} else { } else {
col_dim_ *= out_shape_[i]; in_shape_.push_back(X.dim(axis_ + i));
}
} }
col_dim_ = kernel_dim_ * group_;
for (int i = 0; i < num_axes_; i++) {
col_dim_ *= out_shape_[i];
} }
} }
......
...@@ -25,7 +25,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -25,7 +25,7 @@ class ConvOpBase : public Operator<Context> {
ConvOpBase(const OperatorDef& def, Workspace* ws) ConvOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
padding_(OpArg<string>("padding", "VALID")), padding_(OpArg<string>("padding", "VALID")),
num_output_(OpArg<int64_t>("num_output", 0)), out_channels_(OpArg<int64_t>("out_channels", 0)),
group_(OpArg<int64_t>("group", 1)) { group_(OpArg<int64_t>("group", 1)) {
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
axis_ = 2; axis_ = 2;
...@@ -42,18 +42,13 @@ class ConvOpBase : public Operator<Context> { ...@@ -42,18 +42,13 @@ class ConvOpBase : public Operator<Context> {
vec64_t kshape_, stride_; vec64_t kshape_, stride_;
vec64_t pad_l_, pad_r_, dilation_; vec64_t pad_l_, pad_r_, dilation_;
vec64_t in_shape_, out_shape_; vec64_t in_shape_, w_shape_, b_shape_, out_shape_;
vec64_t x_shape_, y_shape_;
vec64_t w_shape_, b_shape_;
string padding_; string padding_;
int64_t is_1x1_, num_output_, group_; int64_t group_;
int64_t axis_, num_axes_; int64_t axis_, num_axes_;
int64_t channels_, out_dim_; int64_t in_channels_, out_channels_, out_dim_;
int64_t conv_in_channels_, conv_out_channels_; int64_t x_offset_, w_offset_, y_offset_;
int64_t conv_out_dim_, kernel_dim_, col_dim_;
int64_t col_ofs_, output_ofs_;
int64_t w_ofs_, x_ofs_, y_ofs_;
DECLARE_ARGS_WITH_DESC(int64_t, output_shape); DECLARE_ARGS_WITH_DESC(int64_t, output_shape);
DECLARE_ARGS_WITH_DESC(int64_t, output_padding); DECLARE_ARGS_WITH_DESC(int64_t, output_padding);
...@@ -133,6 +128,11 @@ class ConvOpBase : public Operator<Context> { ...@@ -133,6 +128,11 @@ class ConvOpBase : public Operator<Context> {
LOG(FATAL) << "ConvNd has not been implemented."; LOG(FATAL) << "ConvNd has not been implemented.";
} }
} }
int64_t is_1x1_;
int64_t kernel_dim_, col_dim_;
int64_t col_offset_, out_offset_;
int64_t conv_in_channels_, conv_out_channels_, conv_out_dim_;
}; };
DEFINE_ARGS_WITH_DESC(int64_t, ConvOpBase, output_shape); DEFINE_ARGS_WITH_DESC(int64_t, ConvOpBase, output_shape);
...@@ -154,16 +154,16 @@ DEFINE_ARGS_WITH_DESC(int64_t, ConvOpBase, output_padding); ...@@ -154,16 +154,16 @@ DEFINE_ARGS_WITH_DESC(int64_t, ConvOpBase, output_padding);
using ConvOpBase<Context>::pad_r_; \ using ConvOpBase<Context>::pad_r_; \
using ConvOpBase<Context>::dilation_; \ using ConvOpBase<Context>::dilation_; \
using ConvOpBase<Context>::group_; \ using ConvOpBase<Context>::group_; \
using ConvOpBase<Context>::channels_; \ using ConvOpBase<Context>::in_channels_; \
using ConvOpBase<Context>::num_output_; \ using ConvOpBase<Context>::out_channels_; \
using ConvOpBase<Context>::axis_; \ using ConvOpBase<Context>::axis_; \
using ConvOpBase<Context>::num_axes_; \ using ConvOpBase<Context>::num_axes_; \
using ConvOpBase<Context>::x_ofs_; \ using ConvOpBase<Context>::x_offset_; \
using ConvOpBase<Context>::y_ofs_; \ using ConvOpBase<Context>::w_offset_; \
using ConvOpBase<Context>::w_ofs_; \ using ConvOpBase<Context>::y_offset_; \
using ConvOpBase<Context>::in_shape_; \
using ConvOpBase<Context>::w_shape_; \ using ConvOpBase<Context>::w_shape_; \
using ConvOpBase<Context>::b_shape_; \ using ConvOpBase<Context>::b_shape_; \
using ConvOpBase<Context>::in_shape_; \
using ConvOpBase<Context>::out_shape_ using ConvOpBase<Context>::out_shape_
} // namespace dragon } // namespace dragon
......
...@@ -10,14 +10,15 @@ template <typename T> ...@@ -10,14 +10,15 @@ template <typename T>
void DepthwiseConv2dOp<Context>::DoRunWithType() { void DepthwiseConv2dOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0); auto &X = Input(0), &W = Input(1), *Y = Output(0);
group_ = channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1); group_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
CHECK_EQ(channels_, num_output_) << "\nExcepted in/out channels unchanged.";
ConvOpBase<Context>::Reshape(); ConvOpBase<Context>::Reshape();
CHECK_EQ(in_channels_, out_channels_)
<< "\nExcepted in/out channels to be same.";
TENSOR_FILL(W, w_shape_); TENSOR_FILL(W, w_shape_);
kernel::DepthwiseConv2d( kernel::DepthwiseConv2d(
Input(0).dim(0), Input(0).dim(0),
channels_, in_channels_,
in_shape_[0], in_shape_[0],
in_shape_[1], in_shape_[1],
out_shape_[0], out_shape_[0],
...@@ -54,13 +55,13 @@ void DepthwiseConv2dGradientOp<Context>::DoRunWithType() { ...@@ -54,13 +55,13 @@ void DepthwiseConv2dGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2); auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
group_ = channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1); group_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
ConvOpBase<Context>::Reshape(true); ConvOpBase<Context>::Reshape(true);
if (dX->has_name()) { if (dX->has_name()) {
kernel::DepthwiseConv2dGrad( kernel::DepthwiseConv2dGrad(
X.dim(0), X.dim(0),
channels_, in_channels_,
in_shape_[0], in_shape_[0],
in_shape_[1], in_shape_[1],
out_shape_[0], out_shape_[0],
...@@ -83,7 +84,7 @@ void DepthwiseConv2dGradientOp<Context>::DoRunWithType() { ...@@ -83,7 +84,7 @@ void DepthwiseConv2dGradientOp<Context>::DoRunWithType() {
if (dW->has_name()) { if (dW->has_name()) {
kernel::DepthwiseConv2dWGrad( kernel::DepthwiseConv2dWGrad(
X.dim(0), X.dim(0),
channels_, in_channels_,
in_shape_[0], in_shape_[0],
in_shape_[1], in_shape_[1],
out_shape_[0], out_shape_[0],
......
...@@ -12,14 +12,15 @@ template <typename T> ...@@ -12,14 +12,15 @@ template <typename T>
void CuDNNDepthwiseConv2dOp<Context>::DoRunWithType() { void CuDNNDepthwiseConv2dOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0); auto &X = Input(0), &W = Input(1), *Y = Output(0);
group_ = channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1); group_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
CHECK_EQ(channels_, num_output_) << "\nExcepted in/out channels unchanged.";
ConvOpBase<Context>::Reshape(); ConvOpBase<Context>::Reshape();
CHECK_EQ(in_channels_, out_channels_)
<< "\nExcepted in/out channels to be same.";
TENSOR_FILL(W, w_shape_); TENSOR_FILL(W, w_shape_);
kernel::DepthwiseConv2d( kernel::DepthwiseConv2d(
X.dim(0), X.dim(0),
channels_, in_channels_,
in_shape_[0], in_shape_[0],
in_shape_[1], in_shape_[1],
out_shape_[0], out_shape_[0],
...@@ -40,7 +41,7 @@ void CuDNNDepthwiseConv2dOp<Context>::DoRunWithType() { ...@@ -40,7 +41,7 @@ void CuDNNDepthwiseConv2dOp<Context>::DoRunWithType() {
if (HasBias()) { if (HasBias()) {
TENSOR_FILL(Input(2), b_shape_); TENSOR_FILL(Input(2), b_shape_);
CuDNNSetBiasDesc<T>(&bias_desc_, 4, num_output_, data_format()); CuDNNSetBiasDesc<T>(&bias_desc_, 4, out_channels_, data_format());
CuDNNSetTensorDesc<T>(&output_desc_, Y->dims(), data_format()); CuDNNSetTensorDesc<T>(&output_desc_, Y->dims(), data_format());
CUDNN_CHECK(cudnnAddTensor( CUDNN_CHECK(cudnnAddTensor(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
...@@ -64,13 +65,13 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() { ...@@ -64,13 +65,13 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2); auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
group_ = channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1); group_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
ConvOpBase<Context>::Reshape(true); ConvOpBase<Context>::Reshape(true);
if (dX->has_name()) { if (dX->has_name()) {
kernel::DepthwiseConv2dGrad( kernel::DepthwiseConv2dGrad(
X.dim(0), X.dim(0),
channels_, in_channels_,
in_shape_[0], in_shape_[0],
in_shape_[1], in_shape_[1],
out_shape_[0], out_shape_[0],
...@@ -93,7 +94,7 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() { ...@@ -93,7 +94,7 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() {
if (dW->has_name()) { if (dW->has_name()) {
kernel::DepthwiseConv2dWGrad( kernel::DepthwiseConv2dWGrad(
X.dim(0), X.dim(0),
channels_, in_channels_,
in_shape_[0], in_shape_[0],
in_shape_[1], in_shape_[1],
out_shape_[0], out_shape_[0],
...@@ -115,7 +116,7 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() { ...@@ -115,7 +116,7 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() {
if (dB->has_name()) { if (dB->has_name()) {
CuDNNSetTensorDesc<T>(&input_desc_, Input(-1).dims(), data_format()); CuDNNSetTensorDesc<T>(&input_desc_, Input(-1).dims(), data_format());
CuDNNSetBiasDesc<T>(&bias_desc_, 4, num_output_, data_format()); CuDNNSetBiasDesc<T>(&bias_desc_, 4, out_channels_, data_format());
CUDNN_CHECK(cudnnConvolutionBackwardBias( CUDNN_CHECK(cudnnConvolutionBackwardBias(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
......
...@@ -50,7 +50,7 @@ void SpaceToDepthOp<Context>::DoRunWithType() { ...@@ -50,7 +50,7 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
for (int i = 0; i < num_axes; i++) { for (int i = 0; i < num_axes; i++) {
perm.insert(perm.begin() + 1, perm.back()); perm.insert(perm.begin() + 1, perm.back());
perm.pop_back(); // CRD mode perm.pop_back(); // DCR mode
} }
} }
......
...@@ -63,7 +63,11 @@ message TensorFillerProto { ...@@ -63,7 +63,11 @@ message TensorFillerProto {
optional float mean = 6 [default = 0]; optional float mean = 6 [default = 0];
optional float std = 7 [default = 1]; optional float std = 7 [default = 1];
optional float scale = 8 [default = 3]; optional float scale = 8 [default = 3];
enum VarianceNorm { FAN_IN = 0; FAN_OUT = 1; FAN_AVG=2; } enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
FAN_AVG = 2;
}
optional VarianceNorm variance_norm = 9 [default = FAN_IN]; optional VarianceNorm variance_norm = 9 [default = FAN_IN];
} }
......
...@@ -28,10 +28,10 @@ from dragon._api import losses ...@@ -28,10 +28,10 @@ from dragon._api import losses
from dragon._api import math from dragon._api import math
from dragon._api import metrics from dragon._api import metrics
from dragon._api import nn from dragon._api import nn
from dragon._api import optimizers
from dragon._api import random from dragon._api import random
from dragon._api import updaters
from dragon._api import vision
from dragon._api import workspace from dragon._api import workspace
from dragon._api import vision
# Virtual API # Virtual API
from dragon import vm from dragon import vm
...@@ -56,7 +56,7 @@ from dragon.core.framework.context import name_scope ...@@ -56,7 +56,7 @@ from dragon.core.framework.context import name_scope
from dragon.core.framework.workspace import get_workspace from dragon.core.framework.workspace import get_workspace
from dragon.core.framework.workspace import reset_workspace from dragon.core.framework.workspace import reset_workspace
from dragon.core.ops import tensorbind_eager as _ from dragon.core.ops import tensorbind_eager as _
from dragon.core.ops import tensorbind_symbolic as _ from dragon.core.ops import tensorbind_symbol as _
from dragon.core.ops.array_ops import arange from dragon.core.ops.array_ops import arange
from dragon.core.ops.array_ops import broadcast_to from dragon.core.ops.array_ops import broadcast_to
from dragon.core.ops.array_ops import cast from dragon.core.ops.array_ops import cast
......
...@@ -24,9 +24,9 @@ from dragon.core.ops.array_ops import min ...@@ -24,9 +24,9 @@ from dragon.core.ops.array_ops import min
from dragon.core.ops.array_ops import moments from dragon.core.ops.array_ops import moments
from dragon.core.ops.array_ops import sum from dragon.core.ops.array_ops import sum
from dragon.core.ops.math_ops import abs from dragon.core.ops.math_ops import abs
from dragon.core.ops.math_ops import accumulate
from dragon.core.ops.math_ops import add from dragon.core.ops.math_ops import add
from dragon.core.ops.math_ops import affine from dragon.core.ops.math_ops import affine
from dragon.core.ops.math_ops import axpby
from dragon.core.ops.math_ops import ceil from dragon.core.ops.math_ops import ceil
from dragon.core.ops.math_ops import clip from dragon.core.ops.math_ops import clip
from dragon.core.ops.math_ops import cos from dragon.core.ops.math_ops import cos
...@@ -45,7 +45,6 @@ from dragon.core.ops.math_ops import log ...@@ -45,7 +45,6 @@ from dragon.core.ops.math_ops import log
from dragon.core.ops.math_ops import matmul from dragon.core.ops.math_ops import matmul
from dragon.core.ops.math_ops import maximum from dragon.core.ops.math_ops import maximum
from dragon.core.ops.math_ops import minimum from dragon.core.ops.math_ops import minimum
from dragon.core.ops.math_ops import moving_average
from dragon.core.ops.math_ops import mul from dragon.core.ops.math_ops import mul
from dragon.core.ops.math_ops import negative from dragon.core.ops.math_ops import negative
from dragon.core.ops.math_ops import not_equal from dragon.core.ops.math_ops import not_equal
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.training.adam import Adam
from dragon.core.training.optimizer import Optimizer
from dragon.core.training.rmsprop import RMSprop
from dragon.core.training.sgd import Nesterov
from dragon.core.training.sgd import SGD
__all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -29,7 +29,7 @@ from dragon.core.eager import context as eager_context ...@@ -29,7 +29,7 @@ from dragon.core.eager import context as eager_context
from dragon.core.eager.tensor import EagerTensor from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.training import updater from dragon.core.training import optimizer
from dragon.core.util import decorator from dragon.core.util import decorator
from dragon.core.util import inspect from dragon.core.util import inspect
from dragon.core.util import nest from dragon.core.util import nest
...@@ -265,7 +265,7 @@ class FunctionGuard(object): ...@@ -265,7 +265,7 @@ class FunctionGuard(object):
dummies.append(obj) dummies.append(obj)
executables = [function_lib.create_function(inputs, outputs)] executables = [function_lib.create_function(inputs, outputs)]
for obj in dummies: for obj in dummies:
if isinstance(obj, updater.Updater): if isinstance(obj, optimizer.Optimizer):
executables.append(function_lib.create_function(updater=obj)) executables.append(function_lib.create_function(updater=obj))
self.inputs = inputs self.inputs = inputs
self.outputs = returns self.outputs = returns
......
...@@ -78,22 +78,22 @@ def add_phase(graph_def, targets): ...@@ -78,22 +78,22 @@ def add_phase(graph_def, targets):
graph_def.arg.extend([proto_util.make_argument('phase', phase)]) graph_def.arg.extend([proto_util.make_argument('phase', phase)])
def add_update_ops(graph_def, updater): def add_update_ops(graph_def, optimizer):
"""Add the update operators for graph.""" """Add the update operators for graph."""
if updater is None: if optimizer is None:
return return
grads, update_ops = [], [] grads, update_ops = [], []
extra_arguments = updater._extra_kwargs extra_arguments = optimizer._extra_kwargs
extra_arguments['slot'] = updater._slot extra_arguments['handle'] = optimizer._op_handle
# Generate update operators according to the updater. # Generate update operators according to the updater.
for e in updater._param_group: for e in optimizer._param_group:
(param, grad), arguments = e (param, grad), arguments = e
if workspace.has_tensor(grad): if workspace.has_tensor(grad):
grads.append(grad) grads.append(grad)
arguments = dict(arguments, **extra_arguments) arguments = dict(arguments, **extra_arguments)
update_ops.append( update_ops.append(
proto_util.make_operator_def( proto_util.make_operator_def(
op_type=updater._op_type, op_type=optimizer._op_type,
inputs=[grad], inputs=[grad],
outputs=[param], outputs=[param],
name=OpDef.get_name(), name=OpDef.get_name(),
...@@ -102,7 +102,7 @@ def add_update_ops(graph_def, updater): ...@@ -102,7 +102,7 @@ def add_update_ops(graph_def, updater):
else: else:
logging.info('Skip to update Tensor({}).'.format(param)) logging.info('Skip to update Tensor({}).'.format(param))
# Insert a reduce op if the process group is found. # Insert a reduce op if the process group is found.
process_group = updater._process_group process_group = optimizer._process_group
if process_group is not None: if process_group is not None:
update_ops.insert( update_ops.insert(
0, proto_util.make_operator_def( 0, proto_util.make_operator_def(
...@@ -139,12 +139,15 @@ class Function(object): ...@@ -139,12 +139,15 @@ class Function(object):
# Collect the forward operators. # Collect the forward operators.
requires_grad = False requires_grad = False
for output in outputs: for i, output in enumerate(outputs):
op_info.merge_from(output) op_info.merge_from(output)
op_info.add_target(output.id) op_info.add_target(output.id)
if output._grad is not None and \ try:
output._grad.required(): grad_info = output._grad
if grad_info and grad_info.required():
requires_grad = True requires_grad = True
except AttributeError:
raise ValueError('Output[%d] is not a symbolic tensor.' % i)
# Handle givens. # Handle givens.
if givens is not None: if givens is not None:
...@@ -169,23 +172,23 @@ class Function(object): ...@@ -169,23 +172,23 @@ class Function(object):
]) ])
del op_def.input[:len(op_def.input) // 2] del op_def.input[:len(op_def.input) // 2]
# Sort out the topology of states. # Sort out the states.
op_defs = sorted(op_info._defs.items(), key=lambda d: d[0]) op_defs = sorted(op_info._defs.items(), key=lambda d: d[0])
forward_ops = copy.deepcopy([v for k, v in op_defs]) forward_ops = copy.deepcopy([v for k, v in op_defs])
# Generate the backward operators. # Generate the backward operators.
if requires_grad: if requires_grad:
input_grads = {} input_grads, grad_targets = {}, []
for output in outputs: for output in outputs:
if hasattr(output, '_grad'):
grad_info = output._grad grad_info = output._grad
if grad_info is not None: if grad_info is not None:
if grad_info.input is not None: if grad_info.input is not None:
input_grads[output.id] = grad_info.input.id input_grads[output.id] = output._grad.input.id
grad_targets.append(output.id)
forward_ops, gradient_ops, _ = \ forward_ops, gradient_ops, _ = \
grad_maker.GradientMaker.make( grad_maker.GradientMaker.make(
forward_ops=forward_ops, forward_ops=forward_ops,
targets=list(op_info._targets), targets=grad_targets,
input_grads=input_grads, input_grads=input_grads,
) )
else: else:
......
...@@ -13,7 +13,7 @@ from __future__ import absolute_import ...@@ -13,7 +13,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.autograph.tensor import RefTensor from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.util import nest from dragon.core.util import nest
...@@ -120,19 +120,12 @@ def gradients(ys, xs, grad_ys=None): ...@@ -120,19 +120,12 @@ def gradients(ys, xs, grad_ys=None):
if grad_ys is not None: if grad_ys is not None:
y._grad.set_input(grad_ys[i]) y._grad.set_input(grad_ys[i])
for x in xs: for x in xs:
if not hasattr(x, '_grad') or \ if not hasattr(x, '_grad') or x._grad is None:
x._grad is None:
x._grad = GradientInfo(x) x._grad = GradientInfo(x)
y._grad.add_wrt(x.id) y._grad.add_wrt(x.id)
x._grad.add_cost(y) x._grad.add_cost(y)
if i == 0: if i == 0:
dxs.append( dxs.append(TensorRef(x.id + '_grad', x.shape, x.dtype))
RefTensor(
name=x.id + '_grad',
shape=x.shape,
dtype=x.dtype,
)
)
# Return the packed gradients. # Return the packed gradients.
return dxs return dxs
...@@ -15,7 +15,7 @@ from __future__ import absolute_import ...@@ -15,7 +15,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.autograph.tensor import RefTensor from dragon.core.autograph.tensor import TensorRef
from dragon.core.autograph import op_spec from dragon.core.autograph import op_spec
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.framework import proto_util from dragon.core.framework import proto_util
...@@ -76,20 +76,18 @@ class OpDef(object): ...@@ -76,20 +76,18 @@ class OpDef(object):
outputs = [] outputs = []
name_scope = context.get_name_scope() name_scope = context.get_name_scope()
for i in range(num_outputs): for i in range(num_outputs):
outputs.append(RefTensor( outputs.append(TensorRef(
workspace.get_dummy_name( workspace.get_dummy_name(
name_scope + (name if name else op_type), name_scope + (name if name else op_type),
suffix=':{}'.format(i), suffix=':{}'.format(i),
domain='Tensor')) domain='Tensor')))
)
else: else:
outputs = nest.flatten(outputs) outputs = nest.flatten(outputs)
num_outputs = len(outputs) num_outputs = len(outputs)
# Construct Def. # Construct Def.
op_idx, op_name = OpDef.get_index_and_name() op_idx, op_name = OpDef.get_index_and_name()
op_info._defs[op_idx] = \ op_info._defs[op_idx] = proto_util.make_operator_def(
proto_util.make_operator_def(
name=op_name, name=op_name,
op_type=op_type, op_type=op_type,
inputs=[input.id for input in inputs], inputs=[input.id for input in inputs],
......
...@@ -147,7 +147,7 @@ def cast_spec(args, inputs, outputs): ...@@ -147,7 +147,7 @@ def cast_spec(args, inputs, outputs):
outputs[0].dtype = args['dtype'] outputs[0].dtype = args['dtype']
try: try:
outputs[0].shape = inputs[0].shape[:] outputs[0].shape = inputs[0].shape[:]
except TypeError: except (TypeError, IndexError):
pass pass
return outputs return outputs
...@@ -192,7 +192,10 @@ def conv_spec(args, inputs, outputs): ...@@ -192,7 +192,10 @@ def conv_spec(args, inputs, outputs):
out_shape = inputs[0].shape[:] out_shape = inputs[0].shape[:]
channel_axis = 1 if args['data_format'] == 'NCHW' else -1 channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1 spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
out_shape[channel_axis] = args['num_output'] if 'out_channels' in args:
out_shape[channel_axis] = args['out_channels']
else:
out_shape[channel_axis] = inputs[1].shape[0]
for i in range(len(out_shape) - 2): for i in range(len(out_shape) - 2):
input_size = out_shape[i + spatial_axis] input_size = out_shape[i + spatial_axis]
k = args['kernel_shape'][i] k = args['kernel_shape'][i]
...@@ -219,7 +222,10 @@ def conv_transpose_spec(args, inputs, outputs): ...@@ -219,7 +222,10 @@ def conv_transpose_spec(args, inputs, outputs):
out_shape = inputs[0].shape[:] out_shape = inputs[0].shape[:]
channel_axis = 1 if args['data_format'] == 'NCHW' else -1 channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1 spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
out_shape[channel_axis] = args['num_output'] if 'out_channels' in args:
out_shape[channel_axis] = args['out_channels']
else:
out_shape[channel_axis] = inputs[1].shape[1]
for i in range(len(out_shape) - 2): for i in range(len(out_shape) - 2):
k = args['kernel_shape'][i] k = args['kernel_shape'][i]
s = args['strides'][i] s = args['strides'][i]
...@@ -274,20 +280,16 @@ def depth_to_space_spec(args, inputs, outputs): ...@@ -274,20 +280,16 @@ def depth_to_space_spec(args, inputs, outputs):
@register('Dot') @register('Dot')
def dot_spec(args, inputs, outputs): def dot_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
ta, tb = args['transA'], args['transB']
try: try:
if len(inputs[0].shape) == 1: a_shape, b_shape = inputs[0].shape[:], inputs[1].shape[:]
if len(a_shape) == 1 and len(b_shape) == 1:
outputs[0].shape = [] outputs[0].shape = []
return outputs elif len(a_shape) == 2 and len(b_shape) == 2:
except TypeError: outputs[0].shape = [a_shape[0], b_shape[1]]
pass elif len(a_shape) == 0 and len(b_shape) == 0:
try: outputs[0].shape = []
if len(inputs[0].shape) >= 2 and len(inputs[1].shape) in (1, 2): elif len(a_shape) >= 2 and len(b_shape) == 1:
out_shape = inputs[0].shape[1:] if ta else inputs[0].shape[:-1] outputs[0].shape = a_shape[:-1]
if len(inputs[1].shape) == 2:
out_shape.append(inputs[1].shape[0] if tb else inputs[1].shape[1])
outputs[0].shape = out_shape
return outputs
except TypeError: except TypeError:
pass pass
return outputs return outputs
...@@ -298,6 +300,7 @@ def dot_spec(args, inputs, outputs): ...@@ -298,6 +300,7 @@ def dot_spec(args, inputs, outputs):
'L1Loss', 'L1Loss',
'L2Loss', 'L2Loss',
'SigmoidCrossEntropy', 'SigmoidCrossEntropy',
'SigmoidFocalLoss',
'SmoothL1Loss', 'SmoothL1Loss',
]) ])
def eltwise_loss_spec(args, inputs, outputs): def eltwise_loss_spec(args, inputs, outputs):
...@@ -426,22 +429,22 @@ def flatten_spec(args, inputs, outputs): ...@@ -426,22 +429,22 @@ def flatten_spec(args, inputs, outputs):
@register('FullyConnected') @register('FullyConnected')
def fully_connected_spec(args, inputs, outputs): def fully_connected_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
axis, num_output = args['axis'], args['num_output'] axis, out_channels = args['axis'], args.get('out_channels', None)
while axis < 0: while axis < 0:
try: try:
axis += len(inputs[0].shape) axis += len(inputs[0].shape)
except TypeError: except TypeError:
return outputs return outputs
outputs[0].shape = [None] * (axis + 1) outputs[0].shape = [None] * (axis + 1)
if num_output is None: if out_channels is None:
try: try:
if args['transW']: if args['transW']:
num_output = inputs[1].shape[0] out_channels = inputs[1].shape[0]
else: else:
num_output = inputs[1].shape[1] out_channels = inputs[1].shape[1]
except (TypeError, IndexError): except (TypeError, IndexError):
num_output = None out_channels = None
outputs[0].shape[axis] = num_output outputs[0].shape[axis] = out_channels
try: try:
outputs[0].shape[:axis] = inputs[0].shape[:axis] outputs[0].shape[:axis] = inputs[0].shape[:axis]
except TypeError: except TypeError:
...@@ -488,7 +491,7 @@ def index_select_spec(args, inputs, outputs): ...@@ -488,7 +491,7 @@ def index_select_spec(args, inputs, outputs):
return outputs return outputs
@register(['IsInf', 'InNaN']) @register(['IsInf', 'IsNaN'])
def is_spec(args, inputs, outputs): def is_spec(args, inputs, outputs):
_ = locals() _ = locals()
outputs[0].dtype = 'bool' outputs[0].dtype = 'bool'
...@@ -507,7 +510,7 @@ def masked_select_spec(args, inputs, outputs): ...@@ -507,7 +510,7 @@ def masked_select_spec(args, inputs, outputs):
return outputs return outputs
@register('Matmul') @register('MatMul')
def matmul_spec(args, inputs, outputs): def matmul_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
ta, tb = args['transA'], args['transB'] ta, tb = args['transA'], args['transB']
...@@ -758,7 +761,7 @@ def resize_spec(args, inputs, outputs): ...@@ -758,7 +761,7 @@ def resize_spec(args, inputs, outputs):
@register(['RoiPool', 'RoiAlign']) @register(['RoiPool', 'RoiAlign'])
def roi_pool_spec(args, inputs, outputs): def roi_pool_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
pool_h, pool_w = args['pool_h'], args['pool_w'] pool_h, pool_w = args['pooled_h'], args['pooled_w']
out_shape = None out_shape = None
try: try:
out_shape = inputs[0].shape[:] out_shape = inputs[0].shape[:]
...@@ -814,7 +817,6 @@ def slice_spec(args, inputs, outputs): ...@@ -814,7 +817,6 @@ def slice_spec(args, inputs, outputs):
@register([ @register([
'NLLLoss', 'NLLLoss',
'SigmoidFocalLoss',
'SoftmaxCrossEntropy', 'SoftmaxCrossEntropy',
'SparseSoftmaxCrossEntropy', 'SparseSoftmaxCrossEntropy',
]) ])
......
...@@ -420,7 +420,7 @@ class Tensor(types.TensorMetaclass): ...@@ -420,7 +420,7 @@ class Tensor(types.TensorMetaclass):
The constant contains the value. The constant contains the value.
""" """
return RefTensor('', dtype=dtype)._from_constant(value, name) return Tensor('', dtype=dtype)._from_constant(value, name)
def _register_as(self, type, **kwargs): def _register_as(self, type, **kwargs):
"""Fill self with the specific type of filler.""" """Fill self with the specific type of filler."""
...@@ -463,13 +463,12 @@ class Tensor(types.TensorMetaclass): ...@@ -463,13 +463,12 @@ class Tensor(types.TensorMetaclass):
"""Convert the value to a tensor.""" """Convert the value to a tensor."""
if not isinstance(value, numpy.ndarray): if not isinstance(value, numpy.ndarray):
value = numpy.array(value, self.dtype if self.dtype else 'float32') value = numpy.array(value, self.dtype if self.dtype else 'float32')
return RefTensor( return TensorRef(
name=workspace.get_dummy_name( name=workspace.get_dummy_name(
basename=context.get_name_scope() + basename=context.get_name_scope() +
(name if name else 'Const'), (name if name else 'Const'),
suffix=':0', suffix=':0',
domain='Tensor' domain='Tensor'),
),
shape=list(value.shape), shape=list(value.shape),
dtype=str(value.dtype), dtype=str(value.dtype),
).set_value(value) ).set_value(value)
...@@ -560,8 +559,8 @@ class Tensor(types.TensorMetaclass): ...@@ -560,8 +559,8 @@ class Tensor(types.TensorMetaclass):
return self.__div__(other) return self.__div__(other)
class RefTensor(object): class TensorRef(object):
"""Create a reference tensor not involved with name scope.""" """Create a reference not involved with name scope."""
def __new__(cls, name, shape=None, dtype=None): def __new__(cls, name, shape=None, dtype=None):
tensor = Tensor('', shape=shape, dtype=dtype) tensor = Tensor('', shape=shape, dtype=dtype)
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Some useful mappings are defined here.""" """Constant mappings."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -104,7 +104,7 @@ class Operator(object): ...@@ -104,7 +104,7 @@ class Operator(object):
"""Generate the OpDef from attributes.""" """Generate the OpDef from attributes."""
attributes = self.attributes() attributes = self.attributes()
self._def = proto_util.make_operator_cdef( self._def = proto_util.make_operator_cdef(
name='Generic', name=attributes.get('name', 'GenericOp'),
cache_key=self._cache_key, cache_key=self._cache_key,
op_type=attributes['op_type'], op_type=attributes['op_type'],
device_option=proto_util.get_device_option( device_option=proto_util.get_device_option(
......
...@@ -134,10 +134,14 @@ def make_operator_cdef( ...@@ -134,10 +134,14 @@ def make_operator_cdef(
op_def = backend.OperatorDef() op_def = backend.OperatorDef()
op_def.ParseFrom( op_def.ParseFrom(
make_operator_def( make_operator_def(
op_type, inputs, outputs, name, op_type,
cache_key, device_option, arg, **kwargs inputs,
).SerializeToString() outputs,
) name,
cache_key,
device_option,
arg,
**kwargs).SerializeToString())
return op_def return op_def
......
...@@ -9,12 +9,7 @@ ...@@ -9,12 +9,7 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Wrappers for the Workspace of C++ backend. """Generic interfaces of current default workspace."""
Flexible API is provided to manage the global resources
between the Python threads (quite different from C++).
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -268,7 +268,7 @@ def leaky_relu(inputs, alpha=0.2, **kwargs): ...@@ -268,7 +268,7 @@ def leaky_relu(inputs, alpha=0.2, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def log_softmax(inputs, axis=1, **kwargs): def log_softmax(inputs, axis=-1, **kwargs):
r"""Apply the composite of logarithm and softmax. r"""Apply the composite of logarithm and softmax.
The **LogSoftmax** function is defined as: The **LogSoftmax** function is defined as:
...@@ -287,7 +287,7 @@ def log_softmax(inputs, axis=1, **kwargs): ...@@ -287,7 +287,7 @@ def log_softmax(inputs, axis=1, **kwargs):
---------- ----------
inputs : dragon.Tensor inputs : dragon.Tensor
The input tensor. The input tensor.
axis : int, optional, default=1 axis : int, optional, default=-1
The axis to reduce. The axis to reduce.
Returns Returns
...@@ -351,7 +351,7 @@ def prelu(inputs, channel_shared=False, data_format='NCHW', **kwargs): ...@@ -351,7 +351,7 @@ def prelu(inputs, channel_shared=False, data_format='NCHW', **kwargs):
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate(data_format=data_format) \ .instantiate(data_format=data_format) \
.apply([inputs]) .apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -373,7 +373,7 @@ def relu(inputs, **kwargs): ...@@ -373,7 +373,7 @@ def relu(inputs, **kwargs):
Examples: Examples:
```python ```python
x = dragon.constant([-1, 0, 1], 'float32') x = dragon.constant([-1., 0., 1.])
print(dragon.nn.relu(x, inplace=False)) print(dragon.nn.relu(x, inplace=False))
``` ```
...@@ -561,9 +561,8 @@ def softmax(inputs, axis=-1, **kwargs): ...@@ -561,9 +561,8 @@ def softmax(inputs, axis=-1, **kwargs):
op_lib = activation_ops_lib.Softmax op_lib = activation_ops_lib.Softmax
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(axis=axis) \
axis=axis, .apply([inputs], inplace=inplace)
).apply([inputs], inplace=inplace)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -64,11 +64,14 @@ def arange(start, stop=None, step=1, dtype='int64', **kwargs): ...@@ -64,11 +64,14 @@ def arange(start, stop=None, step=1, dtype='int64', **kwargs):
""" """
args = parse_args(locals()) args = parse_args(locals())
args['dtype'] = args['dtype'].lower() args['dtype'] = args['dtype'].lower()
op_lib = array_ops_lib.Arange
if stop is None: if stop is None:
args['slice'] = (start, step) args['slice'] = (float(start), float(step))
else: else:
args['slice'] = (start, stop, step) args['slice'] = (float(start), float(stop), float(step))
args.pop('start')
args.pop('stop')
args.pop('step')
op_lib = array_ops_lib.Arange
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate( return op_lib.instantiate(
...@@ -269,6 +272,8 @@ def cast(inputs, dtype, **kwargs): ...@@ -269,6 +272,8 @@ def cast(inputs, dtype, **kwargs):
.instantiate(dtype=dtype) \ .instantiate(dtype=dtype) \
.apply([inputs], inplace=inplace) .apply([inputs], inplace=inplace)
else: else:
if inputs.dtype == dtype:
return inputs
if inplace: if inplace:
args['inputs'], args['outputs'] = [], [inputs] args['inputs'], args['outputs'] = [], [inputs]
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -627,16 +632,14 @@ def index_select(inputs, indices, axis=0, **kwargs): ...@@ -627,16 +632,14 @@ def index_select(inputs, indices, axis=0, **kwargs):
return op_lib.blend(**args) return op_lib.blend(**args)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(2)
def masked_select(inputs, mask, **kwargs): def masked_select(inputs, **kwargs):
"""Select the elements where the given mask is **1**. """Select the elements where the given mask is **1**.
Parameters Parameters
---------- ----------
inputs : dragon.Tensor inputs : Sequence[dragon.Tensor]
The input tensor. The input and mask tensor.
mask : dragon.Tensor
The mask, with the same size as ``inputs``.
Returns Returns
------- -------
...@@ -647,9 +650,8 @@ def masked_select(inputs, mask, **kwargs): ...@@ -647,9 +650,8 @@ def masked_select(inputs, mask, **kwargs):
args = parse_args(locals()) args = parse_args(locals())
op_lib = array_ops_lib.MaskedSelect op_lib = array_ops_lib.MaskedSelect
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate().apply([inputs, mask]) return op_lib.instantiate().apply(inputs)
else: else:
args['inputs'] = [args['inputs'], args.pop('mask')]
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -1047,7 +1049,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs): ...@@ -1047,7 +1049,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs):
.instantiate( .instantiate(
ndim=len(pads_begin), ndim=len(pads_begin),
value=args['value'], value=args['value'],
mode=mode, mode=args['mode'],
).apply([inputs], args['pads']) ).apply([inputs], args['pads'])
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -1278,7 +1280,9 @@ def split( ...@@ -1278,7 +1280,9 @@ def split(
size_splits = None size_splits = None
if slice_points is not None: if slice_points is not None:
if len(slice_points) + 1 != num_splits: if len(slice_points) + 1 != num_splits:
raise ValueError('Excepted %d values for <slice_points>.') raise ValueError(
'Excepted %d values for <slice_points>.'
% len(slice_points))
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(
......
...@@ -61,38 +61,36 @@ def assign(inputs, starts=None, sizes=None, **kwargs): ...@@ -61,38 +61,36 @@ def assign(inputs, starts=None, sizes=None, **kwargs):
@OpSchema.num_inputs(1, 2) @OpSchema.num_inputs(1, 2)
def copy(inputs, **kwargs): def copy(inputs, **kwargs):
r"""Copy the value to ref. """Copy the input.
.. math:: \text{Ref}[:] = \text{Value}[:]
Examples: Examples:
```python ```python
# Copy the content from ``x`` to ``xx`` # Copy ``x`` to ``y``
x = dragon.ones(shape=(2, 3)) x = dragon.ones(shape=(2, 3))
xx = dragon.zeros(shape=(2, 4)) y = dragon.zeros(shape=(2, 4))
dragon.copy([xx, x]) dragon.copy([x, y])
# Create a new tensor initialized from ``x`` # Copy to a new tensor from ``x``
xxx = dragon.copy(x) y = dragon.copy(x)
``` ```
Parameters Parameters
---------- ----------
inputs : Sequence[dragon.Tensor] inputs : Union[dragon.Tensor, Sequence[dragon.Tensor]]
The **ref** and **value**. The input tensor.
Returns Returns
------- -------
dragon.Tensor dragon.Tensor
The **ref**. The output tensor.
""" """
args = parse_args(locals()) args = parse_args(locals())
inputs = nest.flatten(inputs) args['inputs'] = nest.flatten(inputs)
if len(inputs) == 2: if len(args['inputs']) == 2:
args['inputs'] = [inputs[1]] args['outputs'] = [args['inputs'][1]]
args['outputs'] = [inputs[0]] args['inputs'] = [args['inputs'][0]]
else: else:
args['outputs'] = None args['outputs'] = None
op_lib = control_flow_ops_lib.Copy op_lib = control_flow_ops_lib.Copy
...@@ -104,8 +102,8 @@ def copy(inputs, **kwargs): ...@@ -104,8 +102,8 @@ def copy(inputs, **kwargs):
return op_lib.blend('Copy', **args) return op_lib.blend('Copy', **args)
@OpSchema.num_inputs(2) @OpSchema.num_inputs(3)
def masked_assign(inputs, mask, **kwargs): def masked_assign(inputs, **kwargs):
r"""Assign the value to ref where mask is **1**. r"""Assign the value to ref where mask is **1**.
.. math:: .. math::
...@@ -118,24 +116,22 @@ def masked_assign(inputs, mask, **kwargs): ...@@ -118,24 +116,22 @@ def masked_assign(inputs, mask, **kwargs):
Parameters Parameters
---------- ----------
inputs : Sequence[dragon.Tensor] inputs : Sequence[dragon.Tensor]
The **ref** and **value**. The **ref**, **value** and **mask** tensor.
mask : dragon.Tensor
The mask, with the same size as **ref**.
Returns Returns
------- -------
dragon.Tensor dragon.Tensor
The **ref**. The **ref** tensor..
""" """
args = parse_args(locals()) args = parse_args(locals())
inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype) inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype)
op_lib = control_flow_ops_lib.MaskedAssign op_lib = control_flow_ops_lib.MaskedAssign
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate().apply(inputs, mask) return op_lib.instantiate().apply(inputs)
else: else:
args.update({ args.update({
'outputs': [args['inputs'][0]], 'outputs': [args['inputs'][0]],
'inputs': [args['inputs'][1], mask], 'inputs': [args['inputs'][1:]],
}) })
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -47,7 +47,7 @@ class Assign(Operator): ...@@ -47,7 +47,7 @@ class Assign(Operator):
sizes[i], 'int64', sizes[i], 'int64',
) )
def forward(self, ws, inputs, starts, sizes): def forward(self, inputs, starts, sizes):
return self.dispatch( return self.dispatch(
[inputs[1]], [inputs[0]], [inputs[1]], [inputs[0]],
callback=lambda ws, handle: callback=lambda ws, handle:
...@@ -75,5 +75,5 @@ class MaskedAssign(Operator): ...@@ -75,5 +75,5 @@ class MaskedAssign(Operator):
def attributes(self): def attributes(self):
return {'op_type': 'MaskedAssign', 'arguments': {}} return {'op_type': 'MaskedAssign', 'arguments': {}}
def forward(self, inputs, mask): def forward(self, inputs):
return self.dispatch([inputs[1], mask], [inputs[0]], no_grad=True) return self.dispatch(inputs[1:], [inputs[0]], no_grad=True)
...@@ -88,9 +88,7 @@ def l1_loss(inputs, reduction='mean', **kwargs): ...@@ -88,9 +88,7 @@ def l1_loss(inputs, reduction='mean', **kwargs):
op_lib = loss_ops_lib.L1Loss op_lib = loss_ops_lib.L1Loss
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(reduction=args['reduction']).apply(inputs)
reduction=args['reduction'],
).apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -16,49 +16,49 @@ from __future__ import print_function ...@@ -16,49 +16,49 @@ from __future__ import print_function
from dragon.core.framework.ops import Operator from dragon.core.framework.ops import Operator
class Accumulate(Operator): class Affine(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Accumulate, self).__init__(key, dev, **kwargs) super(Affine, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.) self.axis = kwargs.get('axis', 1)
self.beta = kwargs.get('beta', 1.) self.num_axes = kwargs.get('num_axes', 1)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Accumulate', 'op_type': 'Affine',
'arguments': { 'arguments': {
'alpha': self.alpha, 'axis': self.axis,
'beta': self.beta, 'num_axes': self.num_axes,
} }
} }
def forward(self, inputs, outputs=None): def forward(self, inputs):
if outputs is None: return self.dispatch(inputs, [self.alloc()])
outputs = [self.alloc() for _ in range(len(inputs))]
return self.dispatch(inputs, outputs, no_grad=True)
class Affine(Operator): class Axpby(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Affine, self).__init__(key, dev, **kwargs) super(Axpby, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1) self.alpha = kwargs.get('alpha', 1.)
self.num_axes = kwargs.get('num_axes', 1) self.beta = kwargs.get('beta', 1.)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Affine', 'op_type': 'Axpby',
'arguments': { 'arguments': {
'axis': self.axis, 'alpha': self.alpha,
'num_axes': self.num_axes, 'beta': self.beta,
} }
} }
def forward(self, inputs): def forward(self, inputs, outputs=None):
return self.dispatch(inputs, [self.alloc()]) if outputs is None:
outputs = [self.alloc() for _ in range(len(inputs))]
return self.dispatch(inputs, outputs, no_grad=True)
class Binary(Operator): class BinaryOp(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Binary, self).__init__(key, dev, **kwargs) super(BinaryOp, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '') self.op_type = kwargs.get('op_type', '')
def attributes(self): def attributes(self):
...@@ -95,37 +95,18 @@ class Clip(Operator): ...@@ -95,37 +95,18 @@ class Clip(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class Dot(Operator):
def __init__(self, key, dev, **kwargs):
super(Dot, self).__init__(key, dev, **kwargs)
self.transA = kwargs.get('transA', False)
self.transB = kwargs.get('transB', False)
def attributes(self):
return {
'op_type': 'Dot',
'arguments': {
'transA': self.transA,
'transB': self.transB,
}
}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()])
class FullyConnected(Operator): class FullyConnected(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(FullyConnected, self).__init__(key, dev, **kwargs) super(FullyConnected, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1) self.axis = kwargs.get('axis', 1)
self.transW = kwargs.get('transW', True) self.transpose_w = kwargs.get('transpose_w', True)
def attributes(self): def attributes(self):
return { return {
'op_type': 'FullyConnected', 'op_type': 'FullyConnected',
'arguments': { 'arguments': {
'axis': self.axis, 'axis': self.axis,
'transW': self.transW, 'transW': self.transpose_w,
} }
} }
...@@ -133,18 +114,18 @@ class FullyConnected(Operator): ...@@ -133,18 +114,18 @@ class FullyConnected(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class Matmul(Operator): class MatMul(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Matmul, self).__init__(key, dev, **kwargs) super(MatMul, self).__init__(key, dev, **kwargs)
self.transA = kwargs.get('transA', False) self.transpose_a = kwargs.get('transpose_a', False)
self.transB = kwargs.get('transB', False) self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Matmul', 'op_type': 'MatMul',
'arguments': { 'arguments': {
'transA': self.transA, 'transA': self.transpose_a,
'transB': self.transB, 'transB': self.transpose_b,
} }
} }
...@@ -152,9 +133,9 @@ class Matmul(Operator): ...@@ -152,9 +133,9 @@ class Matmul(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class Unary(Operator): class UnaryOp(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Unary, self).__init__(key, dev, **kwargs) super(UnaryOp, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '') self.op_type = kwargs.get('op_type', '')
def attributes(self): def attributes(self):
......
...@@ -38,10 +38,10 @@ def batch_norm( ...@@ -38,10 +38,10 @@ def batch_norm(
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The moving average of stats are calculated as: The running average of statistics are calculated as:
.. math:: .. math::
x_{moving} \leftarrow momentum * x_{moving} + (1 - momentum) * x_{stat} x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
Note that the number of inputs should be **5**, i.e., Note that the number of inputs should be **5**, i.e.,
this operators is implemented into the fused version. this operators is implemented into the fused version.
...@@ -56,11 +56,11 @@ def batch_norm( ...@@ -56,11 +56,11 @@ def batch_norm(
axis : int, optional, default=-1 axis : int, optional, default=-1
The channel axis. The channel axis.
momentum : float, optional, default=0.9 momentum : float, optional, default=0.9
The momentum of moving average. The momentum for running average.
eps : float, optional, default=1e-5 eps : float, optional, default=1e-5
The epsilon. The value of :math:`\epsilon`.
use_stats : int, optional, default=-1 use_stats : int, optional, default=-1
Whether to use global stats. Whether to use estimated statistics or not.
Returns Returns
------- -------
...@@ -168,7 +168,7 @@ def instance_norm(inputs, axis=-1, eps=1e-5, **kwargs): ...@@ -168,7 +168,7 @@ def instance_norm(inputs, axis=-1, eps=1e-5, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def lp_normalize(inputs, axis=None, p=2, eps=1e-5, reduction='sum', **kwargs): def lp_normalize(inputs, axis=None, p=2, eps=1e-12, reduction='sum', **kwargs):
r"""Apply the lp normalization. r"""Apply the lp normalization.
The **Lp-Normalization** is defined as: The **Lp-Normalization** is defined as:
...@@ -200,7 +200,7 @@ def lp_normalize(inputs, axis=None, p=2, eps=1e-5, reduction='sum', **kwargs): ...@@ -200,7 +200,7 @@ def lp_normalize(inputs, axis=None, p=2, eps=1e-5, reduction='sum', **kwargs):
The order of the normalization. The order of the normalization.
axis : Union[int, Sequence[int]], optional axis : Union[int, Sequence[int]], optional
The axis to compute the norm. The axis to compute the norm.
eps : float, optional, default=1e-5 eps : float, optional, default=1e-12
The value of :math:`\epsilon`. The value of :math:`\epsilon`.
reduction : {'sum', 'mean'}, optional reduction : {'sum', 'mean'}, optional
The reduction method for norm. The reduction method for norm.
...@@ -326,9 +326,9 @@ def local_response_norm( ...@@ -326,9 +326,9 @@ def local_response_norm(
beta=args['beta'], beta=args['beta'],
bias=args['bias'], bias=args['bias'],
data_format=data_format, data_format=data_format,
).apply(inputs) ).apply([inputs])
else: else:
return op_lib.blend(**args) return op_lib.blend('LRN', **args)
@OpSchema.num_inputs(5) @OpSchema.num_inputs(5)
...@@ -349,10 +349,10 @@ def sync_batch_norm( ...@@ -349,10 +349,10 @@ def sync_batch_norm(
.. math:: .. math::
\text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta \text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The moving average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: .. math::
x_{moving} \leftarrow momentum * x_{moving} + (1 - momentum) * x_{stat} x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
Note that the number of inputs should be **5**, i.e., Note that the number of inputs should be **5**, i.e.,
this operators is implemented into the fused version. this operators is implemented into the fused version.
...@@ -367,11 +367,11 @@ def sync_batch_norm( ...@@ -367,11 +367,11 @@ def sync_batch_norm(
axis : int, optional, default=-1 axis : int, optional, default=-1
The channel axis. The channel axis.
momentum : float, optional, default=0.9 momentum : float, optional, default=0.9
The momentum of moving average. The momentum for average.
eps : float, optional, default=1e-5 eps : float, optional, default=1e-5
The epsilon. The value of :math:`\epsilon`.
use_stats : int, optional, default=-1 use_stats : int, optional, default=-1
Whether to use global stats. Whether to use estimated statistics or not.
process_group : ProcessGroup, optional process_group : ProcessGroup, optional
The group for communication. The group for communication.
......
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!