Commit adb6fa64 by Ting PAN

Add native ops test

Summary:
This commit tests the executing of native ops and verifies the results.
Several bugs are found and fixed according to these tests.
1 parent df172cc8
Showing with 1138 additions and 1425 deletions
......@@ -9,13 +9,13 @@
#
# ------------------------------------------------------------
"""Implementation for the ``Layer`` C++ class."""
"""The base layer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.autograph.tensor import RefTensor
from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context as eager_context
from dragon.core.framework import context
......@@ -76,8 +76,8 @@ class Layer(object):
param_name = scoped_name + '/param:{}'.format(len(self._blobs))
# Set the name explicitly.
variable = RefTensor(param_name)
variable_grad = RefTensor(param_name + '_grad')
variable = TensorRef(param_name)
variable_grad = TensorRef(param_name + '_grad')
if filler is not None:
variable._register_as(**filler)
......
......@@ -455,8 +455,8 @@ class InnerProduct(Layer):
param = layer_param.inner_product_param
self.arguments = {
'axis': param.axis,
'num_output': param.num_output,
'transW': not param.transpose,
'out_channels': param.num_output,
'transpose_w': not param.transpose,
}
# Add weights and biases
self.add_blob(filler=self.get_filler(param, 'weight_filler'))
......@@ -522,7 +522,7 @@ class Normalize(Layer):
normalize_param {
across_spatial: false
channel_shared: false
eps: 1e-5
eps: 1e-12
scale_filler: {
type: "constant"
value: 1
......@@ -548,7 +548,7 @@ class Normalize(Layer):
self.add_blob(filler=self.get_filler(param, 'scale_filler'), value=1)
def __call__(self, bottom):
norm_out = [normalization_ops.l2_normalize(bottom, **self.l2norm_arguments)]
norm_out = [normalization_ops.lp_normalize(bottom, **self.l2norm_arguments)]
norm_out += [blob['data'] for blob in self._blobs]
return math_ops.affine(norm_out, **self.affine_arguments)
......
......@@ -65,7 +65,7 @@ class Convolution(Layer):
super(Convolution, self).__init__(layer_param)
param = layer_param.convolution_param
self.arguments = {
'num_output': param.num_output,
'out_channels': param.num_output,
'kernel_shape': [int(e) for e in param.kernel_size],
'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1],
'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0],
......@@ -187,7 +187,7 @@ class DepthwiseConv2d(Layer):
super(DepthwiseConv2d, self).__init__(layer_param)
param = layer_param.convolution_param
self.arguments = {
'num_output': param.num_output,
'out_channels': param.num_output,
'kernel_shape': [int(e) for e in param.kernel_size],
'strides': [int(e) for e in param.stride] if len(param.stride) > 0 else [1],
'pads': [int(e) for e in param.pad] if len(param.pad) > 0 else [0],
......
......@@ -9,7 +9,7 @@
#
# ------------------------------------------------------------
"""Implementation for the ``Net`` C++ class."""
"""The base net class."""
from __future__ import absolute_import
from __future__ import division
......@@ -20,8 +20,8 @@ from google.protobuf import text_format
from dragon.core.autograph import def_function
from dragon.core.autograph import grad_impl
from dragon.core.autograph.tensor import RefTensor
from dragon.core.autograph.tensor import Tensor
from dragon.core.autograph.tensor import TensorRef
from dragon.core.framework import workspace
from dragon.core.util import nest
from dragon.vm.caffe import layers as layer_factory
......@@ -84,17 +84,13 @@ class Net(object):
if len(self._net_proto.input) > 0:
shapes = self._net_proto.input_shape
for i, input in enumerate(self._net_proto.input):
for i, input_name in enumerate(self._net_proto.input):
shape = [e for e in shapes[i].dim] if i < len(shapes) else None
if input not in self._blobs:
data = Tensor(input, shape=shape, dtype='float32').placeholder()
self._blobs[input] = {
data = Tensor(input_name, shape, 'float32').placeholder()
self._blobs[input_name] = {
'data': data,
'diff': RefTensor(
data.id + '_grad',
shape=shape,
dtype=data.dtype
),
'diff': TensorRef(data.id + '_grad', shape, data.dtype),
}
for layer in self._net_proto.layer:
......@@ -145,7 +141,7 @@ class Net(object):
for i, blob in enumerate(layer._top):
self._blobs[blob] = {
'data': outputs[i],
'diff': RefTensor(outputs[i].id + '_grad'),
'diff': TensorRef(outputs[i].id + '_grad'),
}
self._net_outputs.add(blob)
......
......@@ -9,7 +9,7 @@
#
# ------------------------------------------------------------
"""Implementation for the ``Solver`` C++ class."""
"""The solver to update parameters."""
from __future__ import absolute_import
from __future__ import division
......@@ -19,9 +19,12 @@ import time
from google.protobuf import text_format
from dragon import updaters
from dragon.core.autograph import def_function
from dragon.core.framework import workspace
from dragon.core.training.adam import Adam
from dragon.core.training.rmsprop import RMSprop
from dragon.core.training.sgd import SGD
from dragon.core.training.sgd import Nesterov
from dragon.vm.caffe.net import Net
from dragon.vm.caffe.proto import caffe_pb2
......@@ -47,10 +50,10 @@ class Solver(object):
if self._param.iter_size > 1:
raise NotImplementedError('GradientAccum is deprecated.')
self._arguments = {
'scale_gradient': 1. / self._param.iter_size,
'clip_gradient': float(self._param.clip_gradients),
'l2_decay': float(self._param.weight_decay)
if str(self._param.regularization_type) == 'L2' else -1.,
'scale': 1. / self._param.iter_size,
'clip_norm': float(self._param.clip_gradients),
'weight_decay': float(self._param.weight_decay)
if str(self._param.regularization_type) == 'L2' else 0,
}
self._optimizer = None
self._net, self._test_nets = None, []
......@@ -415,7 +418,7 @@ class AdamSolver(Solver):
self._arguments['beta1'] = self._param.momentum
self._arguments['beta2'] = self._param.momentum2
self._arguments['eps'] = self._param.delta
self._optimizer = updaters.Adam(**self._arguments)
self._optimizer = Adam(**self._arguments)
class NesterovSolver(Solver):
......@@ -447,7 +450,7 @@ class NesterovSolver(Solver):
super(NesterovSolver, self).__init__(solver_file, is_root)
self._arguments['base_lr'] = self._param.base_lr
self._arguments['momentum'] = self._param.momentum
self._optimizer = updaters.Nesterov(**self._arguments)
self._optimizer = Nesterov(**self._arguments)
class RMSPropSolver(Solver):
......@@ -481,7 +484,7 @@ class RMSPropSolver(Solver):
self._arguments['base_lr'] = self._param.base_lr
self._arguments['decay'] = self._param.rms_decay
self._arguments['eps'] = self._param.delta
self._optimizer = updaters.RMSProp(**self._arguments)
self._optimizer = RMSprop(**self._arguments)
class SGDSolver(Solver):
......@@ -513,4 +516,4 @@ class SGDSolver(Solver):
super(SGDSolver, self).__init__(solver_file, is_root)
self._arguments['base_lr'] = self._param.base_lr
self._arguments['momentum'] = self._param.momentum
self._optimizer = updaters.SGD(**self._arguments)
self._optimizer = SGD(**self._arguments)
......@@ -9,9 +9,6 @@ dragon.math
`abs(...) <math/abs.html>`_
: Compute the absolute value of input.
`accumulate(...) <math/accumulate.html>`_
: Compute the element-wise accumulation from input to output.
`add(...) <math/add.html>`_
: Compute the element-wise addition.
......@@ -24,6 +21,9 @@ dragon.math
`argmin(...) <math/argmin.html>`_
: Compute the indices of minimum elements along the given axis.
`axpby(...) <math/axpby.html>`_
: Compute the element-wise addition from input to output.
`ceil(...) <math/ceil.html>`_
: Compute the smallest integer not less than input.
......@@ -96,9 +96,6 @@ dragon.math
`moments(...) <math/moments.html>`_
: Compute the mean and variance of input along the given axes.
`moving_average(...) <math/moving_average.html>`_
: Compute the moving average of input to output.
`mul(...) <math/mul.html>`_
: Compute the element-wise multiplication.
......@@ -148,11 +145,11 @@ dragon.math
:hidden:
math/abs
math/accumulate
math/add
math/affine
math/argmax
math/argmin
math/axpby
math/ceil
math/clip
math/cos
......@@ -177,7 +174,6 @@ dragon.math
math/min
math/minimum
math/moments
math/moving_average
math/mul
math/negative
math/not_equal
......
accumulate
==========
axpby
=====
.. autofunction:: dragon.math.accumulate
.. autofunction:: dragon.math.axpby
.. raw:: html
......
moving_average
==============
.. autofunction:: dragon.math.moving_average
.. raw:: html
<style>
h1:before {
content: "dragon.math.";
color: #103d3e;
}
</style>
dragon.updaters
===============
dragon.optimizers
=================
.. only:: html
Classes
-------
`class Adam <updaters/Adam.html>`_
: The updater which implements Adam algorithm.
`class Adam <optimizers/Adam.html>`_
: The optimizer to apply Adam algorithm.
`[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_.
`class Nesterov <updaters/Nesterov.html>`_
: The updater which implements NesterovSGD algorithm.
`class Nesterov <optimizers/Nesterov.html>`_
: The optimizer to apply NesterovSGD algorithm.
`[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_.
`class RMSProp <updaters/RMSProp.html>`_
: The updater which implements RMSprop algorithm.
`class RMSProp <optimizers/RMSprop.html>`_
: The optimizer to apply RMSprop algorithm.
`[Hinton et.al, 2013] <http://www.cs.utoronto.ca/~bonner/courses/2016s/csc321/lectures/lec6.pdf>`_.
`class SGD <updaters/SGD.html>`_
: The updater which implements MomentumSGD algorithm.
`class SGD <optimizers/SGD.html>`_
: The optimizer to apply MomentumSGD algorithm.
`[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_.
.. toctree::
:hidden:
updaters/Adam
updaters/Nesterov
updaters/RMSProp
updaters/SGD
optimizers/Adam
optimizers/Nesterov
optimizers/Optimizer
optimizers/RMSprop
optimizers/SGD
.. raw:: html
......
Adam
====
.. autoclass:: dragon.updaters.Adam
.. autoclass:: dragon.optimizers.Adam
__init__
--------
.. automethod:: dragon.updaters.Adam.__init__
.. automethod:: dragon.optimizers.Adam.__init__
Methods
-------
apply_gradients
################
.. automethod:: dragon.updaters.Updater.apply_gradients
.. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex:
.. raw:: html
<style>
h1:before {
content: "dragon.updaters.";
content: "dragon.optimizers.";
color: #103d3e;
}
</style>
Nesterov
========
.. autoclass:: dragon.updaters.Nesterov
.. autoclass:: dragon.optimizers.Nesterov
__init__
--------
.. automethod:: dragon.updaters.Nesterov.__init__
.. automethod:: dragon.optimizers.Nesterov.__init__
Methods
-------
apply_gradients
################
.. automethod:: dragon.updaters.Updater.apply_gradients
.. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex:
.. raw:: html
<style>
h1:before {
content: "dragon.updaters.";
content: "dragon.optimizers.";
color: #103d3e;
}
</style>
Optimizer
=========
.. autoclass:: dragon.optimizers.Optimizer
__init__
--------
.. automethod:: dragon.optimizers.Optimizer.__init__
Methods
-------
apply_gradients
################
.. automethod:: dragon.optimizers.Optimizer.apply_gradients
.. raw:: html
<style>
h1:before {
content: "dragon.optimizers.";
color: #103d3e;
}
</style>
RMSProp
RMSprop
=======
.. autoclass:: dragon.updaters.RMSProp
.. autoclass:: dragon.optimizers.RMSprop
__init__
--------
.. automethod:: dragon.updaters.RMSProp.__init__
.. automethod:: dragon.optimizers.RMSprop.__init__
Methods
-------
apply_gradients
################
.. automethod:: dragon.updaters.Updater.apply_gradients
.. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex:
.. raw:: html
<style>
h1:before {
content: "dragon.updaters.";
content: "dragon.optimizers.";
color: #103d3e;
}
</style>
SGD
===
.. autoclass:: dragon.updaters.SGD
.. autoclass:: dragon.optimizers.SGD
__init__
--------
.. automethod:: dragon.updaters.SGD.__init__
.. automethod:: dragon.optimizers.SGD.__init__
Methods
-------
apply_gradients
################
.. automethod:: dragon.updaters.Updater.apply_gradients
.. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex:
.. raw:: html
<style>
h1:before {
content: "dragon.updaters.";
content: "dragon.optimizers.";
color: #103d3e;
}
</style>
......@@ -14,15 +14,15 @@ For using it, import as follows:
However, it will not help you much because you do not want to learn it.
We have extended it with following programming styles:
To resolve this matter, we are concerned to design diverse styles for you:
Dragon
######
*Dragon* takes a very light-weight programming style.
*Dragon* is initially as a light-weight but professional style.
Our goal is to reduce unnecessary structures or interfaces. Therefore,
in addition to feed or fetch, the last thing is designing a function.
Native interfaces are encouraged to manipulate the backend engine
to perform the computation flexibly with data feeding or fetching.
This style involves the following components:
......@@ -38,15 +38,15 @@ Dragon
* `dragon.math <dragon/math.html>`_
* `dragon.metrics <dragon/metrics.html>`_
* `dragon.nn <dragon/nn.html>`_
* `dragon.optimizers <dragon/optimizers.html>`_
* `dragon.random <dragon/random.html>`_
* `dragon.updaters <dragon/updaters.html>`_
* `dragon.vision <dragon/vision.html>`_
* `dragon.workspace <dragon/workspace.html>`_
* `dragon.vision <dragon/vision.html>`_
Caffe
#####
*Caffe* is one of the most famous deep learning framework for Computer Vision.
*Caffe* is the most famous framework for vision.
Our work is very different from the official python wrappers, a.k.a,
the *PyCaffe*, which comes from the exports of *BoostPython*
......@@ -102,7 +102,7 @@ PyTorch
*PyTorch* provides straight-forward operations on research prototyping.
To bridge it, our *JIT* traces and dispatches the expressions,
To bridge it, our *JIT* traces and dispatches the operations,
as well as the rewriting of *GC* (Garbage Collection) to reuse
the memories and operators by turns.
......@@ -168,52 +168,52 @@ Modules
.. only:: html
`Module autograph <dragon/autograph.html>`_
: Public API for ``dragon.autograph`` namespace.
: Native API for ``dragon.autograph`` namespace.
`Module bitwise <dragon/bitwise.html>`_
: Public API for ``dragon.bitwise`` namespace.
: Native API for ``dragon.bitwise`` namespace.
`Module cuda <dragon/cuda.html>`_
: Public API for ``dragon.cuda`` namespace.
: Native API for ``dragon.cuda`` namespace.
`Module distributed <dragon/distributed.html>`_
: Public API for ``dragon.distributed`` namespace.
: Native API for ``dragon.distributed`` namespace.
`Module dlpack <dragon/dlpack.html>`_
: Public API for ``dragon.dlpack`` namespace.
: Native API for ``dragon.dlpack`` namespace.
`Module io <dragon/io.html>`_
: Public API for ``dragon.io`` namespace.
: Native API for ``dragon.io`` namespace.
`Module logging <dragon/logging.html>`_
: Public API for ``dragon.logging`` namespace.
: Native API for ``dragon.logging`` namespace.
`Module losses <dragon/losses.html>`_
: Public API for ``dragon.losses`` namespace.
: Native API for ``dragon.losses`` namespace.
`Module math <dragon/math.html>`_
: Public API for ``dragon.math`` namespace.
: Native API for ``dragon.math`` namespace.
`Module metrics <dragon/metrics.html>`_
: Public API for ``dragon.metrics`` namespace.
: Native API for ``dragon.metrics`` namespace.
`Module nn <dragon/nn.html>`_
: Public API for ``dragon.nn`` namespace.
: Native API for ``dragon.nn`` namespace.
`Module optimizers <dragon/optimizers.html>`_
: Native API for ``dragon.optimizers`` namespace.
`Module random <dragon/random.html>`_
: Public API for ``dragon.random`` namespace.
: Native API for ``dragon.random`` namespace.
`Module updaters <dragon/updaters.html>`_
: Public API for ``dragon.updaters`` namespace.
`Module workspace <dragon/workspace.html>`_
: Native API for ``dragon.workspace`` namespace.
`Module vision <dragon/vision.html>`_
: Public API for ``dragon.vision`` namespace.
: Native API for ``dragon.vision`` namespace.
`Module workspace <dragon/workspace.html>`_
: Public API for ``dragon.workspace`` namespace.
`Module workspace <dragon/workspace.html>`_
: Public API for ``dragon.workspace`` namespace.
: Native API for ``dragon.workspace`` namespace.
`Module vm.caffe <caffe.html>`_
: Virtual API for ``caffe`` namespace.
......@@ -317,10 +317,10 @@ Modules
dragon/math
dragon/metrics
dragon/nn
dragon/optimizers
dragon/random
dragon/updaters
dragon/vision
dragon/workspace
dragon/vision
caffe
caffe/layers
dali
......
......@@ -30,9 +30,6 @@ vm.torch
`abs(...) <torch/abs.html>`_
: Compute the absolute value of input.
`accumulate(...) <torch/accumulate.html>`_
: Compute the element-wise accumulation from input to output.
`add(...) <torch/add.html>`_
: Compute the element-wise addition.
......@@ -45,6 +42,9 @@ vm.torch
`argmin(...) <torch/argmin.html>`_
: Return the indices of minimum elements along the given axis.
`axpby(...) <torch/axpby.html>`_
: Compute the element-wise addition from input to output.
`bitwise_not(...) <torch/bitwise_not.html>`_
: Compute the element-wise NOT bitwise operation.
......@@ -254,11 +254,11 @@ vm.torch
:hidden:
torch/abs
torch/accumulate
torch/add
torch/arange
torch/argmax
torch/argmin
torch/axpby
torch/bitwise_not
torch/bitwise_xor
torch/cat
......
accumulate
==========
axpby
=====
.. autofunction:: dragon.vm.torch.accumulate
.. autofunction:: dragon.vm.torch.axpby
.. raw:: html
......
......@@ -50,18 +50,18 @@ class CUDAObject {
*/
if (stream) cudaStreamDestroy(stream);
}
for (auto& e : cublas_handles_[i])
if (e) {
CUBLAS_CHECK(cublasDestroy_v2(e));
for (auto& handle : cublas_handles_[i])
if (handle) {
CUBLAS_CHECK(cublasDestroy(handle));
}
#ifdef USE_CUDNN
for (auto& e : cudnn_handles_[i])
if (e) {
CUDNN_CHECK(cudnnDestroy(e));
for (auto& handle : cudnn_handles_[i])
if (handle) {
CUDNN_CHECK(cudnnDestroy(handle));
}
#endif
#ifdef USE_NCCL
for (auto& e : nccl_comms_[i]) {
for (auto& comm : nccl_comms_[i]) {
/*!
* Temporarily disable the comm destroying,
* to avoid an unhandled error.
......@@ -74,17 +74,18 @@ class CUDAObject {
/*! \brief Return the specified cublas handle */
cublasHandle_t cublas_handle(int device_id, int stream_id) {
auto& handles = cublas_handles_[device_id];
if (handles.size() <= (unsigned)stream_id)
if (handles.size() <= (unsigned)stream_id) {
handles.resize(stream_id + 1, nullptr);
}
if (!handles[stream_id]) {
CUDADeviceGuard guard(device_id);
CUBLAS_CHECK(cublasCreate_v2(&handles[stream_id]));
CUBLAS_CHECK(
cublasSetStream_v2(handles[stream_id], stream(device_id, stream_id)));
CUBLAS_CHECK(cublasCreate(&handles[stream_id]));
auto& handle = handles[stream_id];
CUBLAS_CHECK(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSetStream(handle, stream(device_id, stream_id)));
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
CUBLAS_CHECK(
cublasSetMathMode(handles[stream_id], CUBLAS_TENSOR_OP_MATH));
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
}
#endif
}
......@@ -95,13 +96,14 @@ class CUDAObject {
#ifdef USE_CUDNN
cudnnHandle_t cudnn_handle(int device_id, int stream_id) {
auto& handles = cudnn_handles_[device_id];
if (handles.size() <= (unsigned)stream_id)
if (handles.size() <= (unsigned)stream_id) {
handles.resize(stream_id + 1, nullptr);
}
if (!handles[stream_id]) {
CUDADeviceGuard guard(device_id);
CUDNN_CHECK(cudnnCreate(&handles[stream_id]));
CUDNN_CHECK(
cudnnSetStream(handles[stream_id], stream(device_id, stream_id)));
auto& handle = handles[stream_id];
CUDNN_CHECK(cudnnSetStream(handle, stream(device_id, stream_id)));
}
return handles[stream_id];
}
......@@ -144,7 +146,7 @@ class CUDAObject {
if (!streams[stream_id]) {
CUDADeviceGuard guard(device_id);
unsigned int flags =
!stream_id ? cudaStreamDefault : cudaStreamNonBlocking;
stream_id == 0 ? cudaStreamDefault : cudaStreamNonBlocking;
CUDA_CHECK(cudaStreamCreateWithFlags(&streams[stream_id], flags));
}
return streams[stream_id];
......
......@@ -80,7 +80,7 @@ Tensor* OperatorBase::Output(int i, const vec32_t& inputs) {
}
Tensor* OperatorBase::Buffer(const string& name) {
return ws()->CreateTensor(unique_name(name));
return ws()->CreateTensor("/share/buffer/" + handle_ + "/" + name);
}
string OperatorBase::TypeString(const Tensor& tensor, const Set<string>& types)
......
......@@ -133,11 +133,6 @@ class DRAGON_API OperatorBase {
return handle_;
}
/*! \brief Return the unique name in this operator */
const string unique_name(const string& name) const {
return "/mnt/" + handle_ + "/" + name;
}
/*! \brief Return the stored def */
const OperatorDef& def() const {
return def_;
......@@ -268,7 +263,6 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
using OperatorBase::dtype; \
using OperatorBase::data_format; \
using OperatorBase::handle; \
using OperatorBase::unique_name; \
using OperatorBase::def; \
using OperatorBase::ws
......@@ -277,17 +271,18 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
using Operator<Context>::allow_run; \
using Operator<Context>::ctx
#define STORE_INPUT_SPEC(i) \
*(ws()->CreateTensor(unique_name("Input[" + std::to_string(i) + "]")) \
->ReshapeLike(Input(i)) \
#define STORE_INPUT_SPEC(i) \
*(Buffer("X_spec:" + std::to_string(i)) \
->ReshapeLike(Input(i)) \
->set_meta(Input(i).meta()))
#define RESTORE_INPUT_SPEC(i) \
*(ws()->GetTensor(unique_name("Input[" + std::to_string(i) + "]")))
*(ws()->GetTensor( \
"/share/buffer/" + handle() + "/X_spec:" + std::to_string(i)))
/* Dispatchers */
#define XIsType(x, type) x.template IsType<type>()
#define XIsType(X, type) X.template IsType<type>()
template <typename... Types>
struct TensorTypes {};
......
......@@ -53,14 +53,11 @@ __global__ void _EluGrad(
const T* y,
T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = dy[i] *
(
#if __CUDA_ARCH__ >= 350
__ldg(y + i) > T(0) ? T(1) : alpha + __ldg(y + i)
dx[i] = dy[i] * (__ldg(y + i) > T(0) ? T(1) : alpha + __ldg(y + i));
#else
y[i] > T(0) ? T(1) : (alpha + y[i])
dx[i] = dy[i] * (y[i] > T(0) ? T(1) : (alpha + y[i]));
#endif
);
}
}
......
......@@ -14,28 +14,28 @@ void _Softmax(
const int inner_dim,
const T* x,
T* y) {
int row_ofs, col_ofs, yi;
int row_offset, col_offset, yi;
auto x_stride = axis_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) {
row_ofs = i * axis_dim * inner_dim;
row_offset = i * axis_dim * inner_dim;
for (int j = 0; j < inner_dim; ++j) {
col_ofs = row_ofs + j;
T val = x[col_ofs];
col_offset = row_offset + j;
T val = x[col_offset];
for (int k = 1; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim;
yi = col_offset + k * inner_dim;
val = std::max(val, x[yi]);
}
for (int k = 0; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim;
yi = col_offset + k * inner_dim;
y[yi] = std::exp(x[yi] - val);
}
val = y[col_ofs];
val = y[col_offset];
for (int k = 1; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim;
yi = col_offset + k * inner_dim;
val += y[yi];
}
for (int k = 0; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim;
yi = col_offset + k * inner_dim;
y[yi] /= val;
}
}
......@@ -60,19 +60,19 @@ void _SoftmaxGrad(
const T* dy,
const T* y,
T* dx) {
int row_ofs, col_ofs, yi;
int row_offset, col_offset, yi;
auto x_stride = axis_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) {
row_ofs = i * axis_dim * inner_dim;
row_offset = i * axis_dim * inner_dim;
for (int j = 0; j < inner_dim; ++j) {
col_ofs = row_ofs + j;
T val = dy[col_ofs] * y[col_ofs];
col_offset = row_offset + j;
T val = dy[col_offset] * y[col_offset];
for (int k = 1; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim;
yi = col_offset + k * inner_dim;
val += dy[yi] * y[yi];
}
for (int k = 0; k < axis_dim; ++k) {
yi = col_ofs + k * inner_dim;
yi = col_offset + k * inner_dim;
dx[yi] = (dy[yi] - val) * y[yi];
}
}
......
......@@ -53,11 +53,11 @@ void _CumSumReverse(
CPUContext* ctx) {
const int kStart = axis_dim - 1;
for (int n = 0; n < outer_dim; ++n) {
const int n_ofs = n * axis_dim;
const int n_offset = n * axis_dim;
for (int m = kStart; m >= 0; --m) {
const int nm_ofs = (n_ofs + m) * inner_dim;
const int nm_offset = (n_offset + m) * inner_dim;
for (int k = 0; k < inner_dim; ++k) {
const int i = nm_ofs + k;
const int i = nm_offset + k;
if (m < kStart) {
const int j = i + inner_dim;
y[i] = y[j] + x[exclusive ? j : i];
......
......@@ -25,9 +25,9 @@ void _SetEye(const int n, const int m, const int k, T* y) {
const int n, const int m, const int k, T* y, CPUContext* ctx) { \
math::Set(n* m, cast::to<T>(0.f), y, ctx); \
if (k > 0) { \
_SetEye(n - k, m, k, y); \
if (m - k > 0) _SetEye(m - k, m, k, y); \
} else { \
_SetEye(n + k, m, 0, y - k * m); \
if (n + k > 0) _SetEye(n + k, m, 0, y - k * m); \
} \
}
......
......@@ -20,9 +20,9 @@ __global__ void _SetEye(const int n, const int m, const int k, T* y) {
template <>
__global__ void _SetEye<half>(const int n, const int m, const int k, half* y) {
const half kZero = __float2half(1.f);
const half kOne = __float2half(1.f);
CUDA_1D_KERNEL_LOOP(i, n) {
y[i * m + k + i] = kZero;
y[i * m + k + i] = kOne;
}
}
......@@ -39,26 +39,34 @@ void Eye<float16, CUDAContext>(
CUDAContext* ctx) {
math::Set(n * m, cast::to<float16>(0.f), y, ctx);
if (k > 0) {
_SetEye<<<CUDA_BLOCKS(n - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n - k, m, k, reinterpret_cast<half*>(y));
if (m - k > 0) {
_SetEye<<<CUDA_BLOCKS(m - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
m - k, m, k, reinterpret_cast<half*>(y));
}
} else {
_SetEye<<<CUDA_BLOCKS(n + k), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n + k, m, 0, reinterpret_cast<half*>(y - k * m));
if (n + k > 0) {
_SetEye<<<CUDA_BLOCKS(n + k), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n + k, m, 0, reinterpret_cast<half*>(y - k * m));
}
}
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Eye<T, CUDAContext>( \
const int n, const int m, const int k, T* y, CUDAContext* ctx) { \
math::Set(n* m, T(0), y, ctx); \
if (k > 0) { \
_SetEye<<<CUDA_BLOCKS(n - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n - k, m, k, y); \
} else { \
_SetEye<<<CUDA_BLOCKS(n + k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n + k, m, 0, y - k * m); \
} \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Eye<T, CUDAContext>( \
const int n, const int m, const int k, T* y, CUDAContext* ctx) { \
math::Set(n* m, T(0), y, ctx); \
if (k > 0) { \
if (m - k > 0) { \
_SetEye<<<CUDA_BLOCKS(m - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
m - k, m, k, y); \
} \
} else { \
if (n + k > 0) { \
_SetEye<<<CUDA_BLOCKS(n + k), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n + k, m, 0, y - k * m); \
} \
} \
}
DEFINE_KERNEL_LAUNCHER(bool);
......
......@@ -35,21 +35,22 @@ void _BroadcastLossGrad<float16>(
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ReduceLoss<T, CPUContext>( \
const int count, \
const int num_masks, \
const float normalizer, \
const T* x, \
const int* mask, \
T* y, \
CPUContext* ctx) { \
float inv_scale = std::max( \
1e-5F, \
num_masks > 0 ? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \
y[0] = math::Sum(count, 1.f / inv_scale, x, ctx); \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ReduceLoss<T, CPUContext>( \
const int count, \
const int num_masks, \
const float normalizer, \
const T* x, \
const int* mask, \
T* y, \
CPUContext* ctx) { \
float inv_scale = std::max( \
1e-5F, \
num_masks > 0 && normalizer < 0.f \
? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \
y[0] = math::Sum(count, 1.f / inv_scale, x, ctx); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -64,8 +65,9 @@ void _BroadcastLossGrad<float16>(
CPUContext* ctx) { \
float inv_scale = std::max( \
1e-5F, \
num_masks > 0 ? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \
num_masks > 0 && normalizer < 0.f \
? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \
math::Scale(count, cast::to<float>(dy[0]) / inv_scale, dx, dx, ctx); \
} \
template <> \
......
......@@ -152,15 +152,15 @@ __global__ void _ReduceLossGradWithMask<half>(
template <typename T>
__global__ void _BroadcastLossGrad(
const int nthreads,
const int rows,
const int cols,
const int dim1,
const int dim2,
const T* dy,
T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
dx[i] *= __ldg(dy + (i / rows) * cols + (i % cols));
dx[i] *= __ldg(dy + (i / dim1) * dim2 + (i % dim2));
#else
dx[i] *= dy[(i / rows) * cols + (i % cols)];
dx[i] *= dy[(i / dim1) * dim2 + (i % dim2)];
#endif
}
}
......@@ -168,18 +168,18 @@ __global__ void _BroadcastLossGrad(
template <>
__global__ void _BroadcastLossGrad<half>(
const int nthreads,
const int rows,
const int cols,
const int dim1,
const int dim2,
const half* dy,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
dx[i] = __float2half(
__half2float(dx[i]) *
__half2float(__ldg(dy + (i / rows) * cols + (i % cols))));
__half2float(__ldg(dy + (i / dim1) * dim2 + (i % dim2))));
#else
dx[i] = __float2half(
__half2float(dx[i]) * __half2float(dy[(i / rows) * cols + (i % cols)]));
__half2float(dx[i]) * __half2float(dy[(i / dim1) * dim2 + (i % dim2)]));
#endif
}
}
......@@ -197,7 +197,7 @@ void ReduceLoss<float16, CUDAContext>(
const int* mask,
float16* y,
CUDAContext* ctx) {
if (num_masks > 0) {
if (num_masks > 0 && normalizer < 0.f) {
_ReduceLossWithMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
num_masks,
reinterpret_cast<const half*>(x),
......@@ -221,7 +221,7 @@ void ReduceLossGrad<float16, CUDAContext>(
const int* mask,
float16* dx,
CUDAContext* ctx) {
if (num_masks > 0) {
if (num_masks > 0 && normalizer < 0.f) {
_ReduceMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
num_masks, const_cast<int*>(mask));
_ReduceLossGradWithMask<<<
......@@ -254,16 +254,15 @@ void BroadcastLossGrad<float16, CUDAContext>(
const float16* dy,
float16* dx,
CUDAContext* ctx) {
auto rows = outer_dim * axis_dim, cols = inner_dim;
auto nthreads = rows * cols;
auto nthreads = outer_dim * axis_dim * inner_dim;
_BroadcastLossGrad<<<
CUDA_BLOCKS(nthreads),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
nthreads,
rows,
cols,
axis_dim * inner_dim,
inner_dim,
reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx));
} // BroadcastLossGrad
......@@ -278,7 +277,7 @@ void BroadcastLossGrad<float16, CUDAContext>(
const int* mask, \
T* y, \
CUDAContext* ctx) { \
if (num_masks > 0) { \
if (num_masks > 0 && normalizer < 0.f) { \
_ReduceLossWithMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
num_masks, x, mask, y); \
} else { \
......@@ -297,7 +296,7 @@ void BroadcastLossGrad<float16, CUDAContext>(
const int* mask, \
T* dx, \
CUDAContext* ctx) { \
if (num_masks > 0) { \
if (num_masks > 0 && normalizer < 0.f) { \
_ReduceMask<<<1, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
num_masks, const_cast<int*>(mask)); \
_ReduceLossGradWithMask<<< \
......@@ -322,13 +321,13 @@ void BroadcastLossGrad<float16, CUDAContext>(
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
auto rows = outer_dim * axis_dim, cols = inner_dim; \
auto nthreads = rows * cols; \
auto nthreads = outer_dim * axis_dim * inner_dim; \
_BroadcastLossGrad<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(nthreads, rows, cols, dy, dx); \
ctx->cuda_stream()>>>( \
nthreads, axis_dim * inner_dim, inner_dim, dy, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
......
......@@ -7,31 +7,31 @@ namespace dragon {
namespace kernel {
template <>
void MixedPrecL2Decay<float16, CPUContext>(
void MixedPrecL2Penalty<float16, CPUContext>(
const int count,
const float alpha,
const float16* w,
const float16* x,
float* dx,
CPUContext* ctx) {
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
dx[i] += (cast::to<float>(w[i]) * alpha);
dx[i] += (cast::to<float>(x[i]) * alpha);
}
}
template <>
void MixedPrecUpdate<float16, CPUContext>(
const int count,
const float* updates,
float16* w,
const float* dx,
float16* x,
CPUContext* ctx) {
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
w[i] = cast::to<float16>(cast::to<float>(w[i]) - updates[i]);
x[i] = cast::to<float16>(cast::to<float>(x[i]) - dx[i]);
}
}
......
......@@ -9,24 +9,19 @@ namespace kernel {
namespace {
__global__ void _MixedPrecL2DecayHalf(
__global__ void _MixedPrecL2Penalty(
const int nthreads,
const float alpha,
const half* w,
const half* x,
float* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] += __half2float(w[i]) * alpha;
#endif
dx[i] += __half2float(x[i]) * alpha;
}
}
__global__ void
_MixedPrecUpdateHalf(const int nthreads, const float* updates, half* w) {
__global__ void _MixedPrecUpdate(const int nthreads, const float* dx, half* x) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
w[i] = __float2half(__half2float(w[i]) - updates[i]);
#endif
x[i] = __float2half(__half2float(x[i]) - dx[i]);
}
}
......@@ -35,30 +30,27 @@ _MixedPrecUpdateHalf(const int nthreads, const float* updates, half* w) {
/* ------------------- Launcher Separator ------------------- */
template <>
void MixedPrecL2Decay<float16, CUDAContext>(
void MixedPrecL2Penalty<float16, CUDAContext>(
const int count,
const float alpha,
const float16* w,
const float16* x,
float* dx,
CUDAContext* ctx) {
_MixedPrecL2DecayHalf<<<
_MixedPrecL2Penalty<<<
CUDA_BLOCKS(count),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(count, alpha, reinterpret_cast<const half*>(w), dx);
ctx->cuda_stream()>>>(count, alpha, reinterpret_cast<const half*>(x), dx);
}
template <>
void MixedPrecUpdate<float16, CUDAContext>(
const int count,
const float* updates,
float16* w,
const float* dx,
float16* x,
CUDAContext* ctx) {
_MixedPrecUpdateHalf<<<
CUDA_BLOCKS(count),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(count, updates, reinterpret_cast<half*>(w));
_MixedPrecUpdate<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, dx, reinterpret_cast<half*>(x));
}
} // namespace kernel
......
......@@ -116,15 +116,13 @@ __global__ void _AvgPool2dGradNCHW(
const T* dy,
T* dx) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int w = xi % W;
const int h = (xi / W) % H;
const int w = xi % W + pad_w;
const int h = (xi / W) % H + pad_h;
const int c = (xi / W / H) % C;
const int n = xi / W / H / C;
const int phstart =
(h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
const int pwstart =
(w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int phend = min(h / stride_h + 1, out_h);
const int pwend = min(w / stride_w + 1, out_w);
......@@ -164,14 +162,12 @@ __global__ void _AvgPool2dGradNHWC(
T* dx) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int c = xi % C;
const int w = (xi / C) % W;
const int h = (xi / C / W) % H;
const int w = (xi / C) % W + pad_w;
const int h = (xi / C / W) % H + pad_h;
const int n = xi / C / W / H;
const int phstart =
(h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
const int pwstart =
(w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int phend = min(h / stride_h + 1, out_h);
const int pwend = min(w / stride_w + 1, out_w);
......
......@@ -30,8 +30,8 @@ void _Im2Col2dNCHW(
const T* im,
T* col) {
int ih, iw;
const int im_ofs = H * W;
for (int c = 0; c < C; ++c, im += im_ofs) {
const int im_offset = H * W;
for (int c = 0; c < C; ++c, im += im_offset) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
ih = -pad_h + kh * dilation_h;
......@@ -117,8 +117,8 @@ void _Col2Im2dNCHW(
const T* col,
T* im) {
int ih, iw;
const int im_ofs = H * W;
for (int c = 0; c < C; ++c, im += im_ofs) {
const int im_offset = H * W;
for (int c = 0; c < C; ++c, im += im_offset) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
ih = -pad_h + kh * dilation_h;
......
......@@ -27,13 +27,13 @@ void _DepthwiseConv2dNCHW(
T* y) {
T sum_val;
int ih, iw, xi, wi;
int yc_ofs, xc_start, yc_start;
int yc_offset, xc_start, yc_start;
int ih_start, yh_start, iw_start;
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
yc_ofs = n * C + c;
xc_start = yc_ofs * H * W;
yc_start = yc_ofs * out_h;
yc_offset = n * C + c;
xc_start = yc_offset * H * W;
yc_start = yc_offset * out_h;
for (int oh = 0; oh < out_h; ++oh) {
ih_start = oh * stride_h - pad_h;
yh_start = (yc_start + oh) * out_w;
......
......@@ -46,7 +46,7 @@ void _ResizeLinearNCHW(
std::array<int, 4> idx = {0, 0, 0, 0};
std::array<int, 4> dims = {N, C, out_h, out_w};
float h_in, w_in, u, v, t, b, tl, tr, bl, br;
int ti, bi, li, ri, ofs, h_max = H - 1, w_max = W - 1;
int ti, bi, li, ri, offset, h_max = H - 1, w_max = W - 1;
for (int i = 0; i < count; ++i) {
h_in = TransformCoordinate(idx[2], scale_h, align_corners);
w_in = TransformCoordinate(idx[3], scale_w, align_corners);
......@@ -54,11 +54,11 @@ void _ResizeLinearNCHW(
bi = (h_in < h_max) ? std::ceil(h_in) : h_max;
ri = (w_in < w_max) ? std::ceil(w_in) : w_max;
v = h_in - ti, u = w_in - li;
ofs = (idx[0] * C + idx[1]) * H;
tl = (float)x[(ofs + ti) * W + li];
tr = (float)x[(ofs + ti) * W + ri];
bl = (float)x[(ofs + bi) * W + li];
br = (float)x[(ofs + bi) * W + ri];
offset = (idx[0] * C + idx[1]) * H;
tl = (float)x[(offset + ti) * W + li];
tr = (float)x[(offset + ti) * W + ri];
bl = (float)x[(offset + bi) * W + li];
br = (float)x[(offset + bi) * W + ri];
t = tl + (tr - tl) * u;
b = bl + (br - bl) * u;
y[i] = static_cast<T>(t + (b - t) * v);
......@@ -83,7 +83,7 @@ void _ResizeLinearNHWC(
std::array<int, 4> idx = {0, 0, 0, 0};
std::array<int, 4> dims = {N, out_h, out_w, C};
float h_in, w_in, u, v, t, b, tl, tr, bl, br;
int ti, bi, li, ri, ofs, h_max = H - 1, w_max = W - 1;
int ti, bi, li, ri, offset, h_max = H - 1, w_max = W - 1;
for (int i = 0; i < count; ++i) {
h_in = TransformCoordinate(idx[1], scale_h, align_corners);
w_in = TransformCoordinate(idx[2], scale_w, align_corners);
......@@ -91,11 +91,11 @@ void _ResizeLinearNHWC(
bi = (h_in < h_max) ? std::ceil(h_in) : h_max;
ri = (w_in < w_max) ? std::ceil(w_in) : w_max;
v = h_in - ti, u = w_in - li;
ofs = idx[0] * H;
tl = (float)x[((ofs + ti) * W + li) * C + idx[3]];
tr = (float)x[((ofs + ti) * W + ri) * C + idx[3]];
bl = (float)x[((ofs + bi) * W + li) * C + idx[3]];
br = (float)x[((ofs + bi) * W + ri) * C + idx[3]];
offset = idx[0] * H;
tl = (float)x[((offset + ti) * W + li) * C + idx[3]];
tr = (float)x[((offset + ti) * W + ri) * C + idx[3]];
bl = (float)x[((offset + bi) * W + li) * C + idx[3]];
br = (float)x[((offset + bi) * W + ri) * C + idx[3]];
t = tl + (tr - tl) * u;
b = bl + (br - bl) * u;
y[i] = static_cast<T>(t + (b - t) * v);
......@@ -120,7 +120,7 @@ void _ResizeLinearGradNCHW(
std::array<int, 4> idx = {0, 0, 0, 0};
std::array<int, 4> dims = {N, C, out_h, out_w};
float h_in, w_in, u, v, dt, db, tl, tr, bl, br;
int ti, bi, li, ri, ofs, h_max = H - 1, w_max = W - 1;
int ti, bi, li, ri, offset, h_max = H - 1, w_max = W - 1;
for (int i = 0; i < count; ++i) {
h_in = TransformCoordinate(idx[2], scale_h, align_corners);
w_in = TransformCoordinate(idx[3], scale_w, align_corners);
......@@ -128,13 +128,13 @@ void _ResizeLinearGradNCHW(
bi = (h_in < h_max) ? std::ceil(h_in) : h_max;
ri = (w_in < w_max) ? std::ceil(w_in) : w_max;
v = h_in - ti, u = w_in - li;
ofs = (idx[0] * C + idx[1]) * H;
offset = (idx[0] * C + idx[1]) * H;
dt = (1.f - v) * static_cast<float>(dy[i]);
db = v * static_cast<float>(dy[i]);
dx[(ofs + ti) * W + li] += (1.f - u) * dt; // tl
dx[(ofs + ti) * W + ri] += u * dt; // tr
dx[(ofs + bi) * W + li] += (1.f - u) * db; // bl
dx[(ofs + bi) * W + ri] += u * db; // br
dx[(offset + ti) * W + li] += (1.f - u) * dt; // tl
dx[(offset + ti) * W + ri] += u * dt; // tr
dx[(offset + bi) * W + li] += (1.f - u) * db; // bl
dx[(offset + bi) * W + ri] += u * db; // br
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -156,7 +156,7 @@ void _ResizeLinearGradNHWC(
std::array<int, 4> idx = {0, 0, 0, 0};
std::array<int, 4> dims = {N, out_h, out_w, C};
float h_in, w_in, u, v, dt, db, tl, tr, bl, br;
int ti, bi, li, ri, ofs, h_max = H - 1, w_max = W - 1;
int ti, bi, li, ri, offset, h_max = H - 1, w_max = W - 1;
for (int i = 0; i < count; ++i) {
h_in = TransformCoordinate(idx[1], scale_h, align_corners);
w_in = TransformCoordinate(idx[2], scale_w, align_corners);
......@@ -164,13 +164,13 @@ void _ResizeLinearGradNHWC(
bi = (h_in < h_max) ? std::ceil(h_in) : h_max;
ri = (w_in < w_max) ? std::ceil(w_in) : w_max;
v = h_in - ti, u = w_in - li;
ofs = idx[0] * H;
offset = idx[0] * H;
dt = (1.f - v) * static_cast<float>(dy[i]);
db = v * static_cast<float>(dy[i]);
dx[((ofs + ti) * W + li) * C + idx[3]] += (1.f - u) * dt; // tl
dx[((ofs + ti) * W + ri) * C + idx[3]] += u * dt; // tr
dx[((ofs + bi) * W + li) * C + idx[3]] += (1.f - u) * db; // bl
dx[((ofs + bi) * W + ri) * C + idx[3]] += u * db; // br
dx[((offset + ti) * W + li) * C + idx[3]] += (1.f - u) * dt; // tl
dx[((offset + ti) * W + ri) * C + idx[3]] += u * dt; // tr
dx[((offset + bi) * W + li) * C + idx[3]] += (1.f - u) * db; // bl
dx[((offset + bi) * W + ri) * C + idx[3]] += u * db; // br
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......
......@@ -61,17 +61,17 @@ __global__ void _ResizeLinearNCHW(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li;
const int ofs = (n * C + c) * H;
const int offset = (n * C + c) * H;
#if __CUDA_ARCH__ >= 350
const float tl = __ldg(x + ((ofs + ti) * W + li));
const float tr = __ldg(x + ((ofs + ti) * W + ri));
const float bl = __ldg(x + ((ofs + bi) * W + li));
const float br = __ldg(x + ((ofs + bi) * W + ri));
const float tl = __ldg(x + ((offset + ti) * W + li));
const float tr = __ldg(x + ((offset + ti) * W + ri));
const float bl = __ldg(x + ((offset + bi) * W + li));
const float br = __ldg(x + ((offset + bi) * W + ri));
#else
const float tl = x[(ofs + ti) * W + li];
const float tr = x[(ofs + ti) * W + ri];
const float bl = x[(ofs + bi) * W + li];
const float br = x[(ofs + bi) * W + ri];
const float tl = x[(offset + ti) * W + li];
const float tr = x[(offset + ti) * W + ri];
const float bl = x[(offset + bi) * W + li];
const float br = x[(offset + bi) * W + ri];
#endif
const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u;
......@@ -109,11 +109,11 @@ __global__ void _ResizeLinearNCHW<half>(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li;
const int ofs = (n * C + c) * H;
const float tl = __half2float(__ldg(x + ((ofs + ti) * W + li)));
const float tr = __half2float(__ldg(x + ((ofs + ti) * W + ri)));
const float bl = __half2float(__ldg(x + ((ofs + bi) * W + li)));
const float br = __half2float(__ldg(x + ((ofs + bi) * W + ri)));
const int offset = (n * C + c) * H;
const float tl = __half2float(__ldg(x + ((offset + ti) * W + li)));
const float tr = __half2float(__ldg(x + ((offset + ti) * W + ri)));
const float bl = __half2float(__ldg(x + ((offset + bi) * W + li)));
const float br = __half2float(__ldg(x + ((offset + bi) * W + ri)));
const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u;
......@@ -151,17 +151,17 @@ __global__ void _ResizeLinearNHWC(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li;
const int ofs = n * H;
const int offset = n * H;
#if __CUDA_ARCH__ >= 350
const float tl = __ldg(x + (((ofs + ti) * W + li) * C + c));
const float tr = __ldg(x + (((ofs + ti) * W + ri) * C + c));
const float bl = __ldg(x + (((ofs + bi) * W + li) * C + c));
const float br = __ldg(x + (((ofs + bi) * W + ri) * C + c));
const float tl = __ldg(x + (((offset + ti) * W + li) * C + c));
const float tr = __ldg(x + (((offset + ti) * W + ri) * C + c));
const float bl = __ldg(x + (((offset + bi) * W + li) * C + c));
const float br = __ldg(x + (((offset + bi) * W + ri) * C + c));
#else
const float tl = x[((ofs + ti) * W + li) * C + c];
const float tr = x[((ofs + ti) * W + ri) * C + c];
const float bl = x[((ofs + bi) * W + li) * C + c];
const float br = x[((ofs + bi) * W + ri) * C + c];
const float tl = x[((offset + ti) * W + li) * C + c];
const float tr = x[((offset + ti) * W + ri) * C + c];
const float bl = x[((offset + bi) * W + li) * C + c];
const float br = x[((offset + bi) * W + ri) * C + c];
#endif
const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u;
......@@ -199,11 +199,15 @@ __global__ void _ResizeLinearNHWC<half>(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li;
const int ofs = n * H;
const float tl = __half2float(__ldg(x + (((ofs + ti) * W + li) * C + c)));
const float tr = __half2float(__ldg(x + (((ofs + ti) * W + ri) * C + c)));
const float bl = __half2float(__ldg(x + (((ofs + bi) * W + li) * C + c)));
const float br = __half2float(__ldg(x + (((ofs + bi) * W + ri) * C + c)));
const int offset = n * H;
const float tl =
__half2float(__ldg(x + (((offset + ti) * W + li) * C + c)));
const float tr =
__half2float(__ldg(x + (((offset + ti) * W + ri) * C + c)));
const float bl =
__half2float(__ldg(x + (((offset + bi) * W + li) * C + c)));
const float br =
__half2float(__ldg(x + (((offset + bi) * W + ri) * C + c)));
const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u;
......@@ -249,11 +253,11 @@ __global__ void _ResizeLinearGradNCHW(
const float db = v * ((float)dy[yi]);
#endif
const int ofs = (n * C + c) * H;
atomicAdd(&dx[(ofs + ti) * W + li], (1.f - u) * dt);
atomicAdd(&dx[(ofs + ti) * W + ri], u * dt);
atomicAdd(&dx[(ofs + bi) * W + li], (1.f - u) * db);
atomicAdd(&dx[(ofs + bi) * W + ri], u * db);
const int offset = (n * C + c) * H;
atomicAdd(&dx[(offset + ti) * W + li], (1.f - u) * dt);
atomicAdd(&dx[(offset + ti) * W + ri], u * dt);
atomicAdd(&dx[(offset + bi) * W + li], (1.f - u) * db);
atomicAdd(&dx[(offset + bi) * W + ri], u * db);
}
}
......@@ -290,11 +294,11 @@ __global__ void _ResizeLinearGradNCHW<half>(
const float dt = (1.f - v) * __half2float(__ldg(dy + yi));
const float db = v * __half2float(__ldg(dy + yi));
const int ofs = (n * C + c) * H;
atomicAdd(&dx[(ofs + ti) * W + li], (1.f - u) * dt);
atomicAdd(&dx[(ofs + ti) * W + ri], u * dt);
atomicAdd(&dx[(ofs + bi) * W + li], (1.f - u) * db);
atomicAdd(&dx[(ofs + bi) * W + ri], u * db);
const int offset = (n * C + c) * H;
atomicAdd(&dx[(offset + ti) * W + li], (1.f - u) * dt);
atomicAdd(&dx[(offset + ti) * W + ri], u * dt);
atomicAdd(&dx[(offset + bi) * W + li], (1.f - u) * db);
atomicAdd(&dx[(offset + bi) * W + ri], u * db);
#endif
}
}
......@@ -336,11 +340,11 @@ __global__ void _ResizeLinearGradNHWC(
const float db = v * ((float)dy[yi]);
#endif
const int ofs = n * H;
atomicAdd(&dx[((ofs + ti) * W + li) * C + c], (1.f - u) * dt);
atomicAdd(&dx[((ofs + ti) * W + ri) * C + c], u * dt);
atomicAdd(&dx[((ofs + bi) * W + li) * C + c], (1.f - u) * db);
atomicAdd(&dx[((ofs + bi) * W + ri) * C + c], u * db);
const int offset = n * H;
atomicAdd(&dx[((offset + ti) * W + li) * C + c], (1.f - u) * dt);
atomicAdd(&dx[((offset + ti) * W + ri) * C + c], u * dt);
atomicAdd(&dx[((offset + bi) * W + li) * C + c], (1.f - u) * db);
atomicAdd(&dx[((offset + bi) * W + ri) * C + c], u * db);
}
}
......
......@@ -11,7 +11,6 @@ template <typename T>
void CuDNNEluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(),
act_desc_,
......@@ -33,7 +32,6 @@ template <typename T>
void CuDNNEluGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(),
act_desc_,
......
......@@ -7,7 +7,6 @@ template <class Context>
template <typename T>
void ReluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
if (max_value_ > 0.f) {
kernel::ReluN(
X.count(),
......@@ -34,7 +33,6 @@ template <class Context>
template <typename T>
void ReluGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
if (max_value_ > 0.f) {
kernel::ReluNGrad(
Y.count(),
......
......@@ -9,7 +9,6 @@ template <typename T>
void CuDNNReluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(),
......@@ -47,7 +46,6 @@ template <typename T>
void CuDNNReluGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(),
......
......@@ -9,7 +9,6 @@ template <typename T>
void CuDNNSigmoidOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(),
......@@ -43,7 +42,6 @@ template <typename T>
void CuDNNSigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(),
......
......@@ -9,10 +9,8 @@ template <typename T>
void CuDNNSoftmaxOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CANONICALIZE_AXIS_WITH_TENSOR(X);
CuDNNSetTensorDesc<T>(
&input_desc_, {X.count(0, axis), X.dim(axis), X.count(axis + 1)});
CUDNN_CHECK(cudnnSoftmaxForward(
ctx()->cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE,
......@@ -35,10 +33,8 @@ template <typename T>
void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(Y);
CuDNNSetTensorDesc<T>(
&input_desc_, {Y.count(0, axis), Y.dim(axis), Y.count(axis + 1)});
CUDNN_CHECK(cudnnSoftmaxBackward(
ctx()->cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE,
......
......@@ -9,7 +9,6 @@ template <typename T>
void CuDNNTanhOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(),
......@@ -43,7 +42,6 @@ template <typename T>
void CuDNNTanhGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(),
......
......@@ -64,9 +64,7 @@ void CastOp<Context>::RunOnDevice() {
STORE_INPUT_SPEC(0);
DISPATCH_WITH_TENSOR(Input(0));
} else {
Buffer("X[" + std::to_string(0) + "]")
->ReshapeLike(*Output(0))
->set_meta(Output(0)->meta());
Buffer("X_spec:0")->ReshapeLike(*Output(0))->set_meta(Output(0)->meta());
DISPATCH_WITH_TENSOR((*Output(0)));
};
}
......
......@@ -26,7 +26,9 @@ namespace dragon {
axes_(OpArgs<int64_t>("axes")), \
keep_dims_(OpArg<int64_t>("keep_dims", 0)) {} \
USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \
\
template <typename T> \
void DoRunWithType(); \
\
......@@ -41,7 +43,9 @@ namespace dragon {
public: \
SIMPLE_CTOR_DTOR(name##GradientOp); \
USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \
\
template <typename T> \
void DoRunWithType(); \
};
......
......@@ -15,9 +15,8 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() {
auto inner_dim = X.count(axis + 1);
auto num_preds = outer_dim * inner_dim;
CHECK_EQ(num_preds, Input(1).count())
CHECK_EQ(X.count(), Input(1).count())
<< "\nNumber of preds must match the number of targets.";
Buffer("prob")->ReshapeLike(X);
auto* loss = ws()->template data<T, Context>({X.count()})[0];
......
......@@ -17,6 +17,8 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
CHECK_EQ(num_preds, Input(1).count())
<< "\nNumber of preds must match the number of targets.";
auto* X_prob = Buffer("prob")->ReshapeLike(X);
auto* prob = X_prob->template mutable_data<LogitType, Context>();
auto scratches = ws()->template data<Context>({
num_preds * sizeof(LogitType), // loss
......@@ -25,10 +27,6 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
auto* loss = static_cast<LogitType*>(scratches[0]);
auto* mask = static_cast<int*>(scratches[1]);
auto* prob = Buffer("prob")
->ReshapeLike(X)
->template mutable_data<LogitType, Context>();
kernel::Softmax(
outer_dim,
X.dim(axis),
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MATH_ACCUMULATE_OP_H_
#define DRAGON_OPERATORS_MATH_ACCUMULATE_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class AccumulateOp final : public Operator<Context> {
public:
AccumulateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OpArg<float>("alpha", 1.f)),
beta_(OpArg<float>("beta", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType(Tensor* X, Tensor* Y);
protected:
float alpha_, beta_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MATH_ACCUMULATE_OP_H_
#include "dragon/core/workspace.h"
#include "dragon/operators/math/accumulate_op.h"
#include "dragon/operators/math/elementwise_ops.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
template <class Context>
template <typename T>
void AccumulateOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) {
void AxpbyOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) {
CHECK_EQ(X->count(), Y->count());
auto* x = X->template data<T, Context>();
auto* y = Y->template mutable_data<T, Context>();
......@@ -26,41 +26,42 @@ void AccumulateOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) {
}
template <class Context>
void AccumulateOp<Context>::RunOnDevice() {
void AxpbyOp<Context>::RunOnDevice() {
for (int i = 0; i < InputSize(); i++) {
Output(i)->ReshapeLike(Input(i));
if (XIsType(Input(i), int8_t)) {
DoRunWithType<int8_t>(&Input(i), Output(i));
} else if (XIsType(Input(i), uint8_t)) {
DoRunWithType<uint8_t>(&Input(i), Output(i));
} else if (XIsType(Input(i), int)) {
DoRunWithType<int>(&Input(i), Output(i));
} else if (XIsType(Input(i), int64_t)) {
DoRunWithType<int64_t>(&Input(i), Output(i));
} else if (XIsType(Input(i), float16)) {
DoRunWithType<float16>(&Input(i), Output(i));
} else if (XIsType(Input(i), float)) {
DoRunWithType<float>(&Input(i), Output(i));
} else if (XIsType(Input(i), double)) {
DoRunWithType<double>(&Input(i), Output(i));
auto &X = Input(i), *Y = Output(i);
Y->ReshapeLike(X);
if (XIsType(X, int8_t)) {
DoRunWithType<int8_t>(&X, Y);
} else if (XIsType(X, uint8_t)) {
DoRunWithType<uint8_t>(&X, Y);
} else if (XIsType(X, int)) {
DoRunWithType<int>(&X, Y);
} else if (XIsType(X, int64_t)) {
DoRunWithType<int64_t>(&X, Y);
} else if (XIsType(X, float16)) {
DoRunWithType<float16>(&X, Y);
} else if (XIsType(X, float)) {
DoRunWithType<float>(&X, Y);
} else if (XIsType(X, double)) {
DoRunWithType<double>(&X, Y);
} else
LOG(FATAL) << TypeString(
Input(i),
X,
{"int8", "uint8", "int32", "int64", "float16", "float32", "float64"});
}
}
DEPLOY_CPU(Accumulate);
DEPLOY_CPU(Axpby);
#ifdef USE_CUDA
DEPLOY_CUDA(Accumulate);
DEPLOY_CUDA(Axpby);
#endif
OPERATOR_SCHEMA(Accumulate)
OPERATOR_SCHEMA(Axpby)
/* X1, ... */
.NumInputs(1, INT_MAX)
/* Y1, ... */
.NumOutputs(1, INT_MAX);
NO_GRADIENT(Accumulate);
NO_GRADIENT(Axpby);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MATH_DOT_OP_H_
#define DRAGON_OPERATORS_MATH_DOT_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class DotOp final : public Operator<Context> {
public:
DotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
transA_(OpArg<bool>("transA", false)),
transB_(OpArg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DotImpl();
template <typename T>
void GemmImpl();
template <typename T>
void GemvImpl();
template <typename T>
void DoRunWithType();
protected:
int64_t transA_, transB_;
int64_t M_, K1_, K2_, N_;
int64_t M1_, N1_, M2_, N2_;
};
template <class Context>
class DotGradientOp final : public Operator<Context> {
public:
DotGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
transA_(OpArg<bool>("transA", false)),
transB_(OpArg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DotImpl();
template <typename T>
void GemmImpl();
template <typename T>
void GemvImpl();
template <typename T>
void DoRunWithType();
protected:
int64_t transA_, transB_;
int64_t M_, K1_, K2_, N_;
int64_t M1_, N1_, M2_, N2_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MATH_DOT_OP_H_
......@@ -18,7 +18,7 @@
namespace dragon {
#define DECLARE_SIMPLE_UNARY_OP(name) \
#define DECLARE_ELEMENTWISE_OP(name) \
template <class Context> \
class name##Op final : public Operator<Context> { \
public: \
......@@ -31,18 +31,23 @@ namespace dragon {
void DoRunWithType(); \
};
#define DECLARE_SIMPLE_BINARY_OP(name) \
template <class Context> \
class name##Op final : public Operator<Context> { \
public: \
SIMPLE_CTOR_DTOR(name##Op); \
USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \
\
template <typename T> \
void DoRunWithType(); \
};
template <class Context>
class AxpbyOp final : public Operator<Context> {
public:
AxpbyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OpArg<float>("alpha", 1.f)),
beta_(OpArg<float>("beta", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType(Tensor* X, Tensor* Y);
protected:
float alpha_, beta_;
};
inline vec32_t CheckOutputAliases(
const Tensor& A,
......@@ -64,87 +69,61 @@ inline vec32_t CheckOutputAliases(
return available_aliases;
}
inline void IsBroadcast(
const Tensor& A,
const Tensor& B,
int& rows,
int& cols,
int& kind,
Tensor* Y = nullptr) {
kind = -2;
if (A.count() == B.count()) {
if (Y != nullptr) Y->ReshapeLike(A);
kind = -1;
} else if (B.count() < A.count()) {
if (Y != nullptr) Y->ReshapeLike(A);
if (utils::math::IsRowwiseBroadcast(A.dims(), B.dims(), &rows, &cols)) {
kind = 0;
} else if (utils::math::IsColwiseBroadcast(
A.dims(), B.dims(), &rows, &cols)) {
kind = 1;
}
} else {
if (Y != nullptr) Y->ReshapeLike(B);
if (utils::math::IsRowwiseBroadcast(A.dims(), B.dims(), &rows, &cols)) {
kind = 2;
} else if (utils::math::IsColwiseBroadcast(
A.dims(), B.dims(), &rows, &cols)) {
kind = 3;
}
}
}
// Unary ElementwiseOp
DECLARE_ELEMENTWISE_OP(Abs);
DECLARE_ELEMENTWISE_OP(Ceil);
DECLARE_ELEMENTWISE_OP(Cos);
DECLARE_ELEMENTWISE_OP(Exp);
DECLARE_ELEMENTWISE_OP(Floor);
DECLARE_ELEMENTWISE_OP(IsInf);
DECLARE_ELEMENTWISE_OP(IsNaN);
DECLARE_ELEMENTWISE_OP(Log);
DECLARE_ELEMENTWISE_OP(Neg);
DECLARE_ELEMENTWISE_OP(Invert);
DECLARE_ELEMENTWISE_OP(Reciprocal);
DECLARE_ELEMENTWISE_OP(Round);
DECLARE_ELEMENTWISE_OP(Rsqrt);
DECLARE_ELEMENTWISE_OP(Sign);
DECLARE_ELEMENTWISE_OP(Sin);
DECLARE_ELEMENTWISE_OP(Sqrt);
DECLARE_ELEMENTWISE_OP(Square);
DECLARE_ELEMENTWISE_OP(AbsGradient);
DECLARE_ELEMENTWISE_OP(CosGradient);
DECLARE_ELEMENTWISE_OP(ExpGradient);
DECLARE_ELEMENTWISE_OP(LogGradient);
DECLARE_ELEMENTWISE_OP(NegGradient);
DECLARE_ELEMENTWISE_OP(ReciprocalGradient);
DECLARE_ELEMENTWISE_OP(RsqrtGradient);
DECLARE_ELEMENTWISE_OP(SignGradient);
DECLARE_ELEMENTWISE_OP(SinGradient);
DECLARE_ELEMENTWISE_OP(SqrtGradient);
DECLARE_ELEMENTWISE_OP(SquareGradient);
DECLARE_SIMPLE_UNARY_OP(Abs);
DECLARE_SIMPLE_UNARY_OP(Ceil);
DECLARE_SIMPLE_UNARY_OP(Cos);
DECLARE_SIMPLE_UNARY_OP(Exp);
DECLARE_SIMPLE_UNARY_OP(Floor);
DECLARE_SIMPLE_UNARY_OP(IsInf);
DECLARE_SIMPLE_UNARY_OP(IsNaN);
DECLARE_SIMPLE_UNARY_OP(Log);
DECLARE_SIMPLE_UNARY_OP(Neg);
DECLARE_SIMPLE_UNARY_OP(Invert);
DECLARE_SIMPLE_UNARY_OP(Reciprocal);
DECLARE_SIMPLE_UNARY_OP(Round);
DECLARE_SIMPLE_UNARY_OP(Rsqrt);
DECLARE_SIMPLE_UNARY_OP(Sign);
DECLARE_SIMPLE_UNARY_OP(Sin);
DECLARE_SIMPLE_UNARY_OP(Sqrt);
DECLARE_SIMPLE_UNARY_OP(Square);
DECLARE_SIMPLE_UNARY_OP(AbsGradient);
DECLARE_SIMPLE_UNARY_OP(CosGradient);
DECLARE_SIMPLE_UNARY_OP(ExpGradient);
DECLARE_SIMPLE_UNARY_OP(LogGradient);
DECLARE_SIMPLE_UNARY_OP(NegGradient);
DECLARE_SIMPLE_UNARY_OP(ReciprocalGradient);
DECLARE_SIMPLE_UNARY_OP(RsqrtGradient);
DECLARE_SIMPLE_UNARY_OP(SignGradient);
DECLARE_SIMPLE_UNARY_OP(SinGradient);
DECLARE_SIMPLE_UNARY_OP(SqrtGradient);
DECLARE_SIMPLE_UNARY_OP(SquareGradient);
#undef DECLARE_SIMPLE_UNARY_OP
// Binary ElementwiseOp
DECLARE_ELEMENTWISE_OP(Add);
DECLARE_ELEMENTWISE_OP(Sub);
DECLARE_ELEMENTWISE_OP(Mul);
DECLARE_ELEMENTWISE_OP(Div);
DECLARE_ELEMENTWISE_OP(Pow);
DECLARE_ELEMENTWISE_OP(Dot);
DECLARE_ELEMENTWISE_OP(Minimum);
DECLARE_ELEMENTWISE_OP(Maximum);
DECLARE_ELEMENTWISE_OP(Equal);
DECLARE_ELEMENTWISE_OP(NotEqual);
DECLARE_ELEMENTWISE_OP(Less);
DECLARE_ELEMENTWISE_OP(LessEqual);
DECLARE_ELEMENTWISE_OP(Greater);
DECLARE_ELEMENTWISE_OP(GreaterEqual);
DECLARE_ELEMENTWISE_OP(AddGradient);
DECLARE_ELEMENTWISE_OP(SubGradient);
DECLARE_ELEMENTWISE_OP(MulGradient);
DECLARE_ELEMENTWISE_OP(DivGradient);
DECLARE_ELEMENTWISE_OP(PowGradient);
DECLARE_ELEMENTWISE_OP(DotGradient);
DECLARE_ELEMENTWISE_OP(MinimumGradient);
DECLARE_ELEMENTWISE_OP(MaximumGradient);
DECLARE_SIMPLE_BINARY_OP(Add);
DECLARE_SIMPLE_BINARY_OP(Sub);
DECLARE_SIMPLE_BINARY_OP(Mul);
DECLARE_SIMPLE_BINARY_OP(Div);
DECLARE_SIMPLE_BINARY_OP(Pow);
DECLARE_SIMPLE_BINARY_OP(Minimum);
DECLARE_SIMPLE_BINARY_OP(Maximum);
DECLARE_SIMPLE_BINARY_OP(Equal);
DECLARE_SIMPLE_BINARY_OP(NotEqual);
DECLARE_SIMPLE_BINARY_OP(Less);
DECLARE_SIMPLE_BINARY_OP(LessEqual);
DECLARE_SIMPLE_BINARY_OP(Greater);
DECLARE_SIMPLE_BINARY_OP(GreaterEqual);
DECLARE_SIMPLE_BINARY_OP(AddGradient);
DECLARE_SIMPLE_BINARY_OP(SubGradient);
DECLARE_SIMPLE_BINARY_OP(MulGradient);
DECLARE_SIMPLE_BINARY_OP(DivGradient);
DECLARE_SIMPLE_BINARY_OP(PowGradient);
DECLARE_SIMPLE_BINARY_OP(MinimumGradient);
DECLARE_SIMPLE_BINARY_OP(MaximumGradient);
#undef DECLARE_SIMPLE_BINARY_OP
#undef DECLARE_ELEMENTWISE_OP
} // namespace dragon
......
......@@ -13,14 +13,14 @@ void FullyConnectedOp<Context>::DoRunWithType() {
// Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N;
if (num_output_ <= 0) {
if (out_channels_ <= 0) {
// Infer the "N" from the weights shape
N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString();
} else {
// Use a fixed "N" from the argument
N = num_output_;
N = out_channels_;
}
vec64_t Y_dims(axis + 1);
......@@ -82,14 +82,14 @@ void FullyConnectedGradientOp<Context>::DoRunWithType() {
// Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N;
if (num_output_ <= 0) {
if (out_channels_ <= 0) {
// Infer the "N" from the weights shape
N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString();
} else {
// Use a fixed "N" from the argument
N = num_output_;
N = out_channels_;
}
if (dX->has_name()) {
......
......@@ -22,7 +22,7 @@ class FullyConnectedOp final : public Operator<Context> {
public:
FullyConnectedOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
num_output_(OpArg<int64_t>("num_output", 0)),
out_channels_(OpArg<int64_t>("out_channels", 0)),
transW_(OpArg<int64_t>("transW", 1)) {}
USE_OPERATOR_FUNCTIONS;
......@@ -32,7 +32,7 @@ class FullyConnectedOp final : public Operator<Context> {
void DoRunWithType();
protected:
int64_t num_output_, transW_;
int64_t out_channels_, transW_;
};
template <class Context>
......@@ -40,7 +40,7 @@ class FullyConnectedGradientOp final : public Operator<Context> {
public:
FullyConnectedGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
num_output_(OpArg<int64_t>("num_output", 0)),
out_channels_(OpArg<int64_t>("out_channels", 0)),
transW_(OpArg<int64_t>("transW", 1)) {}
USE_OPERATOR_FUNCTIONS;
......@@ -50,7 +50,7 @@ class FullyConnectedGradientOp final : public Operator<Context> {
void DoRunWithType();
protected:
int64_t num_output_, transW_;
int64_t out_channels_, transW_;
};
} // namespace dragon
......
......@@ -5,7 +5,7 @@ namespace dragon {
template <class Context>
template <typename T>
void MatmulOp<Context>::DoRunWithType() {
void MatMulOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), *Y = Output(0);
CHECK_GE(A.ndim(), 2) << "\nTensor(" << A.name() + ") must be a matrix"
......@@ -51,13 +51,13 @@ void MatmulOp<Context>::DoRunWithType() {
}
template <class Context>
void MatmulOp<Context>::RunOnDevice() {
void MatMulOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void MatmulGradientOp<Context>::DoRunWithType() {
void MatMulGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(2);
auto *dA = Output(0), *dB = Output(1);
......@@ -154,32 +154,32 @@ void MatmulGradientOp<Context>::DoRunWithType() {
}
template <class Context>
void MatmulGradientOp<Context>::RunOnDevice() {
void MatMulGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Matmul);
DEPLOY_CPU(MatMul);
#ifdef USE_CUDA
DEPLOY_CUDA(Matmul);
DEPLOY_CUDA(MatMul);
#endif
DEPLOY_CPU(MatmulGradient);
DEPLOY_CPU(MatMulGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(MatmulGradient);
DEPLOY_CUDA(MatMulGradient);
#endif
OPERATOR_SCHEMA(Matmul)
OPERATOR_SCHEMA(MatMul)
/* A, B */
.NumInputs(2)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(MatmulGradient)
OPERATOR_SCHEMA(MatMulGradient)
/* A, B, dY */
.NumInputs(3)
/* dA, dB */
.NumOutputs(2);
REGISTER_GRADIENT(Matmul, GenericGradientMaker);
REGISTER_GRADIENT(MatMul, GenericGradientMaker);
} // namespace dragon
......@@ -18,9 +18,9 @@
namespace dragon {
template <class Context>
class MatmulOp final : public Operator<Context> {
class MatMulOp final : public Operator<Context> {
public:
MatmulOp(const OperatorDef& def, Workspace* ws)
MatMulOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
transA_(OpArg<int64_t>("transA", 0)),
transB_(OpArg<int64_t>("transB", 0)) {}
......@@ -36,9 +36,9 @@ class MatmulOp final : public Operator<Context> {
};
template <class Context>
class MatmulGradientOp final : public Operator<Context> {
class MatMulGradientOp final : public Operator<Context> {
public:
MatmulGradientOp(const OperatorDef& def, Workspace* ws)
MatMulGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
transA_(OpArg<int64_t>("transA", 0)),
transB_(OpArg<int64_t>("transB", 0)) {}
......
......@@ -94,8 +94,8 @@ void GroupNormGradientOp<Context>::DoRunWithType() {
template <class Context>
void GroupNormGradientOp<Context>::RunOnDevice() {
DetermineBaseArguments();
Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) {
DoRunWithType<float, float>();
} else if (XIsType(Input(0), float16)) {
......
......@@ -42,9 +42,6 @@ class GroupNormOpBase : public Operator<Context> {
// Check the channels and groups
CHECK_EQ(C_ % G_, 0) << "\nThe " << C_ << " channels "
<< "can not be split into " << G_ << " groups.";
if (G_ == C_ && X.ndim() == 2) {
LOG(WARNING) << "The 2d input will output all zeros.";
}
}
protected:
......
#include "dragon/operators/training/adam_update_op.h"
#include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
void AdamUpdateOp<Context>::Compute(Tensor* dX) {
auto* m = ws()->CreateTensor("/mnt/" + slot() + "/m")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
auto* v = ws()->CreateTensor("/mnt/" + slot() + "/v")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
void AdamUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
t_++;
auto beta1 = param("beta1");
auto beta2 = param("beta2");
auto beta1 = Parameter("beta1"), beta2 = Parameter("beta2");
auto coef = sqrt(1.f - pow(beta2, t_)) / (1.f - pow(beta1, t_));
kernel::AdamUpdate(
dX->count(),
param("base_lr") * coef * lr_mult(),
Parameter("base_lr") * coef * this->lr_mult_,
beta1,
beta2,
param("eps"),
Parameter("eps"),
dX->template mutable_data<float, Context>(),
m,
v,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx());
}
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_ADAM_UPDATE_OP_H_
#define DRAGON_OPERATORS_TRAINING_ADAM_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
namespace dragon {
template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), t_(0) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void Compute(Tensor* dX) override;
protected:
int t_;
// float lr_, beta1_, beta2_, eps_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_ADAM_UPDATE_OP_H_
#include "dragon/operators/training/nesterov_update_op.h"
#include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
void NesterovUpdateOp<Context>::Compute(Tensor* dX) {
auto* m = ws()->CreateTensor("/mnt/" + slot() + "/m")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
void NesterovUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
kernel::NesterovUpdate(
dX->count(),
param("base_lr") * lr_mult(),
param("momentum"),
Parameter("base_lr") * this->lr_mult_,
Parameter("momentum"),
dX->template mutable_data<float, Context>(),
m,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx());
}
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_TRAINING_NESTEROV_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
namespace dragon {
template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void Compute(Tensor* dX) override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_NESTEROV_UPDATE_OP_H_
#include "dragon/operators/training/rmsprop_update_op.h"
#include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
void RMSPropUpdateOp<Context>::Compute(Tensor* dX) {
auto* m = ws()->CreateTensor("/mnt/" + slot() + "/m")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
auto* v = ws()->CreateTensor("/mnt/" + slot() + "/v")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
void RMSpropUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
kernel::RMSPropUpdate(
dX->count(),
param("base_lr") * lr_mult(),
param("momentum"),
param("decay"),
param("eps"),
Parameter("base_lr") * this->lr_mult_,
Parameter("momentum"),
Parameter("decay"),
Parameter("eps"),
dX->template mutable_data<float, Context>(),
m,
v,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx());
}
DEPLOY_CPU(RMSPropUpdate);
DEPLOY_CPU(RMSpropUpdate);
#ifdef USE_CUDA
DEPLOY_CUDA(RMSPropUpdate);
DEPLOY_CUDA(RMSpropUpdate);
#endif
OPERATOR_SCHEMA(RMSPropUpdate)
OPERATOR_SCHEMA(RMSpropUpdate)
/* dX */
.NumInputs(1)
/* X */
.NumOutputs(1);
NO_GRADIENT(RMSPropUpdate);
NO_GRADIENT(RMSpropUpdate);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_RMSPROP_UPDATE_OP_H_
#define DRAGON_OPERATORS_TRAINING_RMSPROP_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
namespace dragon {
template <class Context>
class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSPropUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void Compute(Tensor* dX) override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_RMSPROP_UPDATE_OP_H_
#include "dragon/operators/training/sgd_update_op.h"
#include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
void SGDUpdateOp<Context>::Compute(Tensor* dX) {
auto* m = ws()->CreateTensor("/mnt/" + slot() + "/m")
->ReshapeLike(*dX)
->template mutable_data<float, Context>();
void SGDUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
// Momentum Correction, See arXiv:1706.02677
auto lr = param("base_lr") * lr_mult();
auto lr = Parameter("base_lr") * this->lr_mult_;
if (last_lr_ > 0) correction_ = lr / last_lr_;
last_lr_ = lr; // Record the last value
kernel::SGDUpdate(
dX->count(),
lr,
param("momentum") * correction_,
Parameter("momentum") * correction_,
dX->template mutable_data<float, Context>(),
m,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx());
}
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_SGD_UPDATE_OP_H_
#define DRAGON_OPERATORS_TRAINING_SGD_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
namespace dragon {
template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> {
public:
SGDUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), last_lr_(-1.f), correction_(1.f) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void Compute(Tensor* dX) override;
protected:
float last_lr_, correction_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_SGD_UPDATE_OP_H_
#include "dragon/operators/training/update_op_base.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/cast.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
float UpdateOpBase<Context>::param(const string& name) const {
return ws()
->GetTensor(slot_ + "/" + name)
->template mutable_data<float, CPUContext>()[0];
Tensor* UpdateOpBase<Context>::Slot(const string& name) {
return Buffer(Output(0)->name() + "/" + name);
}
template <class Context>
float UpdateOpBase<Context>::Parameter(const string& name) const {
auto* P = ws()->GetTensor("/share/hyper/" + handle() + "/" + name);
return P->template mutable_data<float, CPUContext>()[0];
}
template <class Context>
template <typename T>
void UpdateOpBase<Context>::Process(Tensor* dX, Tensor* X) {
void UpdateOpBase<Context>::AdjustGradient(Tensor* dX, Tensor* X) {
// Scale
auto scale_factor = param("scale_gradient");
if (scale_factor != 1.f) {
auto scale = Parameter("scale");
if (scale != 1.f) {
auto* dx = dX->template mutable_data<T, Context>();
math::Scale(dX->count(), scale_factor, dx, dx, ctx());
math::Scale(dX->count(), scale, dx, dx, ctx());
}
// Clip
auto clip_thresh = param("clip_gradient");
if (clip_thresh > 0.f) {
T sumsq_grad;
auto clip_norm = Parameter("clip_norm");
if (clip_norm > 0.f) {
auto* dx = dX->template mutable_data<T, Context>();
math::Dot(dX->count(), dx, dx, &sumsq_grad, ctx());
auto l2_norm = sqrt(cast::to<float>(sumsq_grad));
if (l2_norm > clip_thresh) {
math::Scale(dX->count(), clip_thresh / l2_norm, dx, dx, ctx());
auto grad_norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx()));
if (grad_norm > clip_norm) {
math::Scale(dX->count(), clip_norm / grad_norm, dx, dx, ctx());
}
}
// L2 Decay
auto l2_decay = param("l2_decay") * decay_mult_;
if (l2_decay > 0) {
// Penalty
auto weight_decay = Parameter("weight_decay");
if (weight_decay > 0.f) {
if (XIsType((*X), float16)) {
kernel::MixedPrecL2Decay(
kernel::MixedPrecL2Penalty(
X->count(),
l2_decay,
weight_decay * decay_mult_,
X->template data<float16, Context>(),
dX->template mutable_data<float, Context>(),
ctx());
} else {
math::Axpy(
X->count(),
l2_decay,
weight_decay * decay_mult_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
ctx());
......@@ -56,7 +57,7 @@ void UpdateOpBase<Context>::Process(Tensor* dX, Tensor* X) {
template <class Context>
template <typename T>
void UpdateOpBase<Context>::Apply(Tensor* dX, Tensor* X) {
void UpdateOpBase<Context>::ApplyUpdate(Tensor* dX, Tensor* X) {
if (XIsType((*X), float16)) {
kernel::MixedPrecUpdate(
X->count(),
......@@ -64,9 +65,9 @@ void UpdateOpBase<Context>::Apply(Tensor* dX, Tensor* X) {
X->template mutable_data<float16, Context>(),
ctx());
} else {
math::Axpy(
math::Sub(
X->count(),
-1.f,
X->template data<T, Context>(),
dX->template data<T, Context>(),
X->template mutable_data<T, Context>(),
ctx());
......@@ -85,19 +86,19 @@ void UpdateOpBase<Context>::RunOnDevice() {
<< "\nGot" << X->DimString() << " and " << dX.DimString();
if (XIsType(dX, float)) {
Process<float>(&dX, X);
Compute(&dX);
Apply<float>(&dX, X);
AdjustGradient<float>(&dX, X);
ComputeUpdate(&dX);
ApplyUpdate<float>(&dX, X);
} else if (XIsType(dX, float16)) {
auto* dX_fp32 = ws()->CreateTensor(dX.name() + "/fp32");
auto* dX_cast = ws()->CreateTensor(dX.name() + "[float32]");
kernel::Cast(
dX.count(),
dX.template data<float16, Context>(),
dX_fp32->ReshapeLike(dX)->template mutable_data<float, Context>(),
dX_cast->ReshapeLike(dX)->template mutable_data<float, Context>(),
ctx());
Process<float>(dX_fp32, X);
Compute(dX_fp32);
Apply<float>(dX_fp32, X);
AdjustGradient<float>(dX_cast, X);
ComputeUpdate(dX_cast);
ApplyUpdate<float>(dX_cast, X);
} else {
LOG(FATAL) << TypeString(dX, {"float16", "float32"});
}
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_UPDATE_OP_BASE_H_
#define DRAGON_OPERATORS_TRAINING_UPDATE_OP_BASE_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class UpdateOpBase : public Operator<Context> {
public:
UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
lr_mult_(OpArg<float>("lr_mult", 1.f)),
decay_mult_(OpArg<float>("decay_mult", 1.f)),
slot_(OpArg<string>("slot", "")) {
CHECK(!slot_.empty()) << "\nRequired a non-empty slot";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
virtual void Compute(Tensor* dX) = 0;
template <typename T>
void Process(Tensor* dX, Tensor* X);
template <typename T>
void Apply(Tensor* dX, Tensor* X);
string slot() {
return slot_ + "/" + Output(0)->name();
}
float param(const string& name) const;
float lr_mult() const {
return lr_mult_;
}
protected:
string slot_;
float lr_mult_, decay_mult_;
};
#define USE_PARAM_UPDATE_FUNCTIONS \
using UpdateOpBase<Context>::slot; \
using UpdateOpBase<Context>::param; \
using UpdateOpBase<Context>::lr_mult
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_UPDATE_OP_BASE_H_
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_TRAINING_UPDATE_OPS_H_
#define DRAGON_OPERATORS_TRAINING_UPDATE_OPS_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class UpdateOpBase : public Operator<Context> {
public:
UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
lr_mult_(OpArg<float>("lr_mult", 1.f)),
decay_mult_(OpArg<float>("decay_mult", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
virtual void ComputeUpdate(Tensor* dX) = 0;
template <typename T>
void AdjustGradient(Tensor* dX, Tensor* X);
template <typename T>
void ApplyUpdate(Tensor* dX, Tensor* X);
Tensor* Slot(const string& name);
float Parameter(const string& name) const;
protected:
float lr_mult_, decay_mult_;
};
#define USE_PARAM_UPDATE_FUNCTIONS \
using UpdateOpBase<Context>::Slot; \
using UpdateOpBase<Context>::Parameter
template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> {
public:
SGDUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), last_lr_(-1.f), correction_(1.f) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void ComputeUpdate(Tensor* dX) override;
protected:
float last_lr_, correction_;
};
template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void ComputeUpdate(Tensor* dX) override;
};
template <class Context>
class RMSpropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSpropUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void ComputeUpdate(Tensor* dX) override;
};
template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), t_(0) {}
USE_OPERATOR_FUNCTIONS;
USE_PARAM_UPDATE_FUNCTIONS;
void ComputeUpdate(Tensor* dX) override;
protected:
int t_;
};
#undef USE_PARAM_UPDATE_FUNCTIONS
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_UPDATE_OPS_H_
......@@ -59,9 +59,9 @@ void BiasAddGradientOp<Context>::DoRunWithType() {
dB->Reshape({dY.dim(-1)});
}
math::ReduceSum(
3,
dims.size(),
dims.data(),
2,
axes.size(),
axes.data(),
1.f,
dY.template data<T, Context>(),
......
......@@ -16,7 +16,7 @@ void Conv2dOp<Context>::DoRunWithType() {
auto* y = Y->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) {
Wx(x + i * x_ofs_, w, y + i * y_ofs_);
Wx(x + i * x_offset_, w, y + i * y_offset_);
}
if (HasBias()) {
......@@ -46,7 +46,7 @@ void Conv2dGradientOp<Context>::DoRunWithType() {
auto* w = W.template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) {
Dx(dy + i * y_ofs_, w, dx + i * x_ofs_);
Dx(dy + i * y_offset_, w, dx + i * x_offset_);
}
}
......@@ -55,7 +55,7 @@ void Conv2dGradientOp<Context>::DoRunWithType() {
auto* x = X.template data<T, Context>();
auto* dw = dW->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) {
Dw(dy + i * y_ofs_, x + i * x_ofs_, dw, i > 0);
Dw(dy + i * y_offset_, x + i * x_offset_, dw, i > 0);
}
}
......
......@@ -73,8 +73,8 @@ void CuDNNConv2dOp<Context>::ResetDesc() {
filter_desc_,
CuDNNType<T>::type,
format_,
num_output_ / cudnn_group_,
channels_ / group_,
out_channels_ / cudnn_group_,
in_channels_ / group_,
kshape_[0],
kshape_[1]));
#else
......@@ -82,14 +82,15 @@ void CuDNNConv2dOp<Context>::ResetDesc() {
filter_desc_,
CuDNNType<T>::type,
format_,
num_output_ / cudnn_group_,
channels_ / group_,
out_channels_ / cudnn_group_,
in_channels_ / group_,
kshape_[0],
kshape_[1]));
#endif
// Determine the bias shape
if (HasBias()) {
CuDNNSetBiasDesc<T>(&bias_desc_, X.ndim(), num_output_, data_format());
CuDNNSetBiasDesc<T>(
&bias_desc_, X.ndim(), out_channels_, data_format());
}
}
// Set the conv configuration
......@@ -179,16 +180,16 @@ void CuDNNConv2dOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(),
CuDNNType<T>::one,
input_desc_,
x + x_ofs_ * g,
x + x_offset_ * g,
filter_desc_,
w + w_ofs_ * g,
w + w_offset_ * g,
conv_desc_,
fwd_algo_,
scratch,
cudnn_ws_nbytes_,
CuDNNType<T>::zero,
output_desc_,
y + y_ofs_ * g));
y + y_offset_ * g));
}
if (HasBias()) {
......@@ -217,11 +218,11 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape();
if (data_format() == "NCHW") {
x_ofs_ = Input(0).stride(0) / cudnn_group_;
y_ofs_ = Output(0)->stride(0) / cudnn_group_;
x_offset_ = Input(0).stride(0) / cudnn_group_;
y_offset_ = Output(0)->stride(0) / cudnn_group_;
} else if (data_format() == "NHWC") {
x_ofs_ = Input(0).dim(-1) / cudnn_group_;
y_ofs_ = Output(0)->dim(-1) / cudnn_group_;
x_offset_ = Input(0).dim(-1) / cudnn_group_;
y_offset_ = Output(0)->dim(-1) / cudnn_group_;
}
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
......@@ -294,8 +295,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
filter_desc_,
CuDNNType<T>::type,
format_,
num_output_ / cudnn_group_,
channels_ / group_,
out_channels_ / cudnn_group_,
in_channels_ / group_,
kshape_[0],
kshape_[1]));
#else
......@@ -303,14 +304,15 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
filter_desc_,
CuDNNType<T>::type,
format_,
num_output_ / cudnn_group_,
channels_ / group_,
out_channels_ / cudnn_group_,
in_channels_ / group_,
kshape_[0],
kshape_[1]));
#endif
// Determine the bias shape
if (HasBias()) {
CuDNNSetBiasDesc<T>(&bias_desc_, X.ndim(), num_output_, data_format());
CuDNNSetBiasDesc<T>(
&bias_desc_, X.ndim(), out_channels_, data_format());
}
}
// Set the conv configuration
......@@ -470,16 +472,16 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(),
CuDNNType<T>::one,
output_desc_,
x + x_ofs_ * g,
x + x_offset_ * g,
input_desc_,
dy + y_ofs_ * g,
dy + y_offset_ * g,
conv_desc_,
bwd_filter_algo_,
scratch,
cudnn_ws_nbytes_,
CuDNNType<T>::zero,
filter_desc_,
dw + w_ofs_ * g));
dw + w_offset_ * g));
}
}
......@@ -491,16 +493,16 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(),
CuDNNType<T>::one,
filter_desc_,
w + w_ofs_ * g,
w + w_offset_ * g,
input_desc_,
dy + y_ofs_ * g,
dy + y_offset_ * g,
conv_desc_,
bwd_data_algo_,
scratch,
cudnn_ws_nbytes_,
CuDNNType<T>::zero,
output_desc_,
dx + x_ofs_ * g));
dx + x_offset_ * g));
}
}
}
......@@ -518,11 +520,11 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(true);
if (data_format() == "NCHW") {
x_ofs_ = Input(0).stride(0) / cudnn_group_;
y_ofs_ = Input(-1).stride(0) / cudnn_group_;
x_offset_ = Input(0).stride(0) / cudnn_group_;
y_offset_ = Input(-1).stride(0) / cudnn_group_;
} else if (data_format() == "NHWC") {
x_ofs_ = Input(0).dim(-1) / cudnn_group_;
y_ofs_ = Input(-1).dim(-1) / cudnn_group_;
x_offset_ = Input(0).dim(-1) / cudnn_group_;
y_offset_ = Input(-1).dim(-1) / cudnn_group_;
}
DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1));
......
......@@ -8,12 +8,7 @@ template <class Context>
template <typename T>
void ConvTranspose2dOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0);
ConvOpBase<Context>::Reshape();
// Fix the output shape for im2col/col2im
for (int i = 0; i < num_axes_; i++) {
out_shape_[i] = X.dim(axis_ + i);
}
TENSOR_FILL(W, w_shape_);
auto* x = X.template data<T, Context>();
......@@ -21,7 +16,7 @@ void ConvTranspose2dOp<Context>::DoRunWithType() {
auto* y = Y->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) {
Dx(x + i * x_ofs_, w, y + i * y_ofs_);
Dx(x + i * x_offset_, w, y + i * y_offset_);
}
if (HasBias()) {
......@@ -44,19 +39,14 @@ template <typename T>
void ConvTranspose2dGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
ConvOpBase<Context>::Reshape(true);
// Fix the output shape for im2col/col2im
for (int i = 0; i < num_axes_; i++) {
out_shape_[i] = X.dim(axis_ + i);
}
if (dX->has_name()) {
auto* dy = dY.template data<T, Context>();
auto* w = W.template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) {
Wx(dy + i * y_ofs_, w, dx + i * x_ofs_);
Wx(dy + i * y_offset_, w, dx + i * x_offset_);
}
}
......@@ -65,7 +55,7 @@ void ConvTranspose2dGradientOp<Context>::DoRunWithType() {
auto* dy = dY.template data<T, Context>();
auto* dw = dW->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) {
Dw(x + i * x_ofs_, dy + i * y_ofs_, dw, i > 0);
Dw(x + i * x_offset_, dy + i * y_offset_, dw, i > 0);
}
}
......
......@@ -71,8 +71,8 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() {
filter_desc_,
CuDNNType<T>::type,
format_,
channels_ / cudnn_group_,
num_output_ / group_,
in_channels_ / cudnn_group_,
out_channels_ / group_,
kshape_[0],
kshape_[1]));
#else
......@@ -80,14 +80,15 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() {
filter_desc_,
CuDNNType<T>::type,
format_,
channels_ / cudnn_group_,
num_output_ / group_,
in_channels_ / cudnn_group_,
out_channels_ / group_,
kshape_[0],
kshape_[1]));
#endif
// Determine the bias shape
if (HasBias()) {
CuDNNSetBiasDesc<T>(&bias_desc_, X.ndim(), num_output_, data_format());
CuDNNSetBiasDesc<T>(
&bias_desc_, X.ndim(), out_channels_, data_format());
}
}
// Set the conv configuration
......@@ -180,16 +181,16 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(),
CuDNNType<T>::one,
filter_desc_,
w + w_ofs_ * g,
w + w_offset_ * g,
input_desc_,
x + x_ofs_ * g,
x + x_offset_ * g,
conv_desc_,
fwd_algo_,
scratch,
cudnn_ws_nbytes_,
CuDNNType<T>::zero,
output_desc_,
y + y_ofs_ * g));
y + y_offset_ * g));
}
if (HasBias()) {
......@@ -218,11 +219,11 @@ void CuDNNConvTranspose2dOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape();
if (data_format() == "NCHW") {
x_ofs_ = Input(0).stride(0) / cudnn_group_;
y_ofs_ = Output(0)->stride(0) / cudnn_group_;
x_offset_ = Input(0).stride(0) / cudnn_group_;
y_offset_ = Output(0)->stride(0) / cudnn_group_;
} else if (data_format() == "NHWC") {
x_ofs_ = Input(0).dim(-1) / cudnn_group_;
y_ofs_ = Output(0)->dim(-1) / cudnn_group_;
x_offset_ = Input(0).dim(-1) / cudnn_group_;
y_offset_ = Output(0)->dim(-1) / cudnn_group_;
}
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
......@@ -293,8 +294,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
filter_desc_,
CuDNNType<T>::type,
format_,
channels_ / cudnn_group_,
num_output_ / group_,
in_channels_ / cudnn_group_,
out_channels_ / group_,
kshape_[0],
kshape_[1]));
#else
......@@ -302,14 +303,15 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
filter_desc,
CuDNNType<T>::type,
format_,
channels_ / cudnn_group_,
num_output_ / group_,
in_channels_ / cudnn_group_,
out_channels_ / group_,
kshape_[0],
kshape_[1]));
#endif
// Determine the bias shape
if (HasBias()) {
CuDNNSetBiasDesc<T>(&bias_desc_, X.ndim(), num_output_, data_format());
CuDNNSetBiasDesc<T>(
&bias_desc_, X.ndim(), out_channels_, data_format());
}
}
// Set the conv configuration
......@@ -466,16 +468,16 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(),
CuDNNType<T>::one,
input_desc_,
dy + y_ofs_ * g,
dy + y_offset_ * g,
output_desc_,
x + x_ofs_ * g,
x + x_offset_ * g,
conv_desc_,
bwd_filter_algo_,
scratch,
cudnn_ws_nbytes_,
CuDNNType<T>::zero,
filter_desc_,
dw + w_ofs_ * g));
dw + w_offset_ * g));
}
}
......@@ -487,16 +489,16 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
ctx()->cudnn_handle(),
CuDNNType<T>::one,
input_desc_,
dy + y_ofs_ * g,
dy + y_offset_ * g,
filter_desc_,
w + w_ofs_ * g,
w + w_offset_ * g,
conv_desc_,
bwd_data_algo_,
scratch,
cudnn_ws_nbytes_,
CuDNNType<T>::zero,
output_desc_,
dx + x_ofs_ * g));
dx + x_offset_ * g));
}
}
}
......@@ -514,11 +516,11 @@ void CuDNNConvTranspose2dGradientOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(true);
if (data_format() == "NCHW") {
x_ofs_ = Input(0).stride(0) / cudnn_group_;
y_ofs_ = Input(-1).stride(0) / cudnn_group_;
x_offset_ = Input(0).stride(0) / cudnn_group_;
y_offset_ = Input(-1).stride(0) / cudnn_group_;
} else if (data_format() == "NHWC") {
x_ofs_ = Input(0).dim(-1) / cudnn_group_;
y_ofs_ = Input(-1).dim(-1) / cudnn_group_;
x_offset_ = Input(0).dim(-1) / cudnn_group_;
y_offset_ = Input(-1).dim(-1) / cudnn_group_;
}
DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1));
......
......@@ -10,10 +10,11 @@ namespace dragon {
template <class Context>
void ConvOpBase<Context>::ComputeOutShape() {
auto X_dims = Input(0).dims();
out_shape_.clear();
for (int i = 0; i < num_axes_; i++) {
if (!Transposed()) {
auto idm = x_shape_[axis_ + i];
auto idm = X_dims[axis_ + i];
auto dk = dilation_[i] * (kshape_[i] - 1) + 1;
if (!str::find(padding_, "SAME")) {
// Explicit pads
......@@ -32,7 +33,7 @@ void ConvOpBase<Context>::ComputeOutShape() {
} // SAME_LOWER or SAME
}
} else {
auto idm = x_shape_[axis_ + i];
auto idm = X_dims[axis_ + i];
auto dk = dilation_[i] * (kshape_[i] - 1) + 1;
if (!str::find(padding_, "SAME")) {
// Explicit pads
......@@ -79,13 +80,11 @@ template <class Context>
template <typename T>
void ConvOpBase<Context>::Wx(const T* x, const T* w, T* y, bool skip) {
auto* col = x;
if (!is_1x1_) {
auto* scratch = ws()->template data<T, Context>({col_dim_})[0];
if (!skip) Im2Col(x, scratch);
col = scratch;
}
for (int g = 0; g < group_; g++) {
if (data_format() == "NCHW") {
math::Gemm(
......@@ -95,10 +94,10 @@ void ConvOpBase<Context>::Wx(const T* x, const T* w, T* y, bool skip) {
conv_out_dim_,
kernel_dim_,
1.f,
w + w_ofs_ * g,
col + col_ofs_ * g,
w + w_offset_ * g,
col + col_offset_ * g,
0.f,
y + output_ofs_ * g,
y + out_offset_ * g,
ctx());
} else if (data_format() == "NHWC") {
math::Gemm(
......@@ -121,10 +120,11 @@ template <class Context>
template <typename T>
void ConvOpBase<Context>::Pb(const T* bias, T* y) {
if (data_format() == "NCHW") {
kernel::BiasAdd(Input(0).dim(0), num_output_, out_dim_, y, bias, y, ctx());
kernel::BiasAdd(
Input(0).dim(0), out_channels_, out_dim_, y, bias, y, ctx());
} else if (data_format() == "NHWC") {
kernel::BiasAdd(
Input(0).dim(0) * out_dim_, num_output_, 1, y, bias, y, ctx());
Input(0).dim(0) * out_dim_, out_channels_, 1, y, bias, y, ctx());
}
}
......@@ -141,10 +141,10 @@ void ConvOpBase<Context>::Dx(const T* dy, const T* w, T* dx) {
conv_out_dim_,
conv_out_channels_ / group_,
1.f,
w + w_ofs_ * g,
dy + output_ofs_ * g,
w + w_offset_ * g,
dy + out_offset_ * g,
0.f,
col + col_ofs_ * g,
col + col_offset_ * g,
ctx());
} else if (data_format() == "NHWC") {
math::Gemm(
......@@ -168,13 +168,11 @@ template <class Context>
template <typename T>
void ConvOpBase<Context>::Dw(const T* dy, const T* x, T* dw, bool accum) {
auto* col = x;
if (!is_1x1_) {
auto* scratch = ws()->template data<T, Context>({col_dim_})[0];
Im2Col(x, scratch);
col = scratch;
}
for (int g = 0; g < group_; g++) {
if (data_format() == "NCHW") {
math::Gemm(
......@@ -184,10 +182,10 @@ void ConvOpBase<Context>::Dw(const T* dy, const T* x, T* dw, bool accum) {
kernel_dim_,
conv_out_dim_,
1.f,
dy + output_ofs_ * g,
col + col_ofs_ * g,
dy + out_offset_ * g,
col + col_offset_ * g,
accum ? 1.f : 0.f,
dw + w_ofs_ * g,
dw + w_offset_ * g,
ctx());
} else if (data_format() == "NHWC") {
math::Gemm(
......@@ -211,10 +209,10 @@ template <typename T>
void ConvOpBase<Context>::Db(const T* dy, T* db) {
vec32_t dims, axes;
if (data_format() == "NCHW") {
dims = {(int)Input(0).dim(0), (int)num_output_, (int)out_dim_};
dims = {(int)Input(0).dim(0), (int)out_channels_, (int)out_dim_};
axes = {0, 2};
} else if (data_format() == "NHWC") {
dims = {(int)Input(0).dim(0), (int)out_dim_, (int)num_output_};
dims = {(int)Input(0).dim(0), (int)out_dim_, (int)out_channels_};
axes = {0, 1};
}
math::ReduceSum(3, dims.data(), 2, axes.data(), 1.f, dy, db, ctx());
......@@ -223,16 +221,15 @@ void ConvOpBase<Context>::Db(const T* dy, T* db) {
template <class Context>
void ConvOpBase<Context>::Setup(int num_axes) {
num_axes_ = num_axes;
auto at = [&](const vec64_t& vec, int i) {
return i < vec.size() ? vec[i] : vec[0];
};
auto pads = OpArgs<int64_t>("pads");
auto strides = OpArgs<int64_t>("strides");
auto kshape = OpArgs<int64_t>("kernel_shape");
auto dilations = OpArgs<int64_t>("dilations");
auto at = [&](const vec64_t& vec, int i) {
return i < vec.size() ? vec[i] : vec[0];
};
for (int i = 0; i < num_axes; i++) {
pad_l_.push_back(at(pads, i));
stride_.push_back(at(strides, i));
......@@ -241,8 +238,9 @@ void ConvOpBase<Context>::Setup(int num_axes) {
}
if ((int64_t)pads.size() == (num_axes * 2)) {
for (int i = 0; i < num_axes; i++)
for (int i = 0; i < num_axes; i++) {
pad_r_.push_back(pads[num_axes + i]);
}
} else {
pad_r_.assign(pad_l_.begin(), pad_l_.end());
}
......@@ -264,63 +262,56 @@ void ConvOpBase<Context>::Reshape(bool backward) {
auto* Y_ref = backward ? &Input(-1) : Output(0);
// Determine the in/out channels
channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
if (num_output_ <= 0) {
in_channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
if (out_channels_ <= 0) {
// Infer the out channels from the weights shape
num_output_ = W.count() / channels_;
for (int i = 0; i < num_axes_; i++)
num_output_ /= kshape_[i];
CHECK_GT(num_output_, 0) << "\nFailed to infer the out channels "
<< "from weights: " << W.DimString();
out_channels_ = W.count() / (in_channels_ / group_);
for (int i = 0; i < num_axes_; i++) {
out_channels_ /= kshape_[i];
}
CHECK_GT(out_channels_, 0) << "\nFailed to infer the out channels "
<< "from weights: " << W.DimString();
}
if (Transposed()) {
conv_out_channels_ = channels_;
conv_in_channels_ = num_output_;
conv_out_channels_ = in_channels_;
conv_in_channels_ = out_channels_;
} else {
conv_out_channels_ = num_output_;
conv_in_channels_ = channels_;
conv_out_channels_ = out_channels_;
conv_in_channels_ = in_channels_;
}
// Determine the weight and bias shape
// Weight shape is assumed as NCHW format
// whatever to compute the fans correctly
w_shape_ = {conv_out_channels_, conv_in_channels_ / group_};
for (int i = 0; i < num_axes_; i++)
for (int i = 0; i < num_axes_; i++) {
w_shape_.push_back(kshape_[i]);
b_shape_ = {num_output_};
}
b_shape_ = {out_channels_};
// Determine the Y shape
x_shape_ = X.dims();
// Determine the output shape
ComputeOutShape();
if (backward) {
if (Output(0)->has_name()) Output(0)->ReshapeLike(X);
if (Output(1)->has_name()) Output(1)->ReshapeLike(W);
if (Output(2)->has_name()) Output(2)->Reshape({num_output_});
if (Output(2)->has_name()) Output(2)->Reshape({out_channels_});
} else {
vec64_t Y_dims{X.dim(0)};
if (data_format() == "NCHW") {
y_shape_ = {X.dim(0), num_output_};
for (int i = 0; i < num_axes_; i++)
y_shape_.push_back(out_shape_[i]);
Y_dims.push_back(out_channels_);
for (int i = 0; i < num_axes_; i++) {
Y_dims.push_back(out_shape_[i]);
}
} else if (data_format() == "NHWC") {
y_shape_ = {X.dim(0)};
for (int i = 0; i < num_axes_; i++)
y_shape_.push_back(out_shape_[i]);
y_shape_.push_back(num_output_);
}
Output(0)->Reshape(y_shape_);
}
// Determine the input shape for im2col/col2im
in_shape_.clear();
for (int i = 0; i < num_axes_; i++) {
if (Transposed()) {
in_shape_.push_back(Y_ref->dim(axis_ + i));
} else {
in_shape_.push_back(X.dim(axis_ + i));
for (int i = 0; i < num_axes_; i++) {
Y_dims.push_back(out_shape_[i]);
}
Y_dims.push_back(out_channels_);
}
Output(0)->Reshape(Y_dims);
}
// Determine the out spatial dim
// Determine the output dim
auto end_axis = X.ndim() - 1;
if (data_format() == "NCHW") {
if (Transposed()) {
......@@ -338,25 +329,31 @@ void ConvOpBase<Context>::Reshape(bool backward) {
out_dim_ = Y_ref->count(axis_, end_axis);
}
// Determine the misc
x_ofs_ = X.stride(0);
y_ofs_ = Y_ref->stride(0);
// Compute the miscellaneous
x_offset_ = X.stride(0);
y_offset_ = Y_ref->stride(0);
kernel_dim_ = conv_in_channels_ / group_;
for (int i = 0; i < num_axes_; i++)
for (int i = 0; i < num_axes_; i++) {
kernel_dim_ *= kshape_[i];
col_ofs_ = kernel_dim_ * conv_out_dim_;
w_ofs_ = conv_out_channels_ * kernel_dim_ / group_;
output_ofs_ = conv_out_channels_ * conv_out_dim_ / group_;
}
col_offset_ = kernel_dim_ * conv_out_dim_;
w_offset_ = conv_out_channels_ * kernel_dim_ / group_;
out_offset_ = conv_out_channels_ * conv_out_dim_ / group_;
// Determine the workspace size for col buffer
col_dim_ = kernel_dim_ * group_;
// Compute the arguments for im2col/col2im
in_shape_.clear();
for (int i = 0; i < num_axes_; i++) {
if (Transposed()) {
col_dim_ *= x_shape_[axis_ + i];
in_shape_.push_back(Y_ref->dim(axis_ + i));
out_shape_[i] = X.dim(axis_ + i);
} else {
col_dim_ *= out_shape_[i];
in_shape_.push_back(X.dim(axis_ + i));
}
}
col_dim_ = kernel_dim_ * group_;
for (int i = 0; i < num_axes_; i++) {
col_dim_ *= out_shape_[i];
}
}
#define INSTANTIATE_API(Context, T) \
......
......@@ -25,7 +25,7 @@ class ConvOpBase : public Operator<Context> {
ConvOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
padding_(OpArg<string>("padding", "VALID")),
num_output_(OpArg<int64_t>("num_output", 0)),
out_channels_(OpArg<int64_t>("out_channels", 0)),
group_(OpArg<int64_t>("group", 1)) {
if (data_format() == "NCHW") {
axis_ = 2;
......@@ -42,18 +42,13 @@ class ConvOpBase : public Operator<Context> {
vec64_t kshape_, stride_;
vec64_t pad_l_, pad_r_, dilation_;
vec64_t in_shape_, out_shape_;
vec64_t x_shape_, y_shape_;
vec64_t w_shape_, b_shape_;
vec64_t in_shape_, w_shape_, b_shape_, out_shape_;
string padding_;
int64_t is_1x1_, num_output_, group_;
int64_t group_;
int64_t axis_, num_axes_;
int64_t channels_, out_dim_;
int64_t conv_in_channels_, conv_out_channels_;
int64_t conv_out_dim_, kernel_dim_, col_dim_;
int64_t col_ofs_, output_ofs_;
int64_t w_ofs_, x_ofs_, y_ofs_;
int64_t in_channels_, out_channels_, out_dim_;
int64_t x_offset_, w_offset_, y_offset_;
DECLARE_ARGS_WITH_DESC(int64_t, output_shape);
DECLARE_ARGS_WITH_DESC(int64_t, output_padding);
......@@ -133,37 +128,42 @@ class ConvOpBase : public Operator<Context> {
LOG(FATAL) << "ConvNd has not been implemented.";
}
}
int64_t is_1x1_;
int64_t kernel_dim_, col_dim_;
int64_t col_offset_, out_offset_;
int64_t conv_in_channels_, conv_out_channels_, conv_out_dim_;
};
DEFINE_ARGS_WITH_DESC(int64_t, ConvOpBase, output_shape);
DEFINE_ARGS_WITH_DESC(int64_t, ConvOpBase, output_padding);
#define USE_CONVOLUTION_FUNCTIONS \
using ConvOpBase<Context>::Setup; \
using ConvOpBase<Context>::Reshape; \
using ConvOpBase<Context>::Transposed; \
using ConvOpBase<Context>::HasBias; \
using ConvOpBase<Context>::Wx; \
using ConvOpBase<Context>::Pb; \
using ConvOpBase<Context>::Dx; \
using ConvOpBase<Context>::Dw; \
using ConvOpBase<Context>::Db; \
using ConvOpBase<Context>::kshape_; \
using ConvOpBase<Context>::stride_; \
using ConvOpBase<Context>::pad_l_; \
using ConvOpBase<Context>::pad_r_; \
using ConvOpBase<Context>::dilation_; \
using ConvOpBase<Context>::group_; \
using ConvOpBase<Context>::channels_; \
using ConvOpBase<Context>::num_output_; \
using ConvOpBase<Context>::axis_; \
using ConvOpBase<Context>::num_axes_; \
using ConvOpBase<Context>::x_ofs_; \
using ConvOpBase<Context>::y_ofs_; \
using ConvOpBase<Context>::w_ofs_; \
using ConvOpBase<Context>::w_shape_; \
using ConvOpBase<Context>::b_shape_; \
using ConvOpBase<Context>::in_shape_; \
#define USE_CONVOLUTION_FUNCTIONS \
using ConvOpBase<Context>::Setup; \
using ConvOpBase<Context>::Reshape; \
using ConvOpBase<Context>::Transposed; \
using ConvOpBase<Context>::HasBias; \
using ConvOpBase<Context>::Wx; \
using ConvOpBase<Context>::Pb; \
using ConvOpBase<Context>::Dx; \
using ConvOpBase<Context>::Dw; \
using ConvOpBase<Context>::Db; \
using ConvOpBase<Context>::kshape_; \
using ConvOpBase<Context>::stride_; \
using ConvOpBase<Context>::pad_l_; \
using ConvOpBase<Context>::pad_r_; \
using ConvOpBase<Context>::dilation_; \
using ConvOpBase<Context>::group_; \
using ConvOpBase<Context>::in_channels_; \
using ConvOpBase<Context>::out_channels_; \
using ConvOpBase<Context>::axis_; \
using ConvOpBase<Context>::num_axes_; \
using ConvOpBase<Context>::x_offset_; \
using ConvOpBase<Context>::w_offset_; \
using ConvOpBase<Context>::y_offset_; \
using ConvOpBase<Context>::in_shape_; \
using ConvOpBase<Context>::w_shape_; \
using ConvOpBase<Context>::b_shape_; \
using ConvOpBase<Context>::out_shape_
} // namespace dragon
......
......@@ -10,14 +10,15 @@ template <typename T>
void DepthwiseConv2dOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0);
group_ = channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
CHECK_EQ(channels_, num_output_) << "\nExcepted in/out channels unchanged.";
group_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
ConvOpBase<Context>::Reshape();
CHECK_EQ(in_channels_, out_channels_)
<< "\nExcepted in/out channels to be same.";
TENSOR_FILL(W, w_shape_);
kernel::DepthwiseConv2d(
Input(0).dim(0),
channels_,
in_channels_,
in_shape_[0],
in_shape_[1],
out_shape_[0],
......@@ -54,13 +55,13 @@ void DepthwiseConv2dGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
group_ = channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
group_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
ConvOpBase<Context>::Reshape(true);
if (dX->has_name()) {
kernel::DepthwiseConv2dGrad(
X.dim(0),
channels_,
in_channels_,
in_shape_[0],
in_shape_[1],
out_shape_[0],
......@@ -83,7 +84,7 @@ void DepthwiseConv2dGradientOp<Context>::DoRunWithType() {
if (dW->has_name()) {
kernel::DepthwiseConv2dWGrad(
X.dim(0),
channels_,
in_channels_,
in_shape_[0],
in_shape_[1],
out_shape_[0],
......
......@@ -12,14 +12,15 @@ template <typename T>
void CuDNNDepthwiseConv2dOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0);
group_ = channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
CHECK_EQ(channels_, num_output_) << "\nExcepted in/out channels unchanged.";
group_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
ConvOpBase<Context>::Reshape();
CHECK_EQ(in_channels_, out_channels_)
<< "\nExcepted in/out channels to be same.";
TENSOR_FILL(W, w_shape_);
kernel::DepthwiseConv2d(
X.dim(0),
channels_,
in_channels_,
in_shape_[0],
in_shape_[1],
out_shape_[0],
......@@ -40,7 +41,7 @@ void CuDNNDepthwiseConv2dOp<Context>::DoRunWithType() {
if (HasBias()) {
TENSOR_FILL(Input(2), b_shape_);
CuDNNSetBiasDesc<T>(&bias_desc_, 4, num_output_, data_format());
CuDNNSetBiasDesc<T>(&bias_desc_, 4, out_channels_, data_format());
CuDNNSetTensorDesc<T>(&output_desc_, Y->dims(), data_format());
CUDNN_CHECK(cudnnAddTensor(
ctx()->cudnn_handle(),
......@@ -64,13 +65,13 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
group_ = channels_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
group_ = data_format() == "NCHW" ? X.dim(1) : X.dim(-1);
ConvOpBase<Context>::Reshape(true);
if (dX->has_name()) {
kernel::DepthwiseConv2dGrad(
X.dim(0),
channels_,
in_channels_,
in_shape_[0],
in_shape_[1],
out_shape_[0],
......@@ -93,7 +94,7 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() {
if (dW->has_name()) {
kernel::DepthwiseConv2dWGrad(
X.dim(0),
channels_,
in_channels_,
in_shape_[0],
in_shape_[1],
out_shape_[0],
......@@ -115,7 +116,7 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() {
if (dB->has_name()) {
CuDNNSetTensorDesc<T>(&input_desc_, Input(-1).dims(), data_format());
CuDNNSetBiasDesc<T>(&bias_desc_, 4, num_output_, data_format());
CuDNNSetBiasDesc<T>(&bias_desc_, 4, out_channels_, data_format());
CUDNN_CHECK(cudnnConvolutionBackwardBias(
ctx()->cudnn_handle(),
CuDNNType<T>::one,
......
......@@ -50,7 +50,7 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
if (data_format() == "NCHW") {
for (int i = 0; i < num_axes; i++) {
perm.insert(perm.begin() + 1, perm.back());
perm.pop_back(); // CRD mode
perm.pop_back(); // DCR mode
}
}
......
......@@ -10,61 +10,65 @@ package dragon;
// Store the serialized Tensor objects.
message TensorProto {
repeated int32 dims = 1;
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1;
INT32 = 2;
BYTE = 3;
STRING = 4;
// Less-commonly used data types.
BOOL = 5;
UINT8 = 6;
INT8 = 7;
UINT16 = 8;
INT16 = 9;
INT64 = 10;
FLOAT16 = 12;
DOUBLE = 13;
}
optional DataType data_type = 2 [default = FLOAT];
// For float.
repeated float float_data = 3 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16
// Note about float16: in storage we will basically convert float16 byte-wise
// to unsigned short and then store them in the int32_data field.
repeated int32 int32_data = 4 [packed = true];
// For bytes.
optional bytes byte_data = 5;
// For strings.
repeated bytes string_data = 6;
// For double.
repeated double double_data = 9 [packed = true];
// For int64.
repeated int64 int64_data = 10 [packed = true];
// Store the raw data, contents are serialized as little-endian.
optional bytes raw_data = 13;
// Optionally, a name for the tensor.
optional string name = 7;
repeated int32 dims = 1;
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1;
INT32 = 2;
BYTE = 3;
STRING = 4;
// Less-commonly used data types.
BOOL = 5;
UINT8 = 6;
INT8 = 7;
UINT16 = 8;
INT16 = 9;
INT64 = 10;
FLOAT16 = 12;
DOUBLE = 13;
}
optional DataType data_type = 2 [default = FLOAT];
// For float.
repeated float float_data = 3 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16
// Note about float16: in storage we will basically convert float16 byte-wise
// to unsigned short and then store them in the int32_data field.
repeated int32 int32_data = 4 [packed = true];
// For bytes.
optional bytes byte_data = 5;
// For strings.
repeated bytes string_data = 6;
// For double.
repeated double double_data = 9 [packed = true];
// For int64.
repeated int64 int64_data = 10 [packed = true];
// Store the raw data, contents are serialized as little-endian.
optional bytes raw_data = 13;
// Optionally, a name for the tensor.
optional string name = 7;
}
// Record the filler of Tensor.
// This structure is kept for backward compatibility
// with caffe1, which relies implicit initializer.
message TensorFillerProto {
optional string tensor = 1;
optional string type = 2 [default = 'constant'];
optional float value = 3 [default = 0];
optional float low = 4 [default = 0];
optional float high = 5 [default = 1];
optional float mean = 6 [default = 0];
optional float std = 7 [default = 1];
optional float scale = 8 [default = 3];
enum VarianceNorm { FAN_IN = 0; FAN_OUT = 1; FAN_AVG=2; }
optional VarianceNorm variance_norm = 9 [default = FAN_IN];
optional string tensor = 1;
optional string type = 2 [default = 'constant'];
optional float value = 3 [default = 0];
optional float low = 4 [default = 0];
optional float high = 5 [default = 1];
optional float mean = 6 [default = 0];
optional float std = 7 [default = 1];
optional float scale = 8 [default = 3];
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
FAN_AVG = 2;
}
optional VarianceNorm variance_norm = 9 [default = FAN_IN];
}
// Store multiple TensorProto objects in one single proto.
......@@ -74,99 +78,99 @@ message TensorProtos {
// DeviceType that Dragon currently supports.
enum DeviceTypeProto {
// The default device.
PROTO_CPU = 0;
// NVIDIA's CUDA Environment.
PROTO_CUDA = 1;
// CAMBRICON's CNML Environment.
PROTO_CNML = 2;
// The default device.
PROTO_CPU = 0;
// NVIDIA's CUDA Environment.
PROTO_CUDA = 1;
// CAMBRICON's CNML Environment.
PROTO_CNML = 2;
}
// Device-specific options.
message DeviceOption {
// The type of device to dispatch executions.
optional DeviceTypeProto device_type = 1 [default = PROTO_CPU];
// The index of this device.
optional int32 device_id = 2 [default = 0];
// The random seed to start the random generator.
optional uint32 random_seed = 3 [default = 3];
// The type of device to dispatch executions.
optional DeviceTypeProto device_type = 1 [default = PROTO_CPU];
// The index of this device.
optional int32 device_id = 2 [default = 0];
// The random seed to start the random generator.
optional uint32 random_seed = 3 [default = 3];
}
// A named argument containing either singular float, integer and string
// values, or repeated float, int and string arrays.
message Argument {
// The name of this argument.
optional string name = 1;
// Store the float32 value.
optional float f = 2;
// Store the bool, int32, int64 value.
optional int64 i = 3;
// Store the string value.
optional bytes s = 4;
// Store the float32 values.
repeated float floats = 7;
// Store the bool, int32, int64 values.
repeated int64 ints = 8;
// Store the string values.
repeated bytes strings = 9;
// The name of this argument.
optional string name = 1;
// Store the float32 value.
optional float f = 2;
// Store the bool, int32, int64 value.
optional int64 i = 3;
// Store the string value.
optional bytes s = 4;
// Store the float32 values.
repeated float floats = 7;
// Store the bool, int32, int64 values.
repeated int64 ints = 8;
// Store the string values.
repeated bytes strings = 9;
}
// Operator Definition
message OperatorDef {
// The name of inputs.
repeated string input = 1;
// The name of outputs.
repeated string output = 2;
// The optional name of this operator.
optional string name = 3;
// The operator type.
optional string type = 4;
// The arguments.
repeated Argument arg = 5;
// The device option that the operator should run under.
optional DeviceOption device_option = 6;
// The optional unique key for this operator.
// Set it to persist operators in the eager mode.
optional string cache_key = 7;
// The name of inputs.
repeated string input = 1;
// The name of outputs.
repeated string output = 2;
// The optional name of this operator.
optional string name = 3;
// The operator type.
optional string type = 4;
// The arguments.
repeated Argument arg = 5;
// The device option that the operator should run under.
optional DeviceOption device_option = 6;
// The optional unique key for this operator.
// Set it to persist operators in the eager mode.
optional string cache_key = 7;
}
// Record the gradient information
message GradientProto {
// The derivative target.
optional string cost = 1;
// The target with respect to?
optional string wrt = 2;
// The external gradient
optional string external = 3;
// The derivative target.
optional string cost = 1;
// The target with respect to?
optional string wrt = 2;
// The external gradient
optional string external = 3;
}
// Graph Definition
message GraphDef {
// The graph name.
optional string name = 1;
// The graph name.
optional string name = 1;
// The operators to execute.
repeated OperatorDef op = 2;
// The operators to execute.
repeated OperatorDef op = 2;
// The type of graph.
optional string graph_type = 3;
// The type of graph.
optional string graph_type = 3;
// The device option for this graph.
optional DeviceOption device_option = 5;
// The device option for this graph.
optional DeviceOption device_option = 5;
// The arguments.
repeated Argument arg = 6;
// The arguments.
repeated Argument arg = 6;
// The name of inputs.
repeated string input = 7;
// The name of outputs.
repeated string output = 8;
// The name of inputs.
repeated string input = 7;
// The name of outputs.
repeated string output = 8;
// The gradients information.
repeated GradientProto gradient = 9;
// The gradients information.
repeated GradientProto gradient = 9;
}
......@@ -28,10 +28,10 @@ from dragon._api import losses
from dragon._api import math
from dragon._api import metrics
from dragon._api import nn
from dragon._api import optimizers
from dragon._api import random
from dragon._api import updaters
from dragon._api import vision
from dragon._api import workspace
from dragon._api import vision
# Virtual API
from dragon import vm
......@@ -56,7 +56,7 @@ from dragon.core.framework.context import name_scope
from dragon.core.framework.workspace import get_workspace
from dragon.core.framework.workspace import reset_workspace
from dragon.core.ops import tensorbind_eager as _
from dragon.core.ops import tensorbind_symbolic as _
from dragon.core.ops import tensorbind_symbol as _
from dragon.core.ops.array_ops import arange
from dragon.core.ops.array_ops import broadcast_to
from dragon.core.ops.array_ops import cast
......
......@@ -24,9 +24,9 @@ from dragon.core.ops.array_ops import min
from dragon.core.ops.array_ops import moments
from dragon.core.ops.array_ops import sum
from dragon.core.ops.math_ops import abs
from dragon.core.ops.math_ops import accumulate
from dragon.core.ops.math_ops import add
from dragon.core.ops.math_ops import affine
from dragon.core.ops.math_ops import axpby
from dragon.core.ops.math_ops import ceil
from dragon.core.ops.math_ops import clip
from dragon.core.ops.math_ops import cos
......@@ -45,7 +45,6 @@ from dragon.core.ops.math_ops import log
from dragon.core.ops.math_ops import matmul
from dragon.core.ops.math_ops import maximum
from dragon.core.ops.math_ops import minimum
from dragon.core.ops.math_ops import moving_average
from dragon.core.ops.math_ops import mul
from dragon.core.ops.math_ops import negative
from dragon.core.ops.math_ops import not_equal
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.training.adam import Adam
from dragon.core.training.optimizer import Optimizer
from dragon.core.training.rmsprop import RMSprop
from dragon.core.training.sgd import Nesterov
from dragon.core.training.sgd import SGD
__all__ = [_s for _s in dir() if not _s.startswith('_')]
......@@ -29,7 +29,7 @@ from dragon.core.eager import context as eager_context
from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import context
from dragon.core.framework import workspace
from dragon.core.training import updater
from dragon.core.training import optimizer
from dragon.core.util import decorator
from dragon.core.util import inspect
from dragon.core.util import nest
......@@ -265,7 +265,7 @@ class FunctionGuard(object):
dummies.append(obj)
executables = [function_lib.create_function(inputs, outputs)]
for obj in dummies:
if isinstance(obj, updater.Updater):
if isinstance(obj, optimizer.Optimizer):
executables.append(function_lib.create_function(updater=obj))
self.inputs = inputs
self.outputs = returns
......
......@@ -78,22 +78,22 @@ def add_phase(graph_def, targets):
graph_def.arg.extend([proto_util.make_argument('phase', phase)])
def add_update_ops(graph_def, updater):
def add_update_ops(graph_def, optimizer):
"""Add the update operators for graph."""
if updater is None:
if optimizer is None:
return
grads, update_ops = [], []
extra_arguments = updater._extra_kwargs
extra_arguments['slot'] = updater._slot
extra_arguments = optimizer._extra_kwargs
extra_arguments['handle'] = optimizer._op_handle
# Generate update operators according to the updater.
for e in updater._param_group:
for e in optimizer._param_group:
(param, grad), arguments = e
if workspace.has_tensor(grad):
grads.append(grad)
arguments = dict(arguments, **extra_arguments)
update_ops.append(
proto_util.make_operator_def(
op_type=updater._op_type,
op_type=optimizer._op_type,
inputs=[grad],
outputs=[param],
name=OpDef.get_name(),
......@@ -102,7 +102,7 @@ def add_update_ops(graph_def, updater):
else:
logging.info('Skip to update Tensor({}).'.format(param))
# Insert a reduce op if the process group is found.
process_group = updater._process_group
process_group = optimizer._process_group
if process_group is not None:
update_ops.insert(
0, proto_util.make_operator_def(
......@@ -139,12 +139,15 @@ class Function(object):
# Collect the forward operators.
requires_grad = False
for output in outputs:
for i, output in enumerate(outputs):
op_info.merge_from(output)
op_info.add_target(output.id)
if output._grad is not None and \
output._grad.required():
requires_grad = True
try:
grad_info = output._grad
if grad_info and grad_info.required():
requires_grad = True
except AttributeError:
raise ValueError('Output[%d] is not a symbolic tensor.' % i)
# Handle givens.
if givens is not None:
......@@ -169,23 +172,23 @@ class Function(object):
])
del op_def.input[:len(op_def.input) // 2]
# Sort out the topology of states.
# Sort out the states.
op_defs = sorted(op_info._defs.items(), key=lambda d: d[0])
forward_ops = copy.deepcopy([v for k, v in op_defs])
# Generate the backward operators.
if requires_grad:
input_grads = {}
input_grads, grad_targets = {}, []
for output in outputs:
if hasattr(output, '_grad'):
grad_info = output._grad
if grad_info is not None:
if grad_info.input is not None:
input_grads[output.id] = grad_info.input.id
grad_info = output._grad
if grad_info is not None:
if grad_info.input is not None:
input_grads[output.id] = output._grad.input.id
grad_targets.append(output.id)
forward_ops, gradient_ops, _ = \
grad_maker.GradientMaker.make(
forward_ops=forward_ops,
targets=list(op_info._targets),
targets=grad_targets,
input_grads=input_grads,
)
else:
......
......@@ -13,7 +13,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.autograph.tensor import RefTensor
from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context
from dragon.core.util import nest
......@@ -120,19 +120,12 @@ def gradients(ys, xs, grad_ys=None):
if grad_ys is not None:
y._grad.set_input(grad_ys[i])
for x in xs:
if not hasattr(x, '_grad') or \
x._grad is None:
if not hasattr(x, '_grad') or x._grad is None:
x._grad = GradientInfo(x)
y._grad.add_wrt(x.id)
x._grad.add_cost(y)
if i == 0:
dxs.append(
RefTensor(
name=x.id + '_grad',
shape=x.shape,
dtype=x.dtype,
)
)
dxs.append(TensorRef(x.id + '_grad', x.shape, x.dtype))
# Return the packed gradients.
return dxs
......@@ -15,7 +15,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.autograph.tensor import RefTensor
from dragon.core.autograph.tensor import TensorRef
from dragon.core.autograph import op_spec
from dragon.core.framework import context
from dragon.core.framework import proto_util
......@@ -76,26 +76,24 @@ class OpDef(object):
outputs = []
name_scope = context.get_name_scope()
for i in range(num_outputs):
outputs.append(RefTensor(
outputs.append(TensorRef(
workspace.get_dummy_name(
name_scope + (name if name else op_type),
suffix=':{}'.format(i),
domain='Tensor'))
)
domain='Tensor')))
else:
outputs = nest.flatten(outputs)
num_outputs = len(outputs)
# Construct Def.
op_idx, op_name = OpDef.get_index_and_name()
op_info._defs[op_idx] = \
proto_util.make_operator_def(
name=op_name,
op_type=op_type,
inputs=[input.id for input in inputs],
outputs=[output.id for output in outputs],
device_option=proto_util.get_default_device_option(),
**kwargs)
op_info._defs[op_idx] = proto_util.make_operator_def(
name=op_name,
op_type=op_type,
inputs=[input.id for input in inputs],
outputs=[output.id for output in outputs],
device_option=proto_util.get_default_device_option(),
**kwargs)
# Blend the op for outputs.
for output in outputs:
......
......@@ -147,7 +147,7 @@ def cast_spec(args, inputs, outputs):
outputs[0].dtype = args['dtype']
try:
outputs[0].shape = inputs[0].shape[:]
except TypeError:
except (TypeError, IndexError):
pass
return outputs
......@@ -192,7 +192,10 @@ def conv_spec(args, inputs, outputs):
out_shape = inputs[0].shape[:]
channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
out_shape[channel_axis] = args['num_output']
if 'out_channels' in args:
out_shape[channel_axis] = args['out_channels']
else:
out_shape[channel_axis] = inputs[1].shape[0]
for i in range(len(out_shape) - 2):
input_size = out_shape[i + spatial_axis]
k = args['kernel_shape'][i]
......@@ -219,7 +222,10 @@ def conv_transpose_spec(args, inputs, outputs):
out_shape = inputs[0].shape[:]
channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
out_shape[channel_axis] = args['num_output']
if 'out_channels' in args:
out_shape[channel_axis] = args['out_channels']
else:
out_shape[channel_axis] = inputs[1].shape[1]
for i in range(len(out_shape) - 2):
k = args['kernel_shape'][i]
s = args['strides'][i]
......@@ -274,20 +280,16 @@ def depth_to_space_spec(args, inputs, outputs):
@register('Dot')
def dot_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype
ta, tb = args['transA'], args['transB']
try:
if len(inputs[0].shape) == 1:
a_shape, b_shape = inputs[0].shape[:], inputs[1].shape[:]
if len(a_shape) == 1 and len(b_shape) == 1:
outputs[0].shape = []
return outputs
except TypeError:
pass
try:
if len(inputs[0].shape) >= 2 and len(inputs[1].shape) in (1, 2):
out_shape = inputs[0].shape[1:] if ta else inputs[0].shape[:-1]
if len(inputs[1].shape) == 2:
out_shape.append(inputs[1].shape[0] if tb else inputs[1].shape[1])
outputs[0].shape = out_shape
return outputs
elif len(a_shape) == 2 and len(b_shape) == 2:
outputs[0].shape = [a_shape[0], b_shape[1]]
elif len(a_shape) == 0 and len(b_shape) == 0:
outputs[0].shape = []
elif len(a_shape) >= 2 and len(b_shape) == 1:
outputs[0].shape = a_shape[:-1]
except TypeError:
pass
return outputs
......@@ -298,6 +300,7 @@ def dot_spec(args, inputs, outputs):
'L1Loss',
'L2Loss',
'SigmoidCrossEntropy',
'SigmoidFocalLoss',
'SmoothL1Loss',
])
def eltwise_loss_spec(args, inputs, outputs):
......@@ -426,22 +429,22 @@ def flatten_spec(args, inputs, outputs):
@register('FullyConnected')
def fully_connected_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype
axis, num_output = args['axis'], args['num_output']
axis, out_channels = args['axis'], args.get('out_channels', None)
while axis < 0:
try:
axis += len(inputs[0].shape)
except TypeError:
return outputs
outputs[0].shape = [None] * (axis + 1)
if num_output is None:
if out_channels is None:
try:
if args['transW']:
num_output = inputs[1].shape[0]
out_channels = inputs[1].shape[0]
else:
num_output = inputs[1].shape[1]
out_channels = inputs[1].shape[1]
except (TypeError, IndexError):
num_output = None
outputs[0].shape[axis] = num_output
out_channels = None
outputs[0].shape[axis] = out_channels
try:
outputs[0].shape[:axis] = inputs[0].shape[:axis]
except TypeError:
......@@ -488,7 +491,7 @@ def index_select_spec(args, inputs, outputs):
return outputs
@register(['IsInf', 'InNaN'])
@register(['IsInf', 'IsNaN'])
def is_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = 'bool'
......@@ -507,7 +510,7 @@ def masked_select_spec(args, inputs, outputs):
return outputs
@register('Matmul')
@register('MatMul')
def matmul_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype
ta, tb = args['transA'], args['transB']
......@@ -758,7 +761,7 @@ def resize_spec(args, inputs, outputs):
@register(['RoiPool', 'RoiAlign'])
def roi_pool_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype
pool_h, pool_w = args['pool_h'], args['pool_w']
pool_h, pool_w = args['pooled_h'], args['pooled_w']
out_shape = None
try:
out_shape = inputs[0].shape[:]
......@@ -814,7 +817,6 @@ def slice_spec(args, inputs, outputs):
@register([
'NLLLoss',
'SigmoidFocalLoss',
'SoftmaxCrossEntropy',
'SparseSoftmaxCrossEntropy',
])
......
......@@ -420,7 +420,7 @@ class Tensor(types.TensorMetaclass):
The constant contains the value.
"""
return RefTensor('', dtype=dtype)._from_constant(value, name)
return Tensor('', dtype=dtype)._from_constant(value, name)
def _register_as(self, type, **kwargs):
"""Fill self with the specific type of filler."""
......@@ -463,13 +463,12 @@ class Tensor(types.TensorMetaclass):
"""Convert the value to a tensor."""
if not isinstance(value, numpy.ndarray):
value = numpy.array(value, self.dtype if self.dtype else 'float32')
return RefTensor(
return TensorRef(
name=workspace.get_dummy_name(
basename=context.get_name_scope() +
(name if name else 'Const'),
suffix=':0',
domain='Tensor'
),
domain='Tensor'),
shape=list(value.shape),
dtype=str(value.dtype),
).set_value(value)
......@@ -560,8 +559,8 @@ class Tensor(types.TensorMetaclass):
return self.__div__(other)
class RefTensor(object):
"""Create a reference tensor not involved with name scope."""
class TensorRef(object):
"""Create a reference not involved with name scope."""
def __new__(cls, name, shape=None, dtype=None):
tensor = Tensor('', shape=shape, dtype=dtype)
......
......@@ -9,7 +9,7 @@
#
# ------------------------------------------------------------
"""Some useful mappings are defined here."""
"""Constant mappings."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -104,7 +104,7 @@ class Operator(object):
"""Generate the OpDef from attributes."""
attributes = self.attributes()
self._def = proto_util.make_operator_cdef(
name='Generic',
name=attributes.get('name', 'GenericOp'),
cache_key=self._cache_key,
op_type=attributes['op_type'],
device_option=proto_util.get_device_option(
......
......@@ -134,10 +134,14 @@ def make_operator_cdef(
op_def = backend.OperatorDef()
op_def.ParseFrom(
make_operator_def(
op_type, inputs, outputs, name,
cache_key, device_option, arg, **kwargs
).SerializeToString()
)
op_type,
inputs,
outputs,
name,
cache_key,
device_option,
arg,
**kwargs).SerializeToString())
return op_def
......
......@@ -9,12 +9,7 @@
#
# ------------------------------------------------------------
"""Wrappers for the Workspace of C++ backend.
Flexible API is provided to manage the global resources
between the Python threads (quite different from C++).
"""
"""Generic interfaces of current default workspace."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -268,7 +268,7 @@ def leaky_relu(inputs, alpha=0.2, **kwargs):
@OpSchema.num_inputs(1)
def log_softmax(inputs, axis=1, **kwargs):
def log_softmax(inputs, axis=-1, **kwargs):
r"""Apply the composite of logarithm and softmax.
The **LogSoftmax** function is defined as:
......@@ -287,7 +287,7 @@ def log_softmax(inputs, axis=1, **kwargs):
----------
inputs : dragon.Tensor
The input tensor.
axis : int, optional, default=1
axis : int, optional, default=-1
The axis to reduce.
Returns
......@@ -351,7 +351,7 @@ def prelu(inputs, channel_shared=False, data_format='NCHW', **kwargs):
if context.executing_eagerly():
return op_lib \
.instantiate(data_format=data_format) \
.apply([inputs])
.apply(inputs)
else:
return op_lib.blend(**args)
......@@ -373,7 +373,7 @@ def relu(inputs, **kwargs):
Examples:
```python
x = dragon.constant([-1, 0, 1], 'float32')
x = dragon.constant([-1., 0., 1.])
print(dragon.nn.relu(x, inplace=False))
```
......@@ -449,10 +449,10 @@ def selu(inputs, alpha=1.67326, gamma=1.0507, **kwargs):
.. math::
\text{SELU}(x) = \gamma *
\begin{cases}
x, & \text{ if } x \geq 0 \\
\alpha * (e^{x} - 1), & \text{ otherwise }
\end{cases}
\begin{cases}
x, & \text{ if } x \geq 0 \\
\alpha * (e^{x} - 1), & \text{ otherwise }
\end{cases}
Examples:
......@@ -561,9 +561,8 @@ def softmax(inputs, axis=-1, **kwargs):
op_lib = activation_ops_lib.Softmax
if context.executing_eagerly():
return op_lib \
.instantiate(
axis=axis,
).apply([inputs], inplace=inplace)
.instantiate(axis=axis) \
.apply([inputs], inplace=inplace)
else:
return op_lib.blend(**args)
......
......@@ -64,11 +64,14 @@ def arange(start, stop=None, step=1, dtype='int64', **kwargs):
"""
args = parse_args(locals())
args['dtype'] = args['dtype'].lower()
op_lib = array_ops_lib.Arange
if stop is None:
args['slice'] = (start, step)
args['slice'] = (float(start), float(step))
else:
args['slice'] = (start, stop, step)
args['slice'] = (float(start), float(stop), float(step))
args.pop('start')
args.pop('stop')
args.pop('step')
op_lib = array_ops_lib.Arange
trainable = args.pop('trainable') if 'trainable' in args else False
if context.executing_eagerly():
return op_lib.instantiate(
......@@ -269,6 +272,8 @@ def cast(inputs, dtype, **kwargs):
.instantiate(dtype=dtype) \
.apply([inputs], inplace=inplace)
else:
if inputs.dtype == dtype:
return inputs
if inplace:
args['inputs'], args['outputs'] = [], [inputs]
return op_lib.blend(**args)
......@@ -627,16 +632,14 @@ def index_select(inputs, indices, axis=0, **kwargs):
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def masked_select(inputs, mask, **kwargs):
@OpSchema.num_inputs(2)
def masked_select(inputs, **kwargs):
"""Select the elements where the given mask is **1**.
Parameters
----------
inputs : dragon.Tensor
The input tensor.
mask : dragon.Tensor
The mask, with the same size as ``inputs``.
inputs : Sequence[dragon.Tensor]
The input and mask tensor.
Returns
-------
......@@ -647,9 +650,8 @@ def masked_select(inputs, mask, **kwargs):
args = parse_args(locals())
op_lib = array_ops_lib.MaskedSelect
if context.executing_eagerly():
return op_lib.instantiate().apply([inputs, mask])
return op_lib.instantiate().apply(inputs)
else:
args['inputs'] = [args['inputs'], args.pop('mask')]
return op_lib.blend(**args)
......@@ -1047,7 +1049,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs):
.instantiate(
ndim=len(pads_begin),
value=args['value'],
mode=mode,
mode=args['mode'],
).apply([inputs], args['pads'])
else:
return op_lib.blend(**args)
......@@ -1278,7 +1280,9 @@ def split(
size_splits = None
if slice_points is not None:
if len(slice_points) + 1 != num_splits:
raise ValueError('Excepted %d values for <slice_points>.')
raise ValueError(
'Excepted %d values for <slice_points>.'
% len(slice_points))
if context.executing_eagerly():
return op_lib \
.instantiate(
......
......@@ -61,38 +61,36 @@ def assign(inputs, starts=None, sizes=None, **kwargs):
@OpSchema.num_inputs(1, 2)
def copy(inputs, **kwargs):
r"""Copy the value to ref.
.. math:: \text{Ref}[:] = \text{Value}[:]
"""Copy the input.
Examples:
```python
# Copy the content from ``x`` to ``xx``
# Copy ``x`` to ``y``
x = dragon.ones(shape=(2, 3))
xx = dragon.zeros(shape=(2, 4))
dragon.copy([xx, x])
y = dragon.zeros(shape=(2, 4))
dragon.copy([x, y])
# Create a new tensor initialized from ``x``
xxx = dragon.copy(x)
# Copy to a new tensor from ``x``
y = dragon.copy(x)
```
Parameters
----------
inputs : Sequence[dragon.Tensor]
The **ref** and **value**.
inputs : Union[dragon.Tensor, Sequence[dragon.Tensor]]
The input tensor.
Returns
-------
dragon.Tensor
The **ref**.
The output tensor.
"""
args = parse_args(locals())
inputs = nest.flatten(inputs)
if len(inputs) == 2:
args['inputs'] = [inputs[1]]
args['outputs'] = [inputs[0]]
args['inputs'] = nest.flatten(inputs)
if len(args['inputs']) == 2:
args['outputs'] = [args['inputs'][1]]
args['inputs'] = [args['inputs'][0]]
else:
args['outputs'] = None
op_lib = control_flow_ops_lib.Copy
......@@ -104,8 +102,8 @@ def copy(inputs, **kwargs):
return op_lib.blend('Copy', **args)
@OpSchema.num_inputs(2)
def masked_assign(inputs, mask, **kwargs):
@OpSchema.num_inputs(3)
def masked_assign(inputs, **kwargs):
r"""Assign the value to ref where mask is **1**.
.. math::
......@@ -118,24 +116,22 @@ def masked_assign(inputs, mask, **kwargs):
Parameters
----------
inputs : Sequence[dragon.Tensor]
The **ref** and **value**.
mask : dragon.Tensor
The mask, with the same size as **ref**.
The **ref**, **value** and **mask** tensor.
Returns
-------
dragon.Tensor
The **ref**.
The **ref** tensor..
"""
args = parse_args(locals())
inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype)
op_lib = control_flow_ops_lib.MaskedAssign
if context.executing_eagerly():
return op_lib.instantiate().apply(inputs, mask)
return op_lib.instantiate().apply(inputs)
else:
args.update({
'outputs': [args['inputs'][0]],
'inputs': [args['inputs'][1], mask],
'inputs': [args['inputs'][1:]],
})
return op_lib.blend(**args)
......@@ -47,7 +47,7 @@ class Assign(Operator):
sizes[i], 'int64',
)
def forward(self, ws, inputs, starts, sizes):
def forward(self, inputs, starts, sizes):
return self.dispatch(
[inputs[1]], [inputs[0]],
callback=lambda ws, handle:
......@@ -75,5 +75,5 @@ class MaskedAssign(Operator):
def attributes(self):
return {'op_type': 'MaskedAssign', 'arguments': {}}
def forward(self, inputs, mask):
return self.dispatch([inputs[1], mask], [inputs[0]], no_grad=True)
def forward(self, inputs):
return self.dispatch(inputs[1:], [inputs[0]], no_grad=True)
......@@ -88,9 +88,7 @@ def l1_loss(inputs, reduction='mean', **kwargs):
op_lib = loss_ops_lib.L1Loss
if context.executing_eagerly():
return op_lib \
.instantiate(
reduction=args['reduction'],
).apply(inputs)
.instantiate(reduction=args['reduction']).apply(inputs)
else:
return op_lib.blend(**args)
......
......@@ -16,49 +16,49 @@ from __future__ import print_function
from dragon.core.framework.ops import Operator
class Accumulate(Operator):
class Affine(Operator):
def __init__(self, key, dev, **kwargs):
super(Accumulate, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.)
self.beta = kwargs.get('beta', 1.)
super(Affine, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
self.num_axes = kwargs.get('num_axes', 1)
def attributes(self):
return {
'op_type': 'Accumulate',
'op_type': 'Affine',
'arguments': {
'alpha': self.alpha,
'beta': self.beta,
'axis': self.axis,
'num_axes': self.num_axes,
}
}
def forward(self, inputs, outputs=None):
if outputs is None:
outputs = [self.alloc() for _ in range(len(inputs))]
return self.dispatch(inputs, outputs, no_grad=True)
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()])
class Affine(Operator):
class Axpby(Operator):
def __init__(self, key, dev, **kwargs):
super(Affine, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
self.num_axes = kwargs.get('num_axes', 1)
super(Axpby, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.)
self.beta = kwargs.get('beta', 1.)
def attributes(self):
return {
'op_type': 'Affine',
'op_type': 'Axpby',
'arguments': {
'axis': self.axis,
'num_axes': self.num_axes,
'alpha': self.alpha,
'beta': self.beta,
}
}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()])
def forward(self, inputs, outputs=None):
if outputs is None:
outputs = [self.alloc() for _ in range(len(inputs))]
return self.dispatch(inputs, outputs, no_grad=True)
class Binary(Operator):
class BinaryOp(Operator):
def __init__(self, key, dev, **kwargs):
super(Binary, self).__init__(key, dev, **kwargs)
super(BinaryOp, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
def attributes(self):
......@@ -95,37 +95,18 @@ class Clip(Operator):
return self.dispatch(inputs, [self.alloc()])
class Dot(Operator):
def __init__(self, key, dev, **kwargs):
super(Dot, self).__init__(key, dev, **kwargs)
self.transA = kwargs.get('transA', False)
self.transB = kwargs.get('transB', False)
def attributes(self):
return {
'op_type': 'Dot',
'arguments': {
'transA': self.transA,
'transB': self.transB,
}
}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()])
class FullyConnected(Operator):
def __init__(self, key, dev, **kwargs):
super(FullyConnected, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
self.transW = kwargs.get('transW', True)
self.transpose_w = kwargs.get('transpose_w', True)
def attributes(self):
return {
'op_type': 'FullyConnected',
'arguments': {
'axis': self.axis,
'transW': self.transW,
'transW': self.transpose_w,
}
}
......@@ -133,18 +114,18 @@ class FullyConnected(Operator):
return self.dispatch(inputs, [self.alloc()])
class Matmul(Operator):
class MatMul(Operator):
def __init__(self, key, dev, **kwargs):
super(Matmul, self).__init__(key, dev, **kwargs)
self.transA = kwargs.get('transA', False)
self.transB = kwargs.get('transB', False)
super(MatMul, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False)
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self):
return {
'op_type': 'Matmul',
'op_type': 'MatMul',
'arguments': {
'transA': self.transA,
'transB': self.transB,
'transA': self.transpose_a,
'transB': self.transpose_b,
}
}
......@@ -152,9 +133,9 @@ class Matmul(Operator):
return self.dispatch(inputs, [self.alloc()])
class Unary(Operator):
class UnaryOp(Operator):
def __init__(self, key, dev, **kwargs):
super(Unary, self).__init__(key, dev, **kwargs)
super(UnaryOp, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
def attributes(self):
......
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!