Commit 494774d3 by Ting PAN

Optimize training update operators

Summary:
This commit fuses the weight decay and mixed precision conversion
into update kernels to get lower training latency.
1 parent fb47d86f
Showing with 1659 additions and 1282 deletions
...@@ -418,9 +418,9 @@ class Normalize(Layer): ...@@ -418,9 +418,9 @@ class Normalize(Layer):
def __call__(self, bottom): def __call__(self, bottom):
if len(self.blobs) == 0: if len(self.blobs) == 0:
self.build(bottom) self.build(bottom)
outputs = [normalization_ops.lp_normalize(bottom, **self.norm_args)] outputs = [normalization_ops.lp_norm(bottom, **self.norm_args)]
outputs += [blob['data'] for blob in self.blobs] outputs += [blob['data'] for blob in self.blobs]
return array_ops.channel_affine(outputs, **self.scale_args) return math_ops.affine(outputs, **self.scale_args)
class Permute(Layer): class Permute(Layer):
...@@ -591,8 +591,7 @@ class Scale(Layer): ...@@ -591,8 +591,7 @@ class Scale(Layer):
param = layer_param.scale_param param = layer_param.scale_param
self.axis = param.axis self.axis = param.axis
self.num_axes = param.num_axes self.num_axes = param.num_axes
end_axis = -1 if self.num_axes < 1 else self.axis + self.num_axes - 1 self.call_args = {'axis': list(range(self.axis, self.axis + self.num_axes))}
self.call_args = {'axis': self.axis, 'end_axis': end_axis}
self.filler = caffe_pb2.FillerParameter(type='constant', value=1) self.filler = caffe_pb2.FillerParameter(type='constant', value=1)
self.filler = param.filler if param.HasField('filler') else self.filler self.filler = param.filler if param.HasField('filler') else self.filler
self.bias_filler = param.bias_filler self.bias_filler = param.bias_filler
...@@ -609,7 +608,7 @@ class Scale(Layer): ...@@ -609,7 +608,7 @@ class Scale(Layer):
if len(self.blobs) == 0: if len(self.blobs) == 0:
self.build(bottom) self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self.blobs] inputs = [bottom] + [blob['data'] for blob in self.blobs]
return array_ops.channel_affine(inputs, **self.call_args) return math_ops.affine(inputs, **self.call_args)
class Slice(Layer): class Slice(Layer):
......
...@@ -16,8 +16,8 @@ from __future__ import print_function ...@@ -16,8 +16,8 @@ from __future__ import print_function
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.io.kpl_record import KPLRecordDataset from dragon.core.io.kpl_record import KPLRecordDataset
from dragon.core.ops import array_ops
from dragon.core.ops import framework_ops from dragon.core.ops import framework_ops
from dragon.core.ops import normalization_ops
from dragon.utils import vision from dragon.utils import vision
from dragon.vm.caffe.core.layer import Layer from dragon.vm.caffe.core.layer import Layer
...@@ -121,5 +121,5 @@ class Data(Layer): ...@@ -121,5 +121,5 @@ class Data(Layer):
data._shape = (self.data_args['batch_size'], data._shape = (self.data_args['batch_size'],
None, None, len(self.norm_args['mean'])) None, None, len(self.norm_args['mean']))
label._shape = (self.data_args['batch_size'], None) label._shape = (self.data_args['batch_size'], None)
data = array_ops.channel_normalize(data, **self.norm_args) data = normalization_ops.channel_norm(data, **self.norm_args)
return data, label return data, label
...@@ -9,6 +9,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) ...@@ -9,6 +9,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# ---[ Compiler flags # ---[ Compiler flags
if (MSVC) if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_ENABLE_EXTENDED_ALIGNED_STORAGE")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}
/wd4003 /wd4114 /wd4003 /wd4114
......
...@@ -36,16 +36,6 @@ dragon ...@@ -36,16 +36,6 @@ dragon
`cast(...) <dragon/cast.html>`_ `cast(...) <dragon/cast.html>`_
: Cast the data type of input. : Cast the data type of input.
`channel_affine(...) <dragon/channel_affine.html>`_
: Apply affine transformation to each channel of input.
`channel_normalize(...) <dragon/channel_normalize.html>`_
: Apply normalization to each channel of input.
`channel_shuffle(...) <dragon/channel_shuffle.html>`_
: Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`concat(...) <dragon/concat.html>`_ `concat(...) <dragon/concat.html>`_
: Concatenate the inputs along the given axis. : Concatenate the inputs along the given axis.
...@@ -211,9 +201,6 @@ dragon ...@@ -211,9 +201,6 @@ dragon
dragon/boolean_mask dragon/boolean_mask
dragon/broadcast_to dragon/broadcast_to
dragon/cast dragon/cast
dragon/channel_affine
dragon/channel_normalize
dragon/channel_shuffle
dragon/concat dragon/concat
dragon/constant dragon/constant
dragon/device dragon/device
......
...@@ -24,6 +24,9 @@ dragon.cuda ...@@ -24,6 +24,9 @@ dragon.cuda
`memory_allocated(...) <cuda/memory_allocated.html>`_ `memory_allocated(...) <cuda/memory_allocated.html>`_
: Return the size of memory used by tensors in current workspace. : Return the size of memory used by tensors in current workspace.
`set_cublas_flags(...) <cuda/set_cublas_flags.html>`_
: Set the flags of cuBLAS library.
`set_cudnn_flags(...) <cuda/set_cudnn_flags.html>`_ `set_cudnn_flags(...) <cuda/set_cudnn_flags.html>`_
: Set the flags of cuDNN library. : Set the flags of cuDNN library.
...@@ -44,6 +47,7 @@ dragon.cuda ...@@ -44,6 +47,7 @@ dragon.cuda
cuda/get_device_capability cuda/get_device_capability
cuda/is_available cuda/is_available
cuda/memory_allocated cuda/memory_allocated
cuda/set_cublas_flags
cuda/set_cudnn_flags cuda/set_cudnn_flags
cuda/set_default_device cuda/set_default_device
cuda/set_device cuda/set_device
......
channel_affine set_cublas_flags
============== ================
.. autofunction:: dragon.channel_affine .. autofunction:: dragon.cuda.set_cublas_flags
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon."; content: "dragon.cuda.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -12,12 +12,18 @@ dragon.math ...@@ -12,12 +12,18 @@ dragon.math
`add(...) <math/add.html>`_ `add(...) <math/add.html>`_
: Compute the element-wise addition. : Compute the element-wise addition.
`affine(...) <math/affine.html>`_
: Apply the affine transformation to input.
`argmax(...) <math/argmax.html>`_ `argmax(...) <math/argmax.html>`_
: Compute the index of maximum elements along the given axis. : Compute the index of maximum elements along the given axis.
`argmin(...) <math/argmin.html>`_ `argmin(...) <math/argmin.html>`_
: Compute the index of minimum elements along the given axis. : Compute the index of minimum elements along the given axis.
`atan2(...) <math/atan2.html>`_
: Compute the element-wise arc-tangent of two arguments.
`ceil(...) <math/ceil.html>`_ `ceil(...) <math/ceil.html>`_
: Compute the smallest integer not less than input. : Compute the smallest integer not less than input.
...@@ -81,9 +87,6 @@ dragon.math ...@@ -81,9 +87,6 @@ dragon.math
`logical_xor(...) <math/logical_xor.html>`_ `logical_xor(...) <math/logical_xor.html>`_
: Compute the element-wise XOR logical operation. : Compute the element-wise XOR logical operation.
`lp_normalize(...) <math/lp_normalize.html>`_
: Apply the lp normalization.
`matmul(...) <math/matmul.html>`_ `matmul(...) <math/matmul.html>`_
: Compute the matrix multiplication. : Compute the matrix multiplication.
...@@ -158,8 +161,10 @@ dragon.math ...@@ -158,8 +161,10 @@ dragon.math
math/abs math/abs
math/add math/add
math/affine
math/argmax math/argmax
math/argmin math/argmin
math/atan2
math/ceil math/ceil
math/clip math/clip
math/cos math/cos
...@@ -181,7 +186,6 @@ dragon.math ...@@ -181,7 +186,6 @@ dragon.math
math/logical_not math/logical_not
math/logical_or math/logical_or
math/logical_xor math/logical_xor
math/lp_normalize
math/matmul math/matmul
math/max math/max
math/maximum math/maximum
......
lp_normalize affine
============ ======
.. autofunction:: dragon.math.lp_normalize .. autofunction:: dragon.math.affine
.. raw:: html .. raw:: html
......
channel_normalize atan2
================= =====
.. autofunction:: dragon.channel_normalize .. autofunction:: dragon.math.atan2
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon."; content: "dragon.math.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -28,6 +28,13 @@ dragon.nn ...@@ -28,6 +28,13 @@ dragon.nn
`bias_add(...) <nn/bias_add.html>`_ `bias_add(...) <nn/bias_add.html>`_
: Add the bias across channels to input. : Add the bias across channels to input.
`channel_norm(...) <nn/channel_norm.html>`_
: Apply the normalization to each channel of input.
`channel_shuffle(...) <nn/channel_shuffle.html>`_
: Apply the group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`conv(...) <nn/conv.html>`_ `conv(...) <nn/conv.html>`_
: Apply the n-dimension convolution. : Apply the n-dimension convolution.
...@@ -107,6 +114,9 @@ dragon.nn ...@@ -107,6 +114,9 @@ dragon.nn
`log_softmax(...) <nn/log_softmax.html>`_ `log_softmax(...) <nn/log_softmax.html>`_
: Compute the composite of logarithm and softmax. : Compute the composite of logarithm and softmax.
`lp_norm(...) <nn/lp_norm.html>`_
: Apply the lp normalization.
`moments(...) <nn/moments.html>`_ `moments(...) <nn/moments.html>`_
: Compute the mean and variance of input along the given axis. : Compute the mean and variance of input along the given axis.
...@@ -157,6 +167,8 @@ dragon.nn ...@@ -157,6 +167,8 @@ dragon.nn
nn/RNN nn/RNN
nn/batch_norm nn/batch_norm
nn/bias_add nn/bias_add
nn/channel_norm
nn/channel_shuffle
nn/conv nn/conv
nn/conv_transpose nn/conv_transpose
nn/conv1d nn/conv1d
...@@ -180,6 +192,7 @@ dragon.nn ...@@ -180,6 +192,7 @@ dragon.nn
nn/leaky_relu nn/leaky_relu
nn/local_response_norm nn/local_response_norm
nn/log_softmax nn/log_softmax
nn/lp_norm
nn/moments nn/moments
nn/pool nn/pool
nn/pool1d nn/pool1d
......
channel_norm
============
.. autofunction:: dragon.nn.channel_norm
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
channel_shuffle channel_shuffle
=============== ===============
.. autofunction:: dragon.channel_shuffle .. autofunction:: dragon.nn.channel_shuffle
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon."; content: "dragon.nn.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
lp_norm
=======
.. autofunction:: dragon.nn.lp_norm
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
...@@ -21,6 +21,9 @@ vm.tensorflow.math ...@@ -21,6 +21,9 @@ vm.tensorflow.math
`argmin(...) <math/argmin.html>`_ `argmin(...) <math/argmin.html>`_
: Compute the index of minimum elements along the given axis. : Compute the index of minimum elements along the given axis.
`atan2(...) <math/atan2.html>`_
: Compute the element-wise arc-tangent of two arguments.
`ceil(...) <math/ceil.html>`_ `ceil(...) <math/ceil.html>`_
: Compute the smallest integer not less than input. : Compute the smallest integer not less than input.
...@@ -134,6 +137,7 @@ vm.tensorflow.math ...@@ -134,6 +137,7 @@ vm.tensorflow.math
math/add_n math/add_n
math/argmax math/argmax
math/argmin math/argmin
math/atan2
math/ceil math/ceil
math/cos math/cos
math/cumsum math/cumsum
......
channel_normalize atan2
================= =====
.. autofunction:: dragon.vm.torch.channel_normalize .. autofunction:: dragon.vm.tensorflow.math.atan2
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "torch."; content: "tf.math.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -51,6 +51,9 @@ vm.torch ...@@ -51,6 +51,9 @@ vm.torch
`argsort(...) <torch/argsort.html>`_ `argsort(...) <torch/argsort.html>`_
: Return the index of sorted elements along the given dimension. : Return the index of sorted elements along the given dimension.
`atan2(...) <torch/atan2.html>`_
: Compute the element-wise arc-tangent of two arguments.
`baddbmm(...) <torch/baddbmm.html>`_ `baddbmm(...) <torch/baddbmm.html>`_
: Add input to the result of batched matrix-matrix multiplication. : Add input to the result of batched matrix-matrix multiplication.
...@@ -75,12 +78,6 @@ vm.torch ...@@ -75,12 +78,6 @@ vm.torch
`ceil(...) <torch/ceil.html>`_ `ceil(...) <torch/ceil.html>`_
: Compute the smallest integer not less than input. : Compute the smallest integer not less than input.
`channel_affine(...) <torch/channel_affine.html>`_
: Apply affine transformation to each channel of input.
`channel_normalize(...) <torch/channel_normalize.html>`_
: Apply normalization to each channel of input.
`chunk(...) <torch/chunk.html>`_ `chunk(...) <torch/chunk.html>`_
: Split input into a specific number of chunks. : Split input into a specific number of chunks.
...@@ -345,6 +342,7 @@ vm.torch ...@@ -345,6 +342,7 @@ vm.torch
torch/argmax torch/argmax
torch/argmin torch/argmin
torch/argsort torch/argsort
torch/atan2
torch/baddbmm torch/baddbmm
torch/bitwise_and torch/bitwise_and
torch/bitwise_not torch/bitwise_not
...@@ -353,8 +351,6 @@ vm.torch ...@@ -353,8 +351,6 @@ vm.torch
torch/bmm torch/bmm
torch/cat torch/cat
torch/ceil torch/ceil
torch/channel_affine
torch/channel_normalize
torch/chunk torch/chunk
torch/clamp torch/clamp
torch/cos torch/cos
......
...@@ -73,6 +73,10 @@ argsort ...@@ -73,6 +73,10 @@ argsort
####### #######
.. automethod:: dragon.vm.torch.Tensor.argsort .. automethod:: dragon.vm.torch.Tensor.argsort
atan2
#####
.. automethod:: dragon.vm.torch.Tensor.atan2
backward backward
######## ########
.. automethod:: dragon.vm.torch.Tensor.backward .. automethod:: dragon.vm.torch.Tensor.backward
...@@ -699,6 +703,7 @@ zero\_ ...@@ -699,6 +703,7 @@ zero\_
.. _torch.argmax(...): argmax.html .. _torch.argmax(...): argmax.html
.. _torch.argmin(...): argmin.html .. _torch.argmin(...): argmin.html
.. _torch.argsort(...): argsort.html .. _torch.argsort(...): argsort.html
.. _torch.atan2(...): atan2.html
.. _torch.baddbmm(...): baddbmm.html .. _torch.baddbmm(...): baddbmm.html
.. _torch.bitwise_and(...): bitwise_and.html .. _torch.bitwise_and(...): bitwise_and.html
.. _torch.bitwise_not(...): bitwise_not.html .. _torch.bitwise_not(...): bitwise_not.html
......
channel_affine atan2
============== =====
.. autofunction:: dragon.vm.torch.channel_affine .. autofunction:: dragon.vm.torch.atan2
.. raw:: html .. raw:: html
......
...@@ -6,12 +6,16 @@ vm.torch.backends ...@@ -6,12 +6,16 @@ vm.torch.backends
Modules Modules
------- -------
`Module cuda <backends/cuda.html>`_
: The CUDA backend module.
`Module cudnn <backends/cudnn.html>`_ `Module cudnn <backends/cudnn.html>`_
: The cuDNN backend module. : The cuDNN backend module.
.. toctree:: .. toctree::
:hidden: :hidden:
backends/cuda
backends/cudnn backends/cudnn
.. raw:: html .. raw:: html
......
cuda
====
Properties
----------
matmul.allow_tf32
#################
.. data:: dragon.vm.torch.backends.cuda.matmul.allow_tf32
:annotation: = False
The flag that allows TF32 math type for matmul or not.
Functions
---------
is_built
########
.. automethod:: dragon.vm.torch.backends.cuda.is_built
.. raw:: html
<style>
h1:before {
content: "torch.backends.";
color: #103d3e;
}
</style>
...@@ -24,8 +24,8 @@ vm.torch.nn ...@@ -24,8 +24,8 @@ vm.torch.nn
`class AdaptiveMaxPool3d <nn/AdaptiveMaxPool3d.html>`_ `class AdaptiveMaxPool3d <nn/AdaptiveMaxPool3d.html>`_
: Apply the 3d adaptive max pooling. : Apply the 3d adaptive max pooling.
`class AffineChannel <nn/AffineChannel.html>`_ `class Affine <nn/Affine.html>`_
: Apply affine transformation along the channels. : Apply the affine transformation.
`class AvgPool1d <nn/AvgPool1d.html>`_ `class AvgPool1d <nn/AvgPool1d.html>`_
: Apply the 1d average pooling. : Apply the 1d average pooling.
...@@ -312,7 +312,7 @@ vm.torch.nn ...@@ -312,7 +312,7 @@ vm.torch.nn
nn/AdaptiveMaxPool1d nn/AdaptiveMaxPool1d
nn/AdaptiveMaxPool2d nn/AdaptiveMaxPool2d
nn/AdaptiveMaxPool3d nn/AdaptiveMaxPool3d
nn/AffineChannel nn/Affine
nn/AvgPool1d nn/AvgPool1d
nn/AvgPool2d nn/AvgPool2d
nn/AvgPool3d nn/AvgPool3d
......
AffineChannel Affine
============= ======
.. autoclass:: dragon.vm.torch.nn.AffineChannel .. autoclass:: dragon.vm.torch.nn.Affine
__init__ __init__
-------- --------
.. automethod:: dragon.vm.torch.nn.AffineChannel.__init__ .. automethod:: dragon.vm.torch.nn.Affine.__init__
.. _torch.channel_affine(...): ../channel_affine.html .. _torch.nn.functional.affine(...): functional/affine.html
.. raw:: html .. raw:: html
......
...@@ -24,6 +24,9 @@ vm.torch.nn.functional ...@@ -24,6 +24,9 @@ vm.torch.nn.functional
`adaptive_max_pool3d(...) <functional/adaptive_max_pool3d.html>`_ `adaptive_max_pool3d(...) <functional/adaptive_max_pool3d.html>`_
: Apply the 3d adaptive max pooling to input. : Apply the 3d adaptive max pooling to input.
`affine(...) <functional/affine.html>`_
: Apply the affine transformation to input.
`avg_pool1d(...) <functional/avg_pool1d.html>`_ `avg_pool1d(...) <functional/avg_pool1d.html>`_
: Apply the 1d average pooling to input. : Apply the 1d average pooling to input.
...@@ -40,8 +43,11 @@ vm.torch.nn.functional ...@@ -40,8 +43,11 @@ vm.torch.nn.functional
`binary_cross_entropy_with_logits(...) <functional/binary_cross_entropy_with_logits.html>`_ `binary_cross_entropy_with_logits(...) <functional/binary_cross_entropy_with_logits.html>`_
: Compute the sigmoid cross entropy with contiguous target. : Compute the sigmoid cross entropy with contiguous target.
`channel_norm(...) <nn/channel_norm.html>`_
: Apply the normalization to each channel of input.
`channel_shuffle(...) <functional/channel_shuffle.html>`_ `channel_shuffle(...) <functional/channel_shuffle.html>`_
: Apply group shuffle to each channel of input. : Apply the group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_. `[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`conv1d(...) <functional/conv1d.html>`_ `conv1d(...) <functional/conv1d.html>`_
...@@ -229,11 +235,13 @@ vm.torch.nn.functional ...@@ -229,11 +235,13 @@ vm.torch.nn.functional
functional/adaptive_max_pool1d functional/adaptive_max_pool1d
functional/adaptive_max_pool2d functional/adaptive_max_pool2d
functional/adaptive_max_pool3d functional/adaptive_max_pool3d
functional/affine
functional/avg_pool1d functional/avg_pool1d
functional/avg_pool2d functional/avg_pool2d
functional/avg_pool3d functional/avg_pool3d
functional/batch_norm functional/batch_norm
functional/binary_cross_entropy_with_logits functional/binary_cross_entropy_with_logits
functional/channel_norm
functional/channel_shuffle functional/channel_shuffle
functional/conv1d functional/conv1d
functional/conv2d functional/conv2d
......
affine
======
.. autofunction:: dragon.vm.torch.nn.functional.affine
.. _torch.nn.affine(...): ../Affine.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
channel_norm
============
.. autofunction:: dragon.vm.torch.nn.functional.channel_norm
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
...@@ -56,15 +56,16 @@ class CUDAObjects { ...@@ -56,15 +56,16 @@ class CUDAObjects {
auto& handle = handles[stream_id]; auto& handle = handles[stream_id];
CUBLAS_CHECK(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST)); CUBLAS_CHECK(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSetStream(handle, stream(device_id, stream_id))); CUBLAS_CHECK(cublasSetStream(handle, stream(device_id, stream_id)));
}
auto& handle = handles[stream_id];
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
if (cudnn_allow_tf32_) { if (cublas_allow_tf32_) {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH)); CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else { } else {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} }
#endif #endif
} return handle;
return handles[stream_id];
} }
/*! \brief Return the specified cudnn handle */ /*! \brief Return the specified cudnn handle */
...@@ -150,6 +151,9 @@ class CUDAObjects { ...@@ -150,6 +151,9 @@ class CUDAObjects {
Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES]; Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES];
#endif #endif
/*! \brief The flag that allows cuBLAS TF32 math type or not */
bool cublas_allow_tf32_ = false;
/*! \brief The flag that uses cuDNN or not */ /*! \brief The flag that uses cuDNN or not */
bool cudnn_enabled_ = true; bool cudnn_enabled_ = true;
......
...@@ -20,32 +20,32 @@ namespace dragon { ...@@ -20,32 +20,32 @@ namespace dragon {
/*! /*!
* \brief Registry to create class instances. * \brief Registry to create class instances.
*/ */
template <class KeyType, class ObjectType, class... Args> template <class KeyT, class ClassT, class... Args>
class Registry { class Registry {
public: public:
typedef std::function<ObjectType*(Args...)> Creator; typedef std::function<ClassT*(Args...)> Creator;
/*! \brief Create an instance of specified class */ /*! \brief Create an instance of specified class */
ObjectType* Create(const KeyType& key, Args... args) { ClassT* Create(const KeyT& key, Args... args) {
CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered."; CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered.";
return registry_[key](args...); return registry_[key](args...);
} }
/*! \brief Return whether the specified class is registered */ /*! \brief Return whether the specified class is registered */
bool Has(const KeyType& key) { bool Has(const KeyT& key) {
return (registry_.count(key)) != 0; return (registry_.count(key)) != 0;
} }
/*! \brief Register a class with the creator */ /*! \brief Register a class with the creator */
void Register(const KeyType& key, Creator creator) { void Register(const KeyT& key, Creator creator) {
CHECK(!registry_.count(key)) CHECK(!registry_.count(key))
<< "\nKey(" << key << ") has already registered."; << "\nKey(" << key << ") has already registered.";
registry_[key] = creator; registry_[key] = creator;
} }
/*! \brief Return the key of registered classes */ /*! \brief Return the key of registered classes */
vector<KeyType> keys() { vector<KeyT> keys() {
vector<KeyType> ret; vector<KeyT> ret;
for (const auto& it : registry_) { for (const auto& it : registry_) {
ret.push_back(it.first); ret.push_back(it.first);
} }
...@@ -54,50 +54,49 @@ class Registry { ...@@ -54,50 +54,49 @@ class Registry {
private: private:
/*! \brief The registry map */ /*! \brief The registry map */
Map<KeyType, Creator> registry_; Map<KeyT, Creator> registry_;
}; };
/*! /*!
* \brief Register creator into the registry. * \brief Register creator into the registry.
*/ */
template <class KeyType, class ObjectType, class... Args> template <class KeyT, class ClassT, class... Args>
class Registerer { class Registerer {
public: public:
/*! \brief Constructor with key and creator */ /*! \brief Constructor with key and creator */
Registerer( Registerer(
const KeyType& key, const KeyT& key,
Registry<KeyType, ObjectType, Args...>* registry, Registry<KeyT, ClassT, Args...>* registry,
typename Registry<KeyType, ObjectType, Args...>::Creator creator, typename Registry<KeyT, ClassT, Args...>::Creator creator,
const string& help_msg = "") { const string& help_msg = "") {
registry->Register(key, creator); registry->Register(key, creator);
} }
/*! \brief Return the default creator */ /*! \brief Return the default creator */
template <class DerivedType> template <class DerivedT>
static ObjectType* DefaultCreator(Args... args) { static ClassT* DefaultCreator(Args... args) {
return new DerivedType(args...); return new DerivedT(args...);
} }
}; };
// Used in *.h files // Used in *.h files.
#define DECLARE_TYPED_REGISTRY(RegistryName, KeyType, ObjectType, ...) \ #define DECLARE_TYPED_REGISTRY(RegistryName, KeyT, ClassT, ...) \
DRAGON_API Registry<KeyType, ObjectType, ##__VA_ARGS__>* RegistryName(); \ DRAGON_API Registry<KeyT, ClassT, ##__VA_ARGS__>* RegistryName(); \
typedef Registerer<KeyType, ObjectType, ##__VA_ARGS__> \ typedef Registerer<KeyT, ClassT, ##__VA_ARGS__> Registerer##RegistryName;
Registerer##RegistryName;
// Used in *.cc files.
// Used in *.cc files #define DEFINE_TYPED_REGISTRY(RegistryName, KeyT, ClassT, ...) \
#define DEFINE_TYPED_REGISTRY(RegistryName, KeyType, ObjectType, ...) \ Registry<KeyT, ClassT, ##__VA_ARGS__>* RegistryName() { \
Registry<KeyType, ObjectType, ##__VA_ARGS__>* RegistryName() { \ static Registry<KeyT, ClassT, ##__VA_ARGS__>* registry = \
static Registry<KeyType, ObjectType, ##__VA_ARGS__>* registry = \ new Registry<KeyT, ClassT, ##__VA_ARGS__>(); \
new Registry<KeyType, ObjectType, ##__VA_ARGS__>(); \
return registry; \ return registry; \
} }
#define DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ #define DECLARE_REGISTRY(RegistryName, ClassT, ...) \
DECLARE_TYPED_REGISTRY(RegistryName, string, ObjectType, ##__VA_ARGS__) DECLARE_TYPED_REGISTRY(RegistryName, string, ClassT, ##__VA_ARGS__)
#define DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ #define DEFINE_REGISTRY(RegistryName, ClassT, ...) \
DEFINE_TYPED_REGISTRY(RegistryName, string, ObjectType, ##__VA_ARGS__) DEFINE_TYPED_REGISTRY(RegistryName, string, ClassT, ##__VA_ARGS__)
#define REGISTER_TYPED_CLASS(RegistryName, key, ...) \ #define REGISTER_TYPED_CLASS(RegistryName, key, ...) \
static Registerer##RegistryName ANONYMOUS_VARIABLE(g_##RegistryName)( \ static Registerer##RegistryName ANONYMOUS_VARIABLE(g_##RegistryName)( \
......
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _ChannelAffine(
const int N,
const int S,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y) {
if (S == 1) {
if (bias != nullptr) {
EigenArrayMap<T>(y, C, N) = (ConstEigenArrayMap<T>(x, C, N).colwise() *
ConstEigenVectorArrayMap<T>(scale, C))
.colwise() +
ConstEigenVectorArrayMap<T>(bias, C);
} else {
EigenArrayMap<T>(y, C, N) = ConstEigenArrayMap<T>(x, C, N).colwise() *
ConstEigenVectorArrayMap<T>(scale, C);
}
return;
}
for (int i = 0; i < N; ++i) {
for (int j = 0; j < C; ++j) {
if (bias != nullptr) {
EigenVectorArrayMap<T>(y, S) =
ConstEigenVectorArrayMap<T>(x, S) * scale[j] + bias[j];
} else {
EigenVectorArrayMap<T>(y, S) =
ConstEigenVectorArrayMap<T>(x, S) * scale[j];
}
x += S;
y += S;
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void ChannelAffine<float16, CPUContext>(
const int N,
const int S,
const int C,
const float16* x,
const float16* w,
const float16* b,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelAffine<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const T* x, \
const T* scale, \
const T* bias, \
T* y, \
CPUContext* ctx) { \
_ChannelAffine(N, S, C, x, scale, bias, y); \
}
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, typename AccT>
__global__ void _ChannelAffine(
const int NxCxS,
const int S,
const int C,
const T* x,
const T* scale,
T* y) {
CUDA_1D_KERNEL_LOOP(i, NxCxS) {
y[i] = convert::To<T>(
convert::To<AccT>(x[i]) *
convert::To<AccT>(__ldg(scale + (i / S) % C)));
}
}
template <typename T, typename AccT>
__global__ void _ChannelAffine(
const int NxCxS,
const int S,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y) {
CUDA_1D_KERNEL_LOOP(i, NxCxS) {
const int j = (i / S) % C;
y[i] = convert::To<T>(
fma(convert::To<AccT>(x[i]),
convert::To<AccT>(__ldg(scale + j)),
convert::To<AccT>(__ldg(bias + j))));
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelAffine<T, CUDAContext>( \
const int N, \
const int S, \
const int C, \
const T* x, \
const T* scale, \
const T* bias, \
T* y, \
CUDAContext* ctx) { \
const auto NxCxS = N * C * S; \
if (bias != nullptr) { \
_ChannelAffine<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_BLOCKS(NxCxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxCxS, \
S, \
C, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<const math::ScalarType<T>::type*>(scale), \
reinterpret_cast<const math::ScalarType<T>::type*>(bias), \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} else { \
_ChannelAffine<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_BLOCKS(NxCxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxCxS, \
S, \
C, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<const math::ScalarType<T>::type*>(scale), \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} \
}
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _ChannelShuffle(
const int N,
const int S,
const int G,
const int K,
const T* x,
T* y) {
for (int i = 0; i < N; ++i) {
for (int gi = 0; gi < G; ++gi) {
for (int ki = 0; ki < K; ++ki) {
std::memcpy(
y + ((i * K + ki) * G + gi) * S,
x + ((i * G + gi) * K + ki) * S,
S * sizeof(T));
}
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelShuffle<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const int G, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_ChannelShuffle(N, S, G, C / G, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
__global__ void _ChannelShuffle(
const int NxCxS,
const int S,
const int G,
const int K,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(index, NxCxS) {
const int j = index % S;
const int gi = index / S % G;
const int ki = index / S / G % K;
const int i = index / S / G / K;
y[index] = x[((i * G + gi) * K + ki) * S + j];
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelShuffle<T, CUDAContext>( \
const int N, \
const int S, \
const int C, \
const int G, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const auto NxCxS = N * C * S; \
_ChannelShuffle<<< \
CUDA_BLOCKS(NxCxS), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(NxCxS, S, G, C / G, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
...@@ -52,14 +52,23 @@ __global__ void _ComputeCounts( ...@@ -52,14 +52,23 @@ __global__ void _ComputeCounts(
CUDAContext* ctx) { \ CUDAContext* ctx) { \
math::Copy(dim, x, y, ctx); \ math::Copy(dim, x, y, ctx); \
auto policy = thrust::cuda::par.on(ctx->cuda_stream()); \ auto policy = thrust::cuda::par.on(ctx->cuda_stream()); \
auto* data = reinterpret_cast<math::ScalarType<T>::type*>(y); \
thrust::device_vector<int> order1(dim), order2(dim); \ thrust::device_vector<int> order1(dim), order2(dim); \
thrust::sequence(policy, order1.begin(), order1.end()); \ thrust::sequence(policy, order1.begin(), order1.end()); \
thrust::sequence(policy, order2.begin(), order2.end()); \ thrust::sequence(policy, order2.begin(), order2.end()); \
thrust::sort_by_key( \ thrust::sort_by_key( \
policy, y, y + dim, order1.begin(), math::LessFunctor<T>()); \ policy, \
data, \
data + dim, \
order1.begin(), \
math::LessFunctor<math::ScalarType<T>::type>()); \
auto last = thrust::unique_by_key( \ auto last = thrust::unique_by_key( \
policy, y, y + dim, order2.begin(), math::EqualFunctor<T>()); \ policy, \
int n = num[0] = last.first - y; \ data, \
data + dim, \
order2.begin(), \
math::EqualFunctor<math::ScalarType<T>::type>()); \
int n = num[0] = last.first - data; \
if (inverse_index) { \ if (inverse_index) { \
_RemapInverse<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _RemapInverse<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
dim, n, order1.data(), order2.data(), inverse_index); \ dim, n, order1.data(), order2.data(), inverse_index); \
......
...@@ -8,7 +8,7 @@ namespace kernels { ...@@ -8,7 +8,7 @@ namespace kernels {
namespace { namespace {
template <typename InputT, typename OutputT> template <typename InputT, typename OutputT>
void _ChannelNormalize( void _ChannelNorm(
const int axis, const int axis,
const int num_dims, const int num_dims,
const int64_t* x_strides, const int64_t* x_strides,
...@@ -19,15 +19,14 @@ void _ChannelNormalize( ...@@ -19,15 +19,14 @@ void _ChannelNormalize(
OutputT* y) { OutputT* y) {
const auto N = math::utils::Prod(num_dims, y_dims); const auto N = math::utils::Prod(num_dims, y_dims);
vec64_t idx(num_dims, 0); vec64_t idx(num_dims, 0);
int64_t xi, wi;
for (int yi = 0; yi < N; ++yi) { for (int yi = 0; yi < N; ++yi) {
xi = 0; int64_t xi = 0, wi;
for (int d = num_dims - 1; d >= 0; --d) { for (int d = num_dims - 1; d >= 0; --d) {
xi += idx[d] * x_strides[d]; xi += idx[d] * x_strides[d];
if (d == axis) wi = idx[d]; if (d == axis) wi = idx[d];
} }
y[yi] = const float val = convert::To<float>(x[xi]);
convert::To<OutputT>((convert::To<float>(x[xi]) - mean[wi]) / std[wi]); y[yi] = convert::To<OutputT>((val - mean[wi]) / std[wi]);
math::utils::IncreaseIndexInDims(num_dims, y_dims, idx.data()); math::utils::IncreaseIndexInDims(num_dims, y_dims, idx.data());
} }
} }
...@@ -38,7 +37,7 @@ void _ChannelNormalize( ...@@ -38,7 +37,7 @@ void _ChannelNormalize(
#define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void ChannelNormalize<InputT, OutputT, CPUContext>( \ void ChannelNorm<InputT, OutputT, CPUContext>( \
const int axis, \ const int axis, \
const int num_dims, \ const int num_dims, \
const int64_t* x_strides, \ const int64_t* x_strides, \
...@@ -48,7 +47,7 @@ void _ChannelNormalize( ...@@ -48,7 +47,7 @@ void _ChannelNormalize(
const float* std, \ const float* std, \
OutputT* y, \ OutputT* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_ChannelNormalize(axis, num_dims, x_strides, y_dims, x, mean, std, y); \ _ChannelNorm(axis, num_dims, x_strides, y_dims, x, mean, std, y); \
} }
DEFINE_KERNEL_LAUNCHER(uint8_t, float16); DEFINE_KERNEL_LAUNCHER(uint8_t, float16);
......
...@@ -11,7 +11,7 @@ namespace kernels { ...@@ -11,7 +11,7 @@ namespace kernels {
namespace { namespace {
template <typename InputT, typename OutputT, int D> template <typename InputT, typename OutputT, int D>
__global__ void _ChannelNormalize( __global__ void _ChannelNorm(
const int N, const int N,
const int axis, const int axis,
const int num_dims, const int num_dims,
...@@ -40,7 +40,7 @@ __global__ void _ChannelNormalize( ...@@ -40,7 +40,7 @@ __global__ void _ChannelNormalize(
#define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void ChannelNormalize<InputT, OutputT, CUDAContext>( \ void ChannelNorm<InputT, OutputT, CUDAContext>( \
const int axis, \ const int axis, \
const int num_dims, \ const int num_dims, \
const int64_t* x_strides, \ const int64_t* x_strides, \
...@@ -57,11 +57,7 @@ __global__ void _ChannelNormalize( ...@@ -57,11 +57,7 @@ __global__ void _ChannelNormalize(
X_strides.data[i] = x_strides[i]; \ X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \ Y_dims.data[i] = y_dims[i]; \
} \ } \
_ChannelNormalize<<< \ _ChannelNorm<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
CUDA_BLOCKS(N), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N, axis, num_dims, X_strides, Y_dims, x, mean, std, y); \ N, axis, num_dims, X_strides, Y_dims, x, mean, std, y); \
} }
......
...@@ -8,7 +8,7 @@ namespace kernels { ...@@ -8,7 +8,7 @@ namespace kernels {
namespace { namespace {
template <typename T> template <typename T>
void _L1Normalize( void _L1Norm(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -28,7 +28,7 @@ void _L1Normalize( ...@@ -28,7 +28,7 @@ void _L1Normalize(
} }
template <typename T> template <typename T>
void _L2Normalize( void _L2Norm(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -48,7 +48,7 @@ void _L2Normalize( ...@@ -48,7 +48,7 @@ void _L2Normalize(
} }
template <typename T> template <typename T>
void _L1NormalizeGrad( void _L1NormGrad(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -73,7 +73,7 @@ void _L1NormalizeGrad( ...@@ -73,7 +73,7 @@ void _L1NormalizeGrad(
} }
template <typename T> template <typename T>
void _L2NormalizeGrad( void _L2NormGrad(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -101,7 +101,7 @@ void _L2NormalizeGrad( ...@@ -101,7 +101,7 @@ void _L2NormalizeGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> template <>
void L1Normalize<float16, CPUContext>( void L1Norm<float16, CPUContext>(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -114,7 +114,7 @@ void L1Normalize<float16, CPUContext>( ...@@ -114,7 +114,7 @@ void L1Normalize<float16, CPUContext>(
} }
template <> template <>
void L2Normalize<float16, CPUContext>( void L2Norm<float16, CPUContext>(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -127,7 +127,7 @@ void L2Normalize<float16, CPUContext>( ...@@ -127,7 +127,7 @@ void L2Normalize<float16, CPUContext>(
} }
template <> template <>
void L1NormalizeGrad<float16, CPUContext>( void L1NormGrad<float16, CPUContext>(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -138,10 +138,10 @@ void L1NormalizeGrad<float16, CPUContext>( ...@@ -138,10 +138,10 @@ void L1NormalizeGrad<float16, CPUContext>(
float16* dx, float16* dx,
CPUContext* ctx) { CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} // L1NormalizeGrad } // L1NormGrad
template <> template <>
void L2NormalizeGrad<float16, CPUContext>( void L2NormGrad<float16, CPUContext>(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -152,7 +152,7 @@ void L2NormalizeGrad<float16, CPUContext>( ...@@ -152,7 +152,7 @@ void L2NormalizeGrad<float16, CPUContext>(
float16* dx, float16* dx,
CPUContext* ctx) { CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} // L2NormalizeGrad } // L2NormGrad
#define DEFINE_KERNEL_LAUNCHER(name, T) \ #define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
...@@ -183,14 +183,14 @@ void L2NormalizeGrad<float16, CPUContext>( ...@@ -183,14 +183,14 @@ void L2NormalizeGrad<float16, CPUContext>(
_##name<T>(N, S, C, normalizer, eps, dy, x, dx); \ _##name<T>(N, S, C, normalizer, eps, dy, x, dx); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float); DEFINE_KERNEL_LAUNCHER(L1Norm, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double); DEFINE_KERNEL_LAUNCHER(L1Norm, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float); DEFINE_KERNEL_LAUNCHER(L2Norm, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double); DEFINE_KERNEL_LAUNCHER(L2Norm, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -12,7 +12,7 @@ namespace kernels { ...@@ -12,7 +12,7 @@ namespace kernels {
namespace { namespace {
template <typename T, typename AccT> template <typename T, typename AccT>
__global__ void _L1Normalize( __global__ void _L1Norm(
const int NxS, const int NxS,
const int S, const int S,
const int C, const int C,
...@@ -41,7 +41,7 @@ __global__ void _L1Normalize( ...@@ -41,7 +41,7 @@ __global__ void _L1Normalize(
} }
template <typename T, typename AccT> template <typename T, typename AccT>
__global__ void _L2Normalize( __global__ void _L2Norm(
const int NxS, const int NxS,
const int S, const int S,
const int C, const int C,
...@@ -70,7 +70,7 @@ __global__ void _L2Normalize( ...@@ -70,7 +70,7 @@ __global__ void _L2Normalize(
} }
template <typename T, typename AccT> template <typename T, typename AccT>
__global__ void _L1NormalizeGrad( __global__ void _L1NormGrad(
const int NxS, const int NxS,
const int S, const int S,
const int C, const int C,
...@@ -107,7 +107,7 @@ __global__ void _L1NormalizeGrad( ...@@ -107,7 +107,7 @@ __global__ void _L1NormalizeGrad(
} }
template <typename T, typename AccT> template <typename T, typename AccT>
__global__ void _L2NormalizeGrad( __global__ void _L2NormGrad(
const int NxS, const int NxS,
const int S, const int S,
const int C, const int C,
...@@ -195,18 +195,18 @@ __global__ void _L2NormalizeGrad( ...@@ -195,18 +195,18 @@ __global__ void _L2NormalizeGrad(
reinterpret_cast<math::ScalarType<T>::type*>(dx)); \ reinterpret_cast<math::ScalarType<T>::type*>(dx)); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float16, float); DEFINE_KERNEL_LAUNCHER(L1Norm, float16, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, float, float); DEFINE_KERNEL_LAUNCHER(L1Norm, float, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double, double); DEFINE_KERNEL_LAUNCHER(L1Norm, double, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float16, float); DEFINE_KERNEL_LAUNCHER(L2Norm, float16, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float, float); DEFINE_KERNEL_LAUNCHER(L2Norm, float, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double, double); DEFINE_KERNEL_LAUNCHER(L2Norm, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float16, float); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, float16, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float, float); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double, double); DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16, float); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, float16, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float, float); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double, double); DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, double, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
namespace kernels { namespace kernels {
template <> namespace {
void Adam<float, CPUContext>(
template <typename T, typename CopyT>
void _Adam(
const int N, const int N,
const float lr, const T lr,
const float beta1, const T beta1,
const float beta2, const T beta2,
const float eps, const T eps,
float* g, const T wd,
float* m, const T* x,
float* v, const T* g,
CPUContext* ctx) { T* m,
T* v,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
float gi = g[i]; const T gi = wd > T(0) ? std::fma(wd, x[i], g[i]) : g[i];
float mi = m[i] = m[i] * beta1 + gi * (1 - beta1); const T mi = m[i] = std::fma(beta1, m[i], (T(1) - beta1) * gi);
float vi = v[i] = v[i] * beta2 + gi * gi * (1 - beta2); const T vi = v[i] = std::fma(beta2, v[i], (T(1) - beta2) * gi * gi);
g[i] = lr * mi / (std::sqrt(vi) + eps); y[i] -= lr * mi / (std::sqrt(vi) + eps);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
template <> template <typename T, typename CopyT>
void AdamW<float, CPUContext>( void _AdamW(
const int N, const int N,
const float lr, const T lr,
const float beta1, const T beta1,
const float beta2, const T beta2,
const float eps, const T eps,
const float wd, const T wd,
const float* x, const T* x,
float* g, const T* g,
float* m, T* m,
float* v, T* v,
CPUContext* ctx) { T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
float gi = g[i]; const T gi = g[i];
float mi = m[i] = m[i] * beta1 + gi * (1 - beta1); const T mi = m[i] = std::fma(beta1, m[i], (T(1) - beta1) * gi);
float vi = v[i] = v[i] * beta2 + gi * gi * (1 - beta2); const T vi = v[i] = std::fma(beta2, v[i], (T(1) - beta2) * gi * gi);
g[i] = lr * mi / (std::sqrt(vi) + eps) + wd * x[i]; y[i] -= wd > T(0) ? std::fma(wd, x[i], lr * mi / (std::sqrt(vi) + eps))
: lr * mi / (std::sqrt(vi) + eps);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
} // namespace
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CPUContext>( \
const int N, \
const float lr, \
const float beta1, \
const float beta2, \
const float eps, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* v, \
T* y, \
CopyT* y_copy, \
CPUContext* ctx) { \
_##name( \
N, \
convert::To<T>(lr), \
convert::To<T>(beta1), \
convert::To<T>(beta2), \
convert::To<T>(eps), \
convert::To<T>(wd), \
x, \
g, \
m, \
v, \
y, \
y_copy); \
}
DEFINE_KERNEL_LAUNCHER(Adam, float, float16);
DEFINE_KERNEL_LAUNCHER(Adam, float, float);
DEFINE_KERNEL_LAUNCHER(Adam, double, double);
DEFINE_KERNEL_LAUNCHER(AdamW, float, float16);
DEFINE_KERNEL_LAUNCHER(AdamW, float, float);
DEFINE_KERNEL_LAUNCHER(AdamW, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels } // namespace kernels
} // namespace dragon } // namespace dragon
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -9,25 +10,32 @@ namespace kernels { ...@@ -9,25 +10,32 @@ namespace kernels {
namespace { namespace {
template <typename T> template <typename T, typename CopyT>
__global__ void _Adam( __global__ void _Adam(
const int N, const int N,
const T lr, const T lr,
const T beta1, const T beta1,
const T beta2, const T beta2,
const T eps, const T eps,
T* g, const T wd,
const T* x,
const T* g,
T* m, T* m,
T* v) { T* v,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
T gi = g[i]; const T gi = wd > T(0) ? fma(wd, x[i], g[i]) : g[i];
T mi = m[i] = m[i] * beta1 + gi * (1 - beta1); const T mi = m[i] = fma(beta1, m[i], (T(1) - beta1) * gi);
T vi = v[i] = v[i] * beta2 + gi * gi * (1 - beta2); const T vi = v[i] = fma(beta2, v[i], (T(1) - beta2) * gi * gi);
g[i] = lr * mi / (sqrt(vi) + eps); y[i] -= lr * mi / (sqrt(vi) + eps);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
template <typename T> template <typename T, typename CopyT>
__global__ void _AdamW( __global__ void _AdamW(
const int N, const int N,
const T lr, const T lr,
...@@ -36,14 +44,20 @@ __global__ void _AdamW( ...@@ -36,14 +44,20 @@ __global__ void _AdamW(
const T eps, const T eps,
const T wd, const T wd,
const T* x, const T* x,
T* g, const T* g,
T* m, T* m,
T* v) { T* v,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
T gi = g[i]; const T gi = g[i];
T mi = m[i] = m[i] * beta1 + gi * (1 - beta1); const T mi = m[i] = fma(beta1, m[i], (T(1) - beta1) * gi);
T vi = v[i] = v[i] * beta2 + gi * gi * (1 - beta2); const T vi = v[i] = fma(beta2, v[i], (T(1) - beta2) * gi * gi);
g[i] = lr * mi / (sqrt(vi) + eps) + wd * x[i]; y[i] -= wd > T(0) ? fma(wd, x[i], lr * mi / (sqrt(vi) + eps))
: lr * mi / (sqrt(vi) + eps);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
...@@ -51,37 +65,44 @@ __global__ void _AdamW( ...@@ -51,37 +65,44 @@ __global__ void _AdamW(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> #define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
void Adam<float, CUDAContext>( template <> \
const int N, void name<T, CopyT, CUDAContext>( \
const float lr, const int N, \
const float beta1, const float lr, \
const float beta2, const float beta1, \
const float eps, const float beta2, \
float* g, const float eps, \
float* m, const float wd, \
float* v, const T* x, \
CUDAContext* ctx) { const T* g, \
_Adam<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( T* m, \
N, lr, beta1, beta2, eps, g, m, v); T* v, \
} T* y, \
CopyT* y_copy, \
CUDAContext* ctx) { \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
convert::To<T>(lr), \
convert::To<T>(beta1), \
convert::To<T>(beta2), \
convert::To<T>(eps), \
convert::To<T>(wd), \
x, \
g, \
m, \
v, \
y, \
reinterpret_cast<math::ScalarType<CopyT>::type*>(y_copy)); \
}
template <> DEFINE_KERNEL_LAUNCHER(Adam, float, float16);
void AdamW<float, CUDAContext>( DEFINE_KERNEL_LAUNCHER(Adam, float, float);
const int N, DEFINE_KERNEL_LAUNCHER(Adam, double, double);
const float lr, DEFINE_KERNEL_LAUNCHER(AdamW, float, float16);
const float beta1, DEFINE_KERNEL_LAUNCHER(AdamW, float, float);
const float beta2, DEFINE_KERNEL_LAUNCHER(AdamW, double, double);
const float eps, #undef DEFINE_KERNEL_LAUNCHER
const float wd,
const float* x,
float* g,
float* m,
float* v,
CUDAContext* ctx) {
_AdamW<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, lr, beta1, beta2, eps, wd, x, g, m, v);
}
} // namespace kernels } // namespace kernels
......
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
namespace kernels { namespace kernels {
template <> namespace {
void RMSprop<float, CPUContext>(
template <typename T, typename CopyT>
void _RMSprop(
const int N, const int N,
const float lr, const T lr,
const float momentum, const T momentum,
const float decay, const T alpha,
const float eps, const T eps,
float* g, const T wd,
float* m, const T* x,
float* v, const T* g,
CPUContext* ctx) { T* m,
T* v,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
float gi = g[i]; const T gi = wd > T(0) ? std::fma(wd, x[i], g[i]) : g[i];
float vi = v[i] = decay * v[i] + (1 - decay) * gi * gi; const T vi = v[i] = std::fma(alpha, v[i], (T(1) - alpha) * gi * gi);
float mi = m[i] = std::fma(momentum, m[i], gi / (std::sqrt(vi) + eps)); const T mi = m[i] = std::fma(momentum, m[i], gi / (std::sqrt(vi) + eps));
g[i] = lr * mi; y[i] -= lr * mi;
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
} // namespace
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CPUContext>( \
const int N, \
const float lr, \
const float momentum, \
const float alpha, \
const float eps, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* v, \
T* y, \
CopyT* y_copy, \
CPUContext* ctx) { \
_##name( \
N, \
convert::To<T>(lr), \
convert::To<T>(momentum), \
convert::To<T>(alpha), \
convert::To<T>(eps), \
convert::To<T>(wd), \
x, \
g, \
m, \
v, \
y, \
y_copy); \
}
DEFINE_KERNEL_LAUNCHER(RMSprop, float, float16);
DEFINE_KERNEL_LAUNCHER(RMSprop, float, float);
DEFINE_KERNEL_LAUNCHER(RMSprop, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels } // namespace kernels
} // namespace dragon } // namespace dragon
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -9,21 +10,28 @@ namespace kernels { ...@@ -9,21 +10,28 @@ namespace kernels {
namespace { namespace {
template <typename T> template <typename T, typename CopyT>
__global__ void _RMSprop( __global__ void _RMSprop(
const int N, const int N,
const T lr, const T lr,
const T momentum, const T momentum,
const T decay, const T alpha,
const T eps, const T eps,
T* g, const T wd,
const T* x,
const T* g,
T* m, T* m,
T* v) { T* v,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
T gi = g[i]; const T gi = wd > T(0) ? fma(wd, x[i], g[i]) : g[i];
T vi = v[i] = decay * v[i] + (1 - decay) * gi * gi; const T vi = v[i] = fma(alpha, v[i], (T(1) - alpha) * gi * gi);
T mi = m[i] = fma(momentum, m[i], gi / (sqrt(vi) + eps)); const T mi = m[i] = fma(momentum, m[i], gi / (std::sqrt(vi) + eps));
g[i] = lr * mi; y[i] -= lr * mi;
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
...@@ -31,20 +39,41 @@ __global__ void _RMSprop( ...@@ -31,20 +39,41 @@ __global__ void _RMSprop(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> #define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
void RMSprop<float, CUDAContext>( template <> \
const int N, void name<T, CopyT, CUDAContext>( \
const float lr, const int N, \
const float momentum, const float lr, \
const float decay, const float momentum, \
const float eps, const float alpha, \
float* g, const float eps, \
float* m, const float wd, \
float* v, const T* x, \
CUDAContext* ctx) { const T* g, \
_RMSprop<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( T* m, \
N, lr, momentum, decay, eps, g, m, v); T* v, \
} T* y, \
CopyT* y_copy, \
CUDAContext* ctx) { \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
convert::To<T>(lr), \
convert::To<T>(momentum), \
convert::To<T>(alpha), \
convert::To<T>(eps), \
convert::To<T>(wd), \
x, \
g, \
m, \
v, \
y, \
reinterpret_cast<math::ScalarType<CopyT>::type*>(y_copy)); \
}
DEFINE_KERNEL_LAUNCHER(RMSprop, float, float16);
DEFINE_KERNEL_LAUNCHER(RMSprop, float, float);
DEFINE_KERNEL_LAUNCHER(RMSprop, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels } // namespace kernels
......
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
namespace kernels { namespace kernels {
template <> namespace {
void MomentumSGD<float, CPUContext>(
template <typename T, typename CopyT>
void _MomentumSGD(
const int N, const int N,
const float lr, const T lr,
const float momentum, const T momentum,
float* g, const T wd,
float* m, const T* x,
CPUContext* ctx) { const T* g,
T* m,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
float mi = m[i] = std::fma(momentum, m[i], g[i]); const T gi = wd > T(0) ? std::fma(wd, x[i], g[i]) : g[i];
g[i] = lr * mi; const T mi = m[i] = std::fma(momentum, m[i], gi);
y[i] -= lr * mi;
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
template <> template <typename T, typename CopyT>
void NesterovSGD<float, CPUContext>( void _NesterovSGD(
const int N, const int N,
const float lr, const T lr,
const float momentum, const T momentum,
float* g, const T wd,
float* m, const T* x,
CPUContext* ctx) { const T* g,
T* m,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
float gi = g[i]; const T gi = wd > T(0) ? std::fma(wd, x[i], g[i]) : g[i];
float mi = m[i] = std::fma(momentum, m[i], gi); const T mi = m[i] = std::fma(momentum, m[i], gi);
g[i] = lr * std::fma(momentum, mi, gi); y[i] -= lr * std::fma(momentum, mi, gi);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
} // namespace
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CPUContext>( \
const int N, \
const float lr, \
const float momentum, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* y, \
CopyT* y_copy, \
CPUContext* ctx) { \
_##name( \
N, \
convert::To<T>(lr), \
convert::To<T>(momentum), \
convert::To<T>(wd), \
x, \
g, \
m, \
y, \
y_copy); \
}
DEFINE_KERNEL_LAUNCHER(MomentumSGD, float, float16);
DEFINE_KERNEL_LAUNCHER(MomentumSGD, float, float);
DEFINE_KERNEL_LAUNCHER(MomentumSGD, double, double);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, float, float16);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, float, float);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels } // namespace kernels
} // namespace dragon } // namespace dragon
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -9,22 +10,45 @@ namespace kernels { ...@@ -9,22 +10,45 @@ namespace kernels {
namespace { namespace {
template <typename T> template <typename T, typename CopyT>
__global__ void __global__ void _MomentumSGD(
_MomentumSGD(const int N, const T lr, const T momentum, T* g, T* m) { const int N,
const T lr,
const T momentum,
const T wd,
const T* x,
const T* g,
T* m,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
T mi = m[i] = fma(momentum, m[i], g[i]); const T gi = wd > T(0) ? fma(wd, x[i], g[i]) : g[i];
g[i] = lr * mi; const T mi = m[i] = fma(momentum, m[i], gi);
y[i] -= lr * mi;
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
template <typename T> template <typename T, typename CopyT>
__global__ void __global__ void _NesterovSGD(
_NesterovSGD(const int N, const T lr, const T momentum, T* g, T* m) { const int N,
const T lr,
const T momentum,
const T wd,
const T* x,
const T* g,
T* m,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
T gi = g[i]; const T gi = wd > T(0) ? fma(wd, x[i], g[i]) : g[i];
T mi = m[i] = fma(momentum, m[i], gi); const T mi = m[i] = fma(momentum, m[i], gi);
g[i] = lr * fma(momentum, mi, gi); y[i] -= lr * fma(momentum, mi, gi);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
} }
} }
...@@ -32,29 +56,38 @@ _NesterovSGD(const int N, const T lr, const T momentum, T* g, T* m) { ...@@ -32,29 +56,38 @@ _NesterovSGD(const int N, const T lr, const T momentum, T* g, T* m) {
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> #define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
void MomentumSGD<float, CUDAContext>( template <> \
const int N, void name<T, CopyT, CUDAContext>( \
const float lr, const int N, \
const float momentum, const float lr, \
float* g, const float momentum, \
float* m, const float wd, \
CUDAContext* ctx) { const T* x, \
_MomentumSGD<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( const T* g, \
N, lr, momentum, g, m); T* m, \
} T* y, \
CopyT* y_copy, \
CUDAContext* ctx) { \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
convert::To<T>(lr), \
convert::To<T>(momentum), \
convert::To<T>(wd), \
x, \
g, \
m, \
y, \
reinterpret_cast<math::ScalarType<CopyT>::type*>(y_copy)); \
}
template <> DEFINE_KERNEL_LAUNCHER(MomentumSGD, float, float16);
void NesterovSGD<float, CUDAContext>( DEFINE_KERNEL_LAUNCHER(MomentumSGD, float, float);
const int N, DEFINE_KERNEL_LAUNCHER(MomentumSGD, double, double);
const float lr, DEFINE_KERNEL_LAUNCHER(NesterovSGD, float, float16);
const float momentum, DEFINE_KERNEL_LAUNCHER(NesterovSGD, float, float);
float* g, DEFINE_KERNEL_LAUNCHER(NesterovSGD, double, double);
float* m, #undef DEFINE_KERNEL_LAUNCHER
CUDAContext* ctx) {
_NesterovSGD<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, lr, momentum, g, m);
}
} // namespace kernels } // namespace kernels
......
...@@ -91,16 +91,24 @@ void RegisterModule_cuda(py::module& m) { ...@@ -91,16 +91,24 @@ void RegisterModule_cuda(py::module& m) {
#endif #endif
}); });
/*! \brief Set the flags of cuBLAS library */
m.def("cublasSetFlags", [](int allow_tf32) {
#ifdef USE_CUDA
auto& ctx = CUDAContext::objects();
if (allow_tf32 >= 0) ctx.cublas_allow_tf32_ = allow_tf32;
#endif
});
/*! \brief Set the flags of cuDNN library */ /*! \brief Set the flags of cuDNN library */
m.def( m.def(
"cudnnSetFlags", "cudnnSetFlags",
[](bool enabled, bool benchmark, bool deterministic, bool allow_tf32) { [](int enabled, int benchmark, int deterministic, int allow_tf32) {
#ifdef USE_CUDA #ifdef USE_CUDA
auto& cuda_objects = CUDAContext::objects(); auto& ctx = CUDAContext::objects();
cuda_objects.cudnn_enabled_ = enabled; if (enabled >= 0) ctx.cudnn_enabled_ = enabled;
cuda_objects.cudnn_deterministic_ = deterministic; if (benchmark >= 0) ctx.cudnn_benchmark_ = benchmark;
cuda_objects.cudnn_benchmark_ = benchmark; if (deterministic >= 0) ctx.cudnn_deterministic_ = deterministic;
cuda_objects.cudnn_allow_tf32_ = allow_tf32; if (allow_tf32 >= 0) ctx.cudnn_allow_tf32_ = allow_tf32;
#endif #endif
}); });
......
...@@ -132,8 +132,8 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -132,8 +132,8 @@ PYBIND11_MODULE(libdragon_python, m) {
PRINT(INFO) << GetVerboseDef(def.DebugString(), "graph"); PRINT(INFO) << GetVerboseDef(def.DebugString(), "graph");
} }
} }
// Return the graph name may be different from the def // Return the graph name may be different from the def.
// We will make a unique dummy name on creating the graph // We will make a unique dummy name on creating the graph.
return graph->name(); return graph->name();
}) })
...@@ -175,8 +175,8 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -175,8 +175,8 @@ PYBIND11_MODULE(libdragon_python, m) {
GraphDef init_graph, pred_graph; GraphDef init_graph, pred_graph;
onnx::ONNXBackend onnx_backend; onnx::ONNXBackend onnx_backend;
onnx_backend.Prepare(model_path, &init_graph, &pred_graph); onnx_backend.Prepare(model_path, &init_graph, &pred_graph);
// Serializing to Python is intractable // Serializing to Python is intractable.
// We should apply the initializer immediately // We should apply the initializer immediately.
self->RunGraph(self->CreateGraph(init_graph)->name()); self->RunGraph(self->CreateGraph(init_graph)->name());
return py::bytes(pred_graph.SerializeAsString()); return py::bytes(pred_graph.SerializeAsString());
}); });
......
...@@ -24,14 +24,14 @@ PythonPluginOp<Context>::PythonPluginOp(const OperatorDef& def, Workspace* ws) ...@@ -24,14 +24,14 @@ PythonPluginOp<Context>::PythonPluginOp(const OperatorDef& def, Workspace* ws)
Py_Initialize(); Py_Initialize();
auto* module = PyImport_ImportModule(module_name_.c_str()); auto* module = PyImport_ImportModule(module_name_.c_str());
CHECK(module) << "\nFailed to import module: " << module; CHECK(module) << "\nFailed to import module: " << module;
auto* module_dict = PyModule_GetDict(module); auto* module_dict = PyModule_GetDict(module);
auto* op_class = PyDict_GetItemString(module_dict, class_name_.c_str()); auto* op_class = PyDict_GetItemString(module_dict, class_name_.c_str());
CHECK(op_class) << "\nFailed to import class: " << class_name_ CHECK(op_class) << "\nFailed to import class: " << class_name_
<< " from module: " << module_name_; << " from module: " << module_name_;
self_ = PyObject_CallObject(op_class, NULL); self_ = PyObject_CallObject(op_class, NULL);
// Project inputs and outputs. // Set inputs and outputs.
inputs_ = PyList_New(InputSize()); inputs_ = PyList_New(InputSize());
outputs_ = PyList_New(OutputSize()); outputs_ = PyList_New(OutputSize());
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < InputSize(); i++) {
...@@ -41,16 +41,15 @@ PythonPluginOp<Context>::PythonPluginOp(const OperatorDef& def, Workspace* ws) ...@@ -41,16 +41,15 @@ PythonPluginOp<Context>::PythonPluginOp(const OperatorDef& def, Workspace* ws)
PyList_SetItem(outputs_, i, PyBytes_FromStdString(Output(i)->name())); PyList_SetItem(outputs_, i, PyBytes_FromStdString(Output(i)->name()));
} }
// Set: self.kwargs_str // Attr: "kwargs_str"
PyObject_SetAttr( PyObject_SetAttr(
self_, self_,
PyBytes_FromRawString("kwargs_str"), PyBytes_FromRawString("kwargs_str"),
PyBytes_FromStdString(kwargs_str_)); PyBytes_FromStdString(kwargs_str_));
// Method: self.setup(inputs, outputs)
if (PyObject_HasAttr(self_, PyBytes_FromRawString("setup"))) { if (PyObject_HasAttr(self_, PyBytes_FromRawString("setup"))) {
CHECK(PyObject_CallMethod(self_, "setup", "OO", inputs_, outputs_)) CHECK(PyObject_CallMethod(self_, "setup", "OO", inputs_, outputs_))
<< CallMethodHelper("setup"); << CallMethodHelper("setup"); // Method: setup(inputs, outputs)
} }
} }
...@@ -67,27 +66,24 @@ string PythonPluginOp<Context>::CallMethodHelper(const string& method_name) { ...@@ -67,27 +66,24 @@ string PythonPluginOp<Context>::CallMethodHelper(const string& method_name) {
template <class Context> template <class Context>
void PythonPluginOp<Context>::RunOnDevice() { void PythonPluginOp<Context>::RunOnDevice() {
// GIL may have been released // GIL may have been released.
pybind11::gil_scoped_acquire g; pybind11::gil_scoped_acquire g;
// Atrribute: self.phase // Attr: phase
PyObject_SetAttr( PyObject_SetAttr(
self_, PyBytes_FromRawString("phase"), PyBytes_FromStdString(phase())); self_, PyBytes_FromRawString("phase"), PyBytes_FromStdString(phase()));
// Method: self.reshape(input, outputs)
if (PyObject_HasAttr(self_, PyBytes_FromRawString("reshape"))) { if (PyObject_HasAttr(self_, PyBytes_FromRawString("reshape"))) {
CHECK(PyObject_CallMethod(self_, "reshape", "OO", inputs_, outputs_)) CHECK(PyObject_CallMethod(self_, "reshape", "OO", inputs_, outputs_))
<< CallMethodHelper("reshape"); << CallMethodHelper("reshape"); // Method: reshape(input, outputs)
} }
// Method: self.run(input, outputs)
// Method: self.forward(input, outputs)
if (PyObject_HasAttr(self_, PyBytes_FromRawString("forward"))) { if (PyObject_HasAttr(self_, PyBytes_FromRawString("forward"))) {
CHECK(PyObject_CallMethod(self_, "forward", "OO", inputs_, outputs_)) CHECK(PyObject_CallMethod(self_, "forward", "OO", inputs_, outputs_))
<< CallMethodHelper("forward"); << CallMethodHelper("forward"); // Method: run(input, outputs)
} else if (PyObject_HasAttr(self_, PyBytes_FromRawString("run"))) { } else if (PyObject_HasAttr(self_, PyBytes_FromRawString("run"))) {
CHECK(PyObject_CallMethod(self_, "run", "OO", inputs_, outputs_)) CHECK(PyObject_CallMethod(self_, "run", "OO", inputs_, outputs_))
<< CallMethodHelper("run"); << CallMethodHelper("run"); // Method: forward(input, outputs)
} }
} }
......
...@@ -13,7 +13,6 @@ void ONNXBackend::Prepare( ...@@ -13,7 +13,6 @@ void ONNXBackend::Prepare(
ModelProto onnx_model; ModelProto onnx_model;
CHECK(ReadProtoFromBinaryFile(onnx_model_path.c_str(), &onnx_model)) CHECK(ReadProtoFromBinaryFile(onnx_model_path.c_str(), &onnx_model))
<< "\nFailed to parse the onnx model."; << "\nFailed to parse the onnx model.";
int opset_version = -1; int opset_version = -1;
for (const auto& imp : onnx_model.opset_import()) { for (const auto& imp : onnx_model.opset_import()) {
if ((!imp.has_domain()) || imp.domain().empty()) { if ((!imp.has_domain()) || imp.domain().empty()) {
...@@ -31,7 +30,6 @@ void ONNXBackend::Prepare( ...@@ -31,7 +30,6 @@ void ONNXBackend::Prepare(
std::cout << "Unrecognized operator set " << opset_version << std::endl; std::cout << "Unrecognized operator set " << opset_version << std::endl;
} }
} }
if (opset_version < 0) { if (opset_version < 0) {
if (onnx_model.ir_version() >= 0x00000003) { if (onnx_model.ir_version() >= 0x00000003) {
LOG(FATAL) << "Model with IR version >= 3 " LOG(FATAL) << "Model with IR version >= 3 "
...@@ -40,7 +38,6 @@ void ONNXBackend::Prepare( ...@@ -40,7 +38,6 @@ void ONNXBackend::Prepare(
opset_version = 1; opset_version = 1;
} }
} }
ONNXToDragon(onnx_model, opset_version, true, init_graph, pred_graph); ONNXToDragon(onnx_model, opset_version, true, init_graph, pred_graph);
} }
...@@ -52,22 +49,23 @@ void ONNXBackend::ONNXToDragon( ...@@ -52,22 +49,23 @@ void ONNXBackend::ONNXToDragon(
GraphDef* pred_graph) { GraphDef* pred_graph) {
ModelProto init_model = ModelProto(); ModelProto init_model = ModelProto();
ModelProto pred_model = onnx_model; ModelProto pred_model = onnx_model;
pred_graph->set_name(onnx_model.graph().name()); pred_graph->set_name(onnx_model.graph().name());
init_graph->set_name(onnx_model.graph().name() + "/init"); init_graph->set_name(onnx_model.graph().name() + "/init");
ValueInfoMap graph_value_infos{}; ValueInfoMap graph_value_infos{};
InitializerMap graph_initializer{}; InitializerMap graph_initializer{};
// Collect graph inputs.
for (const auto& vi : onnx_model.graph().input()) for (const auto& v : onnx_model.graph().input()) {
graph_value_infos[vi.name()].CopyFrom(vi); graph_value_infos[v.name()].CopyFrom(v);
}
for (const auto& vi : onnx_model.graph().output()) // Collect graph outputs.
graph_value_infos[vi.name()].CopyFrom(vi); for (const auto& v : onnx_model.graph().output()) {
graph_value_infos[v.name()].CopyFrom(v);
for (const auto& vi : onnx_model.graph().value_info()) }
graph_value_infos[vi.name()].CopyFrom(vi); // Collect graph values.
for (const auto& v : onnx_model.graph().value_info()) {
graph_value_infos[v.name()].CopyFrom(v);
}
// Collect graph initializers.
for (const auto& tensor : onnx_model.graph().initializer()) { for (const auto& tensor : onnx_model.graph().initializer()) {
if (include_initializers) { if (include_initializers) {
auto* op_def = init_graph->add_op(); auto* op_def = init_graph->add_op();
...@@ -76,17 +74,19 @@ void ONNXBackend::ONNXToDragon( ...@@ -76,17 +74,19 @@ void ONNXBackend::ONNXToDragon(
} }
graph_initializer[tensor.name()] = &tensor; graph_initializer[tensor.name()] = &tensor;
} }
// Convert to graph defs.
auto converter = [&](const ModelProto& model, GraphDef* graph) mutable { auto converter = [&](const ModelProto& model, GraphDef* graph) mutable {
for (const auto& node : model.graph().node()) { for (const auto& node : model.graph().node()) {
ValueInfoMap value_infos{}; ValueInfoMap value_infos{};
InitializerMap initializer{}; InitializerMap initializer{};
for (const auto& name : node.input()) { for (const auto& name : node.input()) {
if (graph_value_infos.count(name)) if (graph_value_infos.count(name)) {
value_infos[name].CopyFrom(graph_value_infos[name]); value_infos[name].CopyFrom(graph_value_infos[name]);
if (graph_initializer.count(name)) }
if (graph_initializer.count(name)) {
initializer[name] = graph_initializer[name]; initializer[name] = graph_initializer[name];
} }
}
auto onnx_node = ONNXNode(node); auto onnx_node = ONNXNode(node);
auto returns = ONNXNodeToOps( auto returns = ONNXNodeToOps(
init_model, init_model,
...@@ -98,23 +98,18 @@ void ONNXBackend::ONNXToDragon( ...@@ -98,23 +98,18 @@ void ONNXBackend::ONNXToDragon(
} }
} }
}; };
converter(pred_model, pred_graph); converter(pred_model, pred_graph);
// Add external inputs.
// Set(Initializer) + Set(Placehoders) = Set(Inputs)
Set<string> initializer; Set<string> initializer;
for (const auto& e : onnx_model.graph().initializer()) { for (const auto& v : onnx_model.graph().initializer()) {
initializer.insert(e.name()); initializer.insert(v.name());
} }
// Add External Inputs
for (const auto& e : onnx_model.graph().input()) { for (const auto& e : onnx_model.graph().input()) {
if (initializer.count(e.name()) == 0) { if (initializer.count(e.name()) == 0) {
pred_graph->add_input(e.name()); pred_graph->add_input(e.name());
} }
} }
// Add external outputs.
// Add External Outputs
for (const auto& e : onnx_model.graph().output()) { for (const auto& e : onnx_model.graph().output()) {
pred_graph->add_output(e.name()); pred_graph->add_output(e.name());
} }
......
#include "dragon/operators/array/channel_shuffle_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void ChannelShuffleOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
CHECK_EQ(X.dim(axis) % group_, 0)
<< "\nThe " << X.dim(axis) << " channels "
<< "can not be split into " << group_ << " groups.";
kernels::ChannelShuffle(
X.count(0, axis),
X.count(axis + 1),
X.dim(axis),
group_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void ChannelShuffleOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void ChannelShuffleGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0);
GET_OP_AXIS_ARG(axis, dY.ndim(), -1);
kernels::ChannelShuffle(
dY.count(0, axis),
dY.count(axis + 1),
dY.dim(axis),
dY.dim(axis) / group_,
dY.template data<T, Context>(),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void ChannelShuffleGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(ChannelShuffle);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelShuffle);
#endif
DEPLOY_CPU_OPERATOR(ChannelShuffleGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelShuffleGradient);
#endif
OPERATOR_SCHEMA(ChannelShuffle)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(ChannelShuffleGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(ChannelShuffle, SimpleGradientMaker);
} // namespace dragon
#include "dragon/operators/array/shuffle_op.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
template <class Context>
template <typename T>
void ChannelShuffleOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
CHECK_EQ(X.dim(axis) % group_, 0)
<< "\nThe " << X.dim(axis) << " channels "
<< "can not be split into " << group_ << " groups.";
auto G = group_, K = X.dim(axis) / group_;
if (def().type() == "ChannelShuffleGradient") std::swap(G, K);
math::Transpose(
4,
vec64_t({X.count(0, axis), G, K, X.count(axis + 1)}).data(),
vec64_t({0, 2, 1, 3}).data(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
DEPLOY_CPU_OPERATOR(ChannelShuffle);
REGISTER_CPU_OPERATOR(ChannelShuffleGradient, ChannelShuffleOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelShuffle);
REGISTER_CUDA_OPERATOR(ChannelShuffleGradient, ChannelShuffleOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(ChannelShuffle).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(ChannelShuffleGradient).NumInputs(1).NumOutputs(1);
REGISTER_GRADIENT(ChannelShuffle, SimpleGradientMaker);
} // namespace dragon
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_ #ifndef DRAGON_OPERATORS_ARRAY_SHUFFLE_OP_H_
#define DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_ #define DRAGON_OPERATORS_ARRAY_SHUFFLE_OP_H_
#include "dragon/core/operator.h" #include "dragon/core/operator.h"
...@@ -25,7 +25,9 @@ class ChannelShuffleOp final : public Operator<Context> { ...@@ -25,7 +25,9 @@ class ChannelShuffleOp final : public Operator<Context> {
group_(OP_SINGLE_ARG(int64_t, "group", 1)) {} group_(OP_SINGLE_ARG(int64_t, "group", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
...@@ -34,22 +36,6 @@ class ChannelShuffleOp final : public Operator<Context> { ...@@ -34,22 +36,6 @@ class ChannelShuffleOp final : public Operator<Context> {
int64_t group_; int64_t group_;
}; };
template <class Context>
class ChannelShuffleGradientOp final : public Operator<Context> {
public:
ChannelShuffleGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
group_(OP_SINGLE_ARG(int64_t, "group", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t group_;
};
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_ #endif // DRAGON_OPERATORS_ARRAY_SHUFFLE_OP_H_
#include "dragon/operators/array/channel_affine_op.h" #include "dragon/operators/math/affine_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void ChannelAffineOp<Context>::DoRunWithType() { void AffineOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0, {0}); auto &X = Input(0), &W = Input(1), *Y = Output(0, {0});
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
GET_OP_AXIS_ARG(end_axis, X.ndim(), axis);
vec64_t affine_dims( // Compute affine dimensions.
{X.dims().begin() + axis, X.dims().begin() + end_axis + 1}); vec64_t affine_dims;
for (auto axis : axes_) {
axis = axis < 0 ? axis + X.ndim() : axis;
CHECK(axis >= 0 && axis < X.ndim())
<< "\nExcepted the axis in [-" << X.ndim() << ", " << X.ndim()
<< "), got " << axis << ".";
affine_dims.push_back(X.dim(axis));
}
CHECK(W.dims() == affine_dims) CHECK(W.dims() == affine_dims)
<< "\nExcepted the weight shape is " << Tensor::DimString(affine_dims) << "\nExcepted the weight shape is " << Tensor::DimString(affine_dims)
<< ", got " << W.DimString() << "."; << ", got " << W.DimString() << ".";
...@@ -23,10 +27,11 @@ void ChannelAffineOp<Context>::DoRunWithType() { ...@@ -23,10 +27,11 @@ void ChannelAffineOp<Context>::DoRunWithType() {
<< ", got " << Input(2).DimString() << "."; << ", got " << Input(2).DimString() << ".";
} }
kernels::ChannelAffine( math::Affine(
X.count(0, axis), X.ndim(),
X.count(end_axis + 1), X.dims().data(),
X.count(axis, end_axis + 1), axes_.size(),
axes_.data(),
X.template data<T, Context>(), X.template data<T, Context>(),
W.template data<T, Context>(), W.template data<T, Context>(),
InputSize() <= 2 ? nullptr : Input(2).template data<T, Context>(), InputSize() <= 2 ? nullptr : Input(2).template data<T, Context>(),
...@@ -35,28 +40,30 @@ void ChannelAffineOp<Context>::DoRunWithType() { ...@@ -35,28 +40,30 @@ void ChannelAffineOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void ChannelAffineOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Numerical>::Call(this, Input(0));
}
template <class Context>
template <typename T> template <typename T>
void ChannelAffineGradientOp<Context>::DoRunWithType() { void AffineGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2); auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
GET_OP_AXIS_ARG(end_axis, X.ndim(), axis);
vec64_t affine_dims = {X.count(0, axis), // Compute reduce axes.
X.count(axis, end_axis + 1), vec64_t reduce_axes;
X.count(end_axis + 1)}, for (int i = 0; i < X.ndim(); ++i) {
affine_axes = {0, 2}; bool keep = true;
for (auto axis : axes_) {
axis = axis < 0 ? axis + X.ndim() : axis;
if (axis == i) keep = false;
}
if (keep) reduce_axes.push_back(i);
}
// Scratch to save the intermediates.
T* data = nullptr;
if (dW->has_name() && X.count() != W.count()) {
data = ctx()->workspace()->template data<T, Context>(X.count());
}
// dW = dY * X // dW = dY * X
if (dW->has_name()) { if (dW->has_name()) {
Output(1)->ReshapeLike(Input(1));
auto* x = Input(0).template data<T, Context>();
auto* dw = Output(1)->template mutable_data<T, Context>();
if (X.count() == W.count()) { if (X.count() == W.count()) {
math::Mul( math::Mul(
X.count(), X.count(),
...@@ -65,20 +72,19 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() { ...@@ -65,20 +72,19 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
dW->ReshapeLike(W)->template mutable_data<T, Context>(), dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx()); ctx());
} else { } else {
T* scratch = ctx()->workspace()->template data<T, Context>(X.count());
math::Mul( math::Mul(
X.count(), X.count(),
dY.template data<T, Context>(), dY.template data<T, Context>(),
X.template data<T, Context>(), X.template data<T, Context>(),
scratch, data,
ctx()); ctx());
math::ReduceSum( math::ReduceSum(
3, X.ndim(),
affine_dims.data(), X.dims().data(),
2, reduce_axes.size(),
affine_axes.data(), reduce_axes.data(),
1.f, 1.f,
scratch, data,
dW->ReshapeLike(W)->template mutable_data<T, Context>(), dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -90,10 +96,10 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() { ...@@ -90,10 +96,10 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
dB->ReshapeLike(W)->CopyFrom(dY, ctx()); dB->ReshapeLike(W)->CopyFrom(dY, ctx());
} else { } else {
math::ReduceSum( math::ReduceSum(
3, X.ndim(),
affine_dims.data(), X.dims().data(),
2, reduce_axes.size(),
affine_axes.data(), reduce_axes.data(),
1.f, 1.f,
dY.template data<T, Context>(), dY.template data<T, Context>(),
dB->ReshapeLike(W)->template mutable_data<T, Context>(), dB->ReshapeLike(W)->template mutable_data<T, Context>(),
...@@ -103,11 +109,11 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() { ...@@ -103,11 +109,11 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
// dX = dY * W // dX = dY * W
if (dX->has_name()) { if (dX->has_name()) {
Output(0)->ReshapeLike(Input(-1)); math::Affine(
kernels::ChannelAffine( X.ndim(),
X.count(0, axis), X.dims().data(),
X.count(end_axis + 1), axes_.size(),
X.count(axis, end_axis + 1), axes_.data(),
dY.template data<T, Context>(), dY.template data<T, Context>(),
W.template data<T, Context>(), W.template data<T, Context>(),
(const T*)nullptr, (const T*)nullptr,
...@@ -116,22 +122,17 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() { ...@@ -116,22 +122,17 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
} }
} }
template <class Context> DEPLOY_CPU_OPERATOR(Affine);
void ChannelAffineGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(ChannelAffine);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelAffine); DEPLOY_CUDA_OPERATOR(Affine);
#endif #endif
DEPLOY_CPU_OPERATOR(ChannelAffineGradient); DEPLOY_CPU_OPERATOR(AffineGradient);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelAffineGradient); DEPLOY_CUDA_OPERATOR(AffineGradient);
#endif #endif
OPERATOR_SCHEMA(ChannelAffine) OPERATOR_SCHEMA(Affine)
/* X, W, B */ /* X, W, B */
.NumInputs(2, 3) .NumInputs(2, 3)
/* Y */ /* Y */
...@@ -139,7 +140,7 @@ OPERATOR_SCHEMA(ChannelAffine) ...@@ -139,7 +140,7 @@ OPERATOR_SCHEMA(ChannelAffine)
/* X => Y */ /* X => Y */
.AllowInplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ChannelAffineGradient) OPERATOR_SCHEMA(AffineGradient)
/* X, W, dY */ /* X, W, dY */
.NumInputs(3) .NumInputs(3)
/* dX, dW, dB */ /* dX, dW, dB */
...@@ -163,6 +164,6 @@ class GradientMaker final : public GradientMakerBase { ...@@ -163,6 +164,6 @@ class GradientMaker final : public GradientMakerBase {
} // namespace } // namespace
REGISTER_GRADIENT(ChannelAffine, GradientMaker); REGISTER_GRADIENT(Affine, GradientMaker);
} // namespace dragon } // namespace dragon
...@@ -10,37 +10,49 @@ ...@@ -10,37 +10,49 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_ #ifndef DRAGON_OPERATORS_MATH_AFFINE_OP_H_
#define DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_ #define DRAGON_OPERATORS_MATH_AFFINE_OP_H_
#include "dragon/core/operator.h" #include "dragon/core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class ChannelAffineOp final : public Operator<Context> { class AffineOp final : public Operator<Context> {
public: public:
SIMPLE_CTOR_DTOR(ChannelAffineOp); AffineOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), axes_(OP_REPEATED_ARG(int64_t, "axes")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
vec64_t axes_;
}; };
template <class Context> template <class Context>
class ChannelAffineGradientOp final : public Operator<Context> { class AffineGradientOp final : public Operator<Context> {
public: public:
SIMPLE_CTOR_DTOR(ChannelAffineGradientOp); AffineGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), axes_(OP_REPEATED_ARG(int64_t, "axes")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
vec64_t axes_;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_ #endif // DRAGON_OPERATORS_MATH_AFFINE_OP_H_
...@@ -26,6 +26,7 @@ DISPATCH_WITH_TENSOR_TYPES(IsInf, dtypes::Floating, Input(0)); ...@@ -26,6 +26,7 @@ DISPATCH_WITH_TENSOR_TYPES(IsInf, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(IsNaN, dtypes::Floating, Input(0)); DISPATCH_WITH_TENSOR_TYPES(IsNaN, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(IsFinite, dtypes::Floating, Input(0)); DISPATCH_WITH_TENSOR_TYPES(IsFinite, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(Pow, dtypes::Floating, Input(0)); DISPATCH_WITH_TENSOR_TYPES(Pow, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(Atan2, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(Minimum, dtypes::Numerical, Input(0)); DISPATCH_WITH_TENSOR_TYPES(Minimum, dtypes::Numerical, Input(0));
DISPATCH_WITH_TENSOR_TYPES(Maximum, dtypes::Numerical, Input(0)); DISPATCH_WITH_TENSOR_TYPES(Maximum, dtypes::Numerical, Input(0));
DISPATCH_WITH_TENSOR_TYPES(BitwiseNot, dtypes::Bitwise, Input(0)); DISPATCH_WITH_TENSOR_TYPES(BitwiseNot, dtypes::Bitwise, Input(0));
...@@ -120,6 +121,7 @@ DEFINE_INPLACE_UNARY_OP_IMPL(BitwiseNot, T); ...@@ -120,6 +121,7 @@ DEFINE_INPLACE_UNARY_OP_IMPL(BitwiseNot, T);
} }
DEFINE_SIMPLE_BINARY_OP_IMPL(Pow, T); DEFINE_SIMPLE_BINARY_OP_IMPL(Pow, T);
DEFINE_SIMPLE_BINARY_OP_IMPL(Atan2, T);
DEFINE_SIMPLE_BINARY_OP_IMPL(Minimum, T); DEFINE_SIMPLE_BINARY_OP_IMPL(Minimum, T);
DEFINE_SIMPLE_BINARY_OP_IMPL(Maximum, T); DEFINE_SIMPLE_BINARY_OP_IMPL(Maximum, T);
DEFINE_SIMPLE_BINARY_OP_IMPL(BitwiseAnd, T); DEFINE_SIMPLE_BINARY_OP_IMPL(BitwiseAnd, T);
...@@ -152,6 +154,7 @@ DEPLOY_CPU_OPERATOR(IsInf); ...@@ -152,6 +154,7 @@ DEPLOY_CPU_OPERATOR(IsInf);
DEPLOY_CPU_OPERATOR(IsNaN); DEPLOY_CPU_OPERATOR(IsNaN);
DEPLOY_CPU_OPERATOR(IsFinite); DEPLOY_CPU_OPERATOR(IsFinite);
DEPLOY_CPU_OPERATOR(Pow); DEPLOY_CPU_OPERATOR(Pow);
DEPLOY_CPU_OPERATOR(Atan2);
DEPLOY_CPU_OPERATOR(Minimum); DEPLOY_CPU_OPERATOR(Minimum);
DEPLOY_CPU_OPERATOR(Maximum); DEPLOY_CPU_OPERATOR(Maximum);
DEPLOY_CPU_OPERATOR(BitwiseNot); DEPLOY_CPU_OPERATOR(BitwiseNot);
...@@ -186,6 +189,7 @@ DEPLOY_CUDA_OPERATOR(IsInf); ...@@ -186,6 +189,7 @@ DEPLOY_CUDA_OPERATOR(IsInf);
DEPLOY_CUDA_OPERATOR(IsNaN); DEPLOY_CUDA_OPERATOR(IsNaN);
DEPLOY_CUDA_OPERATOR(IsFinite); DEPLOY_CUDA_OPERATOR(IsFinite);
DEPLOY_CUDA_OPERATOR(Pow); DEPLOY_CUDA_OPERATOR(Pow);
DEPLOY_CUDA_OPERATOR(Atan2);
DEPLOY_CUDA_OPERATOR(Minimum); DEPLOY_CUDA_OPERATOR(Minimum);
DEPLOY_CUDA_OPERATOR(Maximum); DEPLOY_CUDA_OPERATOR(Maximum);
DEPLOY_CUDA_OPERATOR(BitwiseNot); DEPLOY_CUDA_OPERATOR(BitwiseNot);
...@@ -222,6 +226,7 @@ OPERATOR_SCHEMA(IsNaN).NumInputs(1).NumOutputs(1); ...@@ -222,6 +226,7 @@ OPERATOR_SCHEMA(IsNaN).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(IsFinite).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(IsFinite).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Not).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Not).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Pow).NumInputs(2).NumOutputs(1); OPERATOR_SCHEMA(Pow).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(Atan2).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(Minimum).NumInputs(2).NumOutputs(1); OPERATOR_SCHEMA(Minimum).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(Maximum).NumInputs(2).NumOutputs(1); OPERATOR_SCHEMA(Maximum).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(BitwiseAnd) OPERATOR_SCHEMA(BitwiseAnd)
...@@ -250,6 +255,7 @@ NO_GRADIENT(Round); ...@@ -250,6 +255,7 @@ NO_GRADIENT(Round);
NO_GRADIENT(IsInf); NO_GRADIENT(IsInf);
NO_GRADIENT(IsNaN); NO_GRADIENT(IsNaN);
NO_GRADIENT(IsFinite); NO_GRADIENT(IsFinite);
NO_GRADIENT(Atan2);
NO_GRADIENT(BitwiseNot); NO_GRADIENT(BitwiseNot);
NO_GRADIENT(BitwiseAnd); NO_GRADIENT(BitwiseAnd);
NO_GRADIENT(BitwiseOr); NO_GRADIENT(BitwiseOr);
......
...@@ -70,7 +70,7 @@ inline vec32_t CheckOutputAliases( ...@@ -70,7 +70,7 @@ inline vec32_t CheckOutputAliases(
return available_aliases; return available_aliases;
} }
// Unary ElementwiseOp // Unary ElementwiseOp.
DECLARE_ELEMENTWISE_OP(Abs); DECLARE_ELEMENTWISE_OP(Abs);
DECLARE_ELEMENTWISE_OP(Ceil); DECLARE_ELEMENTWISE_OP(Ceil);
DECLARE_ELEMENTWISE_OP(Cos); DECLARE_ELEMENTWISE_OP(Cos);
...@@ -101,12 +101,13 @@ DECLARE_ELEMENTWISE_OP(SignGradient); ...@@ -101,12 +101,13 @@ DECLARE_ELEMENTWISE_OP(SignGradient);
DECLARE_ELEMENTWISE_OP(SinGradient); DECLARE_ELEMENTWISE_OP(SinGradient);
DECLARE_ELEMENTWISE_OP(SqrtGradient); DECLARE_ELEMENTWISE_OP(SqrtGradient);
DECLARE_ELEMENTWISE_OP(SquareGradient); DECLARE_ELEMENTWISE_OP(SquareGradient);
// Binary ElementwiseOp // Binary ElementwiseOp.
DECLARE_ELEMENTWISE_OP(Add); DECLARE_ELEMENTWISE_OP(Add);
DECLARE_ELEMENTWISE_OP(Sub); DECLARE_ELEMENTWISE_OP(Sub);
DECLARE_ELEMENTWISE_OP(Mul); DECLARE_ELEMENTWISE_OP(Mul);
DECLARE_ELEMENTWISE_OP(Div); DECLARE_ELEMENTWISE_OP(Div);
DECLARE_ELEMENTWISE_OP(Pow); DECLARE_ELEMENTWISE_OP(Pow);
DECLARE_ELEMENTWISE_OP(Atan2);
DECLARE_ELEMENTWISE_OP(Minimum); DECLARE_ELEMENTWISE_OP(Minimum);
DECLARE_ELEMENTWISE_OP(Maximum); DECLARE_ELEMENTWISE_OP(Maximum);
DECLARE_ELEMENTWISE_OP(BitwiseAnd); DECLARE_ELEMENTWISE_OP(BitwiseAnd);
...@@ -128,7 +129,7 @@ DECLARE_ELEMENTWISE_OP(DivGradient); ...@@ -128,7 +129,7 @@ DECLARE_ELEMENTWISE_OP(DivGradient);
DECLARE_ELEMENTWISE_OP(PowGradient); DECLARE_ELEMENTWISE_OP(PowGradient);
DECLARE_ELEMENTWISE_OP(MinimumGradient); DECLARE_ELEMENTWISE_OP(MinimumGradient);
DECLARE_ELEMENTWISE_OP(MaximumGradient); DECLARE_ELEMENTWISE_OP(MaximumGradient);
// Trinary ElementwiseOp // Trinary ElementwiseOp.
DECLARE_ELEMENTWISE_OP(Where); DECLARE_ELEMENTWISE_OP(Where);
DECLARE_ELEMENTWISE_OP(WhereGradient); DECLARE_ELEMENTWISE_OP(WhereGradient);
#undef DECLARE_ELEMENTWISE_OP #undef DECLARE_ELEMENTWISE_OP
......
...@@ -199,11 +199,6 @@ void MatMulOp<Context>::DoRunWithType() { ...@@ -199,11 +199,6 @@ void MatMulOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void MatMulOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T> template <typename T>
void MatMulGradientOp<Context>::DoRunWithType() { void MatMulGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(2); auto &A = Input(0), &B = Input(1), &dY = Input(2);
...@@ -590,11 +585,6 @@ void MatMulGradientOp<Context>::DoRunWithType() { ...@@ -590,11 +585,6 @@ void MatMulGradientOp<Context>::DoRunWithType() {
} }
} }
template <class Context>
void MatMulGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(MatMul); DEPLOY_CPU_OPERATOR(MatMul);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(MatMul); DEPLOY_CUDA_OPERATOR(MatMul);
......
...@@ -23,7 +23,9 @@ class MatMulOp final : public Operator<Context> { ...@@ -23,7 +23,9 @@ class MatMulOp final : public Operator<Context> {
SIMPLE_CTOR_DTOR(MatMulOp); SIMPLE_CTOR_DTOR(MatMulOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
...@@ -35,7 +37,9 @@ class MatMulGradientOp final : public Operator<Context> { ...@@ -35,7 +37,9 @@ class MatMulGradientOp final : public Operator<Context> {
SIMPLE_CTOR_DTOR(MatMulGradientOp); SIMPLE_CTOR_DTOR(MatMulGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
......
#include "dragon/operators/array/channel_normalize_op.h" #include "dragon/operators/normalization/channel_norm_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
...@@ -6,7 +6,7 @@ namespace dragon { ...@@ -6,7 +6,7 @@ namespace dragon {
template <class Context> template <class Context>
template <typename InputT, typename OutputT> template <typename InputT, typename OutputT>
void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() { void ChannelNormOp<Context>::DoRunWithTypeAndCast() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1); GET_OP_AXIS_ARG(axis, X.ndim(), -1);
...@@ -30,7 +30,7 @@ void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() { ...@@ -30,7 +30,7 @@ void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() {
<< "\nProviding " << X_mean_.count() << " values to normalize Dimension(" << "\nProviding " << X_mean_.count() << " values to normalize Dimension("
<< Y_dims[axis] << ")."; << Y_dims[axis] << ").";
kernels::ChannelNormalize( kernels::ChannelNorm(
axis, axis,
num_dims, num_dims,
X_strides.data(), X_strides.data(),
...@@ -44,7 +44,7 @@ void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() { ...@@ -44,7 +44,7 @@ void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() {
template <class Context> template <class Context>
template <typename T> template <typename T>
void ChannelNormalizeOp<Context>::DoRunWithType() { void ChannelNormOp<Context>::DoRunWithType() {
if (data_type() == "float16") { if (data_type() == "float16") {
DoRunWithTypeAndCast<T, float16>(); DoRunWithTypeAndCast<T, float16>();
} else if (data_type() == "float32") { } else if (data_type() == "float32") {
...@@ -58,21 +58,21 @@ void ChannelNormalizeOp<Context>::DoRunWithType() { ...@@ -58,21 +58,21 @@ void ChannelNormalizeOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void ChannelNormalizeOp<Context>::RunOnDevice() { void ChannelNormOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Numerical>::Call(this, Input(0)); DispatchHelper<dtypes::Numerical>::Call(this, Input(0));
} }
DEPLOY_CPU_OPERATOR(ChannelNormalize); DEPLOY_CPU_OPERATOR(ChannelNorm);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelNormalize); DEPLOY_CUDA_OPERATOR(ChannelNorm);
#endif #endif
OPERATOR_SCHEMA(ChannelNormalize) OPERATOR_SCHEMA(ChannelNorm)
/* X */ /* X */
.NumInputs(1) .NumInputs(1)
/* Y */ /* Y */
.NumOutputs(1); .NumOutputs(1);
NO_GRADIENT(ChannelNormalize); NO_GRADIENT(ChannelNorm);
} // namespace dragon } // namespace dragon
...@@ -10,17 +10,17 @@ ...@@ -10,17 +10,17 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_NORMALIZE_OP_H_ #ifndef DRAGON_OPERATORS_NORMALIZATION_CHANNEL_NORM_OP_H_
#define DRAGON_OPERATORS_ARRAY_CHANNEL_NORMALIZE_OP_H_ #define DRAGON_OPERATORS_NORMALIZATION_CHANNEL_NORM_OP_H_
#include "dragon/core/operator.h" #include "dragon/core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class ChannelNormalizeOp final : public Operator<Context> { class ChannelNormOp final : public Operator<Context> {
public: public:
ChannelNormalizeOp(const OperatorDef& def, Workspace* ws) ChannelNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
INITIALIZE_OP_REPEATED_ARG(int64_t, perm); INITIALIZE_OP_REPEATED_ARG(int64_t, perm);
auto mean = OP_REPEATED_ARG(float, "mean"); auto mean = OP_REPEATED_ARG(float, "mean");
...@@ -50,8 +50,8 @@ class ChannelNormalizeOp final : public Operator<Context> { ...@@ -50,8 +50,8 @@ class ChannelNormalizeOp final : public Operator<Context> {
DECLARE_OP_REPEATED_ARG(int64_t, perm); DECLARE_OP_REPEATED_ARG(int64_t, perm);
}; };
DEFINE_OP_REPEATED_ARG(int64_t, ChannelNormalizeOp, perm); DEFINE_OP_REPEATED_ARG(int64_t, ChannelNormOp, perm);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CHANNEL_NORMALIZE_OP_H_ #endif // DRAGON_OPERATORS_NORMALIZATION_CHANNEL_NORM_OP_H_
#include "dragon/operators/normalization/lp_normalize_op.h" #include "dragon/operators/normalization/lp_norm_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void LpNormalizeOp<Context>::DoRunWithType() { void LpNormOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1); GET_OP_AXIS_ARG(axis, X.ndim(), -1);
GET_OP_AXIS_ARG(end_axis, X.ndim(), axis); GET_OP_AXIS_ARG(end_axis, X.ndim(), axis);
auto reduce_dim = X.count(axis, end_axis + 1); auto reduce_dim = X.count(axis, end_axis + 1);
// Normalize input with a scaled Lp-norm
if (p_ == 1) { if (p_ == 1) {
kernels::L1Normalize( kernels::L1Norm(
X.count(0, axis), X.count(0, axis),
X.count(end_axis + 1), X.count(end_axis + 1),
reduce_dim, reduce_dim,
...@@ -25,7 +22,7 @@ void LpNormalizeOp<Context>::DoRunWithType() { ...@@ -25,7 +22,7 @@ void LpNormalizeOp<Context>::DoRunWithType() {
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
} else if (p_ == 2) { } else if (p_ == 2) {
kernels::L2Normalize( kernels::L2Norm(
X.count(0, axis), X.count(0, axis),
X.count(end_axis + 1), X.count(end_axis + 1),
reduce_dim, reduce_dim,
...@@ -40,20 +37,15 @@ void LpNormalizeOp<Context>::DoRunWithType() { ...@@ -40,20 +37,15 @@ void LpNormalizeOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void LpNormalizeOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T> template <typename T>
void LpNormalizeGradientOp<Context>::DoRunWithType() { void LpNormGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(1), *dX = Output(0); auto &X = Input(0), &dY = Input(1), *dX = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1); GET_OP_AXIS_ARG(axis, X.ndim(), -1);
GET_OP_AXIS_ARG(end_axis, X.ndim(), axis); GET_OP_AXIS_ARG(end_axis, X.ndim(), axis);
auto reduce_dim = X.count(axis, end_axis + 1); auto reduce_dim = X.count(axis, end_axis + 1);
if (p_ == 1) { if (p_ == 1) {
kernels::L1NormalizeGrad( kernels::L1NormGrad(
X.count(0, axis), X.count(0, axis),
X.count(end_axis + 1), X.count(end_axis + 1),
reduce_dim, reduce_dim,
...@@ -64,7 +56,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() { ...@@ -64,7 +56,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
dX->ReshapeLike(X)->template mutable_data<T, Context>(), dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
} else if (p_ == 2) { } else if (p_ == 2) {
kernels::L2NormalizeGrad( kernels::L2NormGrad(
X.count(0, axis), X.count(0, axis),
X.count(end_axis + 1), X.count(end_axis + 1),
reduce_dim, reduce_dim,
...@@ -79,33 +71,28 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() { ...@@ -79,33 +71,28 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
} }
} }
template <class Context> DEPLOY_CPU_OPERATOR(LpNorm);
void LpNormalizeGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(LpNormalize);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(LpNormalize); DEPLOY_CUDA_OPERATOR(LpNorm);
#endif #endif
DEPLOY_CPU_OPERATOR(LpNormalizeGradient); DEPLOY_CPU_OPERATOR(LpNormGradient);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(LpNormalizeGradient); DEPLOY_CUDA_OPERATOR(LpNormGradient);
#endif #endif
OPERATOR_SCHEMA(LpNormalize) OPERATOR_SCHEMA(LpNorm)
/* X */ /* X */
.NumInputs(1) .NumInputs(1)
/* Y */ /* Y */
.NumOutputs(1); .NumOutputs(1);
OPERATOR_SCHEMA(LpNormalizeGradient) OPERATOR_SCHEMA(LpNormGradient)
/* X, dY */ /* X, dY */
.NumInputs(2) .NumInputs(2)
/* dX */ /* dX */
.NumOutputs(1); .NumOutputs(1);
REGISTER_GRADIENT(LpNormalize, GenericGradientMaker); REGISTER_GRADIENT(LpNorm, GenericGradientMaker);
} // namespace dragon } // namespace dragon
...@@ -10,24 +10,26 @@ ...@@ -10,24 +10,26 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_NORMALIZATION_LP_NORMALIZE_OP_H_ #ifndef DRAGON_OPERATORS_NORMALIZATION_LP_NORM_OP_H_
#define DRAGON_OPERATORS_NORMALIZATION_LP_NORMALIZE_OP_H_ #define DRAGON_OPERATORS_NORMALIZATION_LP_NORM_OP_H_
#include "dragon/core/operator.h" #include "dragon/core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class LpNormalizeOp final : public Operator<Context> { class LpNormOp final : public Operator<Context> {
public: public:
LpNormalizeOp(const OperatorDef& def, Workspace* ws) LpNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
p_(OP_SINGLE_ARG(int64_t, "p", 2)), p_(OP_SINGLE_ARG(int64_t, "p", 2)),
epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-12)), epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-12)),
reduction_(OP_SINGLE_ARG(string, "reduction", "SUM")) {} reduction_(OP_SINGLE_ARG(string, "reduction", "SUM")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
...@@ -39,16 +41,18 @@ class LpNormalizeOp final : public Operator<Context> { ...@@ -39,16 +41,18 @@ class LpNormalizeOp final : public Operator<Context> {
}; };
template <class Context> template <class Context>
class LpNormalizeGradientOp final : public Operator<Context> { class LpNormGradientOp final : public Operator<Context> {
public: public:
LpNormalizeGradientOp(const OperatorDef& def, Workspace* ws) LpNormGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
p_(OP_SINGLE_ARG(int64_t, "p", 2)), p_(OP_SINGLE_ARG(int64_t, "p", 2)),
epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-12)), epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-12)),
reduction_(OP_SINGLE_ARG(string, "reduction", "SUM")) {} reduction_(OP_SINGLE_ARG(string, "reduction", "SUM")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
...@@ -61,4 +65,4 @@ class LpNormalizeGradientOp final : public Operator<Context> { ...@@ -61,4 +65,4 @@ class LpNormalizeGradientOp final : public Operator<Context> {
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NORMALIZATION_LP_NORMALIZE_OP_H_ #endif // DRAGON_OPERATORS_NORMALIZATION_LP_NORM_OP_H_
...@@ -5,46 +5,41 @@ ...@@ -5,46 +5,41 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void AdamOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) { template <typename T, typename CopyT>
void AdamOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::Adam( kernels::Adam(
dX->count(), dX->count(),
lr_ * correction_, lr_ * correction_,
beta1_, beta1_,
beta2_, beta2_,
eps_, eps_,
dX->template mutable_data<float, Context>(), this->weight_decay_,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(), X->template data<T, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(), dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
GetState("v")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx()); ctx());
} }
template <class Context> template <class Context>
void AdamWOp<Context>::ComputeUpdate(Tensor* dX, Tensor* X) { template <typename T, typename CopyT>
if (lambda_ > 0.f) { void AdamWOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::AdamW( kernels::AdamW(
dX->count(), dX->count(),
lr_ * correction_, lr_ * correction_,
beta1_, beta1_,
beta2_, beta2_,
eps_, eps_,
this->lr_ * lambda_, lr_ * this->weight_decay_,
X->template data<float, Context>(), X->template data<T, Context>(),
dX->template mutable_data<float, Context>(), dX->template mutable_data<T, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(), GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(), GetState("v")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx()); ctx());
} else {
kernels::Adam(
dX->count(),
lr_ * correction_,
beta1_,
beta2_,
eps_,
dX->template mutable_data<float, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx());
}
} }
DEPLOY_CPU_OPERATOR(Adam); DEPLOY_CPU_OPERATOR(Adam);
......
...@@ -5,16 +5,21 @@ ...@@ -5,16 +5,21 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void RMSpropOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) { template <typename T, typename CopyT>
void RMSpropOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::RMSprop( kernels::RMSprop(
dX->count(), dX->count(),
lr_, lr_,
momentum_, momentum_,
decay_, alpha_,
eps_, eps_,
dX->template mutable_data<float, Context>(), this->weight_decay_,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(), X->template data<T, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(), dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
GetState("v")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx()); ctx());
} }
......
...@@ -6,33 +6,44 @@ ...@@ -6,33 +6,44 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
void MomentumSGDOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) { template <typename T, typename CopyT>
void MomentumSGDOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::MomentumSGD( kernels::MomentumSGD(
dX->count(), dX->count(),
lr_, lr_,
momentum_, momentum_,
dX->template mutable_data<float, Context>(), this->weight_decay_,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(), X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx()); ctx());
} }
template <class Context> template <class Context>
void NesterovSGDOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) { template <typename T, typename CopyT>
void NesterovSGDOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::NesterovSGD( kernels::NesterovSGD(
dX->count(), dX->count(),
lr_, lr_,
momentum_, momentum_,
dX->template mutable_data<float, Context>(), this->weight_decay_,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(), X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx()); ctx());
} }
template <class Context> template <class Context>
void LARSOp<Context>::ComputeUpdate(Tensor* dX, Tensor* X) { template <typename T, typename CopyT>
void LARSOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
float trust_ratio = 0.f; float trust_ratio = 0.f;
if (trust_coef_ > 0.f) { if (trust_coef_ > 0.f) {
auto* x = X->template data<float, Context>(); auto* x = X->template data<T, Context>();
auto* dx = dX->template mutable_data<float, Context>(); auto* dx = dX->template mutable_data<T, Context>();
float x_norm = std::sqrt(math::Dot(X->count(), x, x, ctx())); float x_norm = std::sqrt(math::Dot(X->count(), x, x, ctx()));
float dx_norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx())); float dx_norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx()));
if (x_norm > 0.f && dx_norm > 0.f) { if (x_norm > 0.f && dx_norm > 0.f) {
...@@ -43,16 +54,20 @@ void LARSOp<Context>::ComputeUpdate(Tensor* dX, Tensor* X) { ...@@ -43,16 +54,20 @@ void LARSOp<Context>::ComputeUpdate(Tensor* dX, Tensor* X) {
math::Scale( math::Scale(
dX->count(), dX->count(),
trust_ratio, trust_ratio,
dX->template data<float, Context>(), dX->template data<T, Context>(),
dX->template mutable_data<float, Context>(), dX->template mutable_data<T, Context>(),
ctx()); ctx());
} }
kernels::MomentumSGD( kernels::MomentumSGD(
dX->count(), dX->count(),
lr_, lr_,
momentum_, momentum_,
dX->template mutable_data<float, Context>(), this->weight_decay_,
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(), X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx()); ctx());
} }
......
...@@ -37,17 +37,14 @@ class UpdateOpBase : public Operator<Context> { ...@@ -37,17 +37,14 @@ class UpdateOpBase : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> template <typename T>
void TransformGrad(Tensor* dX, Tensor* X); void TransformGrad(Tensor* dX);
virtual void ComputeUpdate(Tensor* dX, Tensor* X) = 0; virtual void ApplyUpdate(Tensor* dX, Tensor* X, Tensor* Y) = 0;
template <typename T>
void ApplyUpdate(Tensor* dX, Tensor* X);
template <typename T> template <typename T>
T GetHyper(const string& key); T GetHyper(const string& key);
Tensor* Slot(const string& key); Tensor* GetState(const string& key);
protected: protected:
int weight_index_; int weight_index_;
...@@ -57,7 +54,24 @@ class UpdateOpBase : public Operator<Context> { ...@@ -57,7 +54,24 @@ class UpdateOpBase : public Operator<Context> {
#define USE_UPDATE_FUNCTIONS \ #define USE_UPDATE_FUNCTIONS \
using UpdateOpBase<Context>::GetHyper; \ using UpdateOpBase<Context>::GetHyper; \
using UpdateOpBase<Context>::Slot using UpdateOpBase<Context>::GetState; \
void ApplyUpdate(Tensor* dX, Tensor* X, Tensor* Y) override { \
if (dX->template IsType<float>()) { \
if (Y == nullptr) { \
DoRunWithType<float, float>(dX, X, Y); \
} else if (Y->template IsType<float16>()) { \
DoRunWithType<float, float16>(dX, X, Y); \
} else { \
LOG(FATAL) << MessageForUnsupported( \
dtypes::to_string(Y->meta()), {"float16", "float32"}); \
} \
} else if (dX->template IsType<double>()) { \
DoRunWithType<double, double>(dX, X, Y); \
} else { \
LOG(FATAL) << MessageForUnsupported( \
dtypes::to_string(dX->meta()), {"float32", "float64"}); \
} \
}
template <class Context> template <class Context>
class MomentumSGDOp final : public UpdateOpBase<Context> { class MomentumSGDOp final : public UpdateOpBase<Context> {
...@@ -73,7 +87,8 @@ class MomentumSGDOp final : public UpdateOpBase<Context> { ...@@ -73,7 +87,8 @@ class MomentumSGDOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments(); UpdateOpBase<Context>::GetArguments();
} }
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override; template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected: protected:
float lr_, momentum_; float lr_, momentum_;
...@@ -93,7 +108,8 @@ class NesterovSGDOp final : public UpdateOpBase<Context> { ...@@ -93,7 +108,8 @@ class NesterovSGDOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments(); UpdateOpBase<Context>::GetArguments();
} }
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override; template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected: protected:
float lr_, momentum_; float lr_, momentum_;
...@@ -110,15 +126,16 @@ class RMSpropOp final : public UpdateOpBase<Context> { ...@@ -110,15 +126,16 @@ class RMSpropOp final : public UpdateOpBase<Context> {
void GetArguments() override { void GetArguments() override {
lr_ = this->template GetHyper<float>("lr"); lr_ = this->template GetHyper<float>("lr");
momentum_ = this->template GetHyper<float>("momentum"); momentum_ = this->template GetHyper<float>("momentum");
decay_ = this->template GetHyper<float>("decay"); alpha_ = this->template GetHyper<float>("alpha");
eps_ = this->template GetHyper<float>("eps"); eps_ = this->template GetHyper<float>("eps");
UpdateOpBase<Context>::GetArguments(); UpdateOpBase<Context>::GetArguments();
} }
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override; template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected: protected:
float lr_, momentum_, decay_, eps_; float lr_, momentum_, alpha_, eps_;
}; };
template <class Context> template <class Context>
...@@ -139,7 +156,8 @@ class AdamOp : public UpdateOpBase<Context> { ...@@ -139,7 +156,8 @@ class AdamOp : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments(); UpdateOpBase<Context>::GetArguments();
} }
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override; template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected: protected:
int64_t t_; int64_t t_;
...@@ -163,16 +181,15 @@ class AdamWOp final : public UpdateOpBase<Context> { ...@@ -163,16 +181,15 @@ class AdamWOp final : public UpdateOpBase<Context> {
t_++; t_++;
correction_ = sqrt(1.f - pow(beta2_, t_)) / (1.f - pow(beta1_, t_)); correction_ = sqrt(1.f - pow(beta2_, t_)) / (1.f - pow(beta1_, t_));
UpdateOpBase<Context>::GetArguments(); UpdateOpBase<Context>::GetArguments();
lambda_ = this->weight_decay_;
this->weight_decay_ = 0.f;
} }
void ComputeUpdate(Tensor* dX, Tensor* X) override; template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected: protected:
int64_t t_; int64_t t_;
float lr_, beta1_, beta2_; float lr_, beta1_, beta2_;
float eps_, correction_, lambda_; float eps_, correction_;
}; };
template <class Context> template <class Context>
...@@ -190,14 +207,13 @@ class LARSOp final : public UpdateOpBase<Context> { ...@@ -190,14 +207,13 @@ class LARSOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments(); UpdateOpBase<Context>::GetArguments();
} }
void ComputeUpdate(Tensor* dX, Tensor* X) override; template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected: protected:
float lr_, momentum_, trust_coef_; float lr_, momentum_, trust_coef_;
}; };
#undef USE_UPDATE_FUNCTIONS
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_UPDATE_OP_H_ #endif // DRAGON_OPERATORS_TRAINING_UPDATE_OP_H_
...@@ -13,67 +13,40 @@ T UpdateOpBase<Context>::GetHyper(const string& key) { ...@@ -13,67 +13,40 @@ T UpdateOpBase<Context>::GetHyper(const string& key) {
} }
template <class Context> template <class Context>
Tensor* UpdateOpBase<Context>::Slot(const string& key) { Tensor* UpdateOpBase<Context>::GetState(const string& key) {
const string& weight_name = Output(weight_index_)->name(); const string& weight_name = Output(weight_index_)->name();
return workspace()->CreateTensor(name() + "/" + weight_name + "/" + key); return workspace()->CreateTensor(name() + "/" + weight_name + "/" + key);
} }
template <class Context> template <class Context>
template <typename T> template <typename T>
void UpdateOpBase<Context>::TransformGrad(Tensor* dX, Tensor* X) { void UpdateOpBase<Context>::TransformGrad(Tensor* dX) {
// Scale.
if (grad_scale_ != 1.f) { if (grad_scale_ != 1.f) {
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
math::Scale(dX->count(), grad_scale_, dx, dx, ctx()); math::Scale(dX->count(), grad_scale_, dx, dx, ctx());
} }
// Clip.
if (clip_norm_ > 0.f) { if (clip_norm_ > 0.f) {
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
float grad_norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx())); float norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx()));
if (grad_norm > clip_norm_) { if (norm > clip_norm_) {
math::Scale(dX->count(), clip_norm_ / grad_norm, dx, dx, ctx()); math::Scale(dX->count(), clip_norm_ / norm, dx, dx, ctx());
} }
} else if (clip_value_ > 0.f) { } else if (clip_value_ > 0.f) {
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
kernels::Clip(dX->count(), -clip_value_, clip_value_, dx, dx, ctx()); kernels::Clip(dX->count(), -clip_value_, clip_value_, dx, dx, ctx());
} }
// Penalty.
if (weight_decay_ > 0.f) {
math::Axpy(
X->count(),
weight_decay_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
template <typename T>
void UpdateOpBase<Context>::ApplyUpdate(Tensor* dX, Tensor* X) {
math::Sub(
X->count(),
X->template data<T, Context>(),
dX->template data<T, Context>(),
X->template mutable_data<T, Context>(),
ctx());
} }
template <class Context> template <class Context>
void UpdateOpBase<Context>::RunOnDevice() { void UpdateOpBase<Context>::RunOnDevice() {
GetArguments(); GetArguments();
for (int i = 0; i < InputSize(); ++i) { for (weight_index_ = 0; weight_index_ < InputSize(); ++weight_index_) {
weight_index_ = i; auto &dX = Input(weight_index_), *X = Output(weight_index_);
auto &dX = Input(i), *X = Output(i); if (dX.count() == 0 || X->count() == 0) continue;
if (dX.count() == 0 || X->count() == 0) return;
CHECK(dX.dims() == X->dims()) CHECK(dX.dims() == X->dims())
<< "\nWeight and grad should have the same dimensions." << "\nWeight and grad should have the same dimensions."
<< "\nGot" << X->DimString() << " and " << dX.DimString(); << "\nGot" << X->DimString() << " and " << dX.DimString();
if (dX.template IsType<float>()) { if (dX.template IsType<float16>()) {
TransformGrad<float>(&dX, X);
ComputeUpdate(&dX, X);
ApplyUpdate<float>(&dX, X);
} else if (dX.template IsType<float16>()) {
auto* X_master = workspace()->CreateTensor(X->name() + "_master"); auto* X_master = workspace()->CreateTensor(X->name() + "_master");
auto* X_grad = ctx()->workspace()->CreateTensor("BufferShared"); auto* X_grad = ctx()->workspace()->CreateTensor("BufferShared");
if (X_master->count() != X->count()) { if (X_master->count() != X->count()) {
...@@ -88,17 +61,17 @@ void UpdateOpBase<Context>::RunOnDevice() { ...@@ -88,17 +61,17 @@ void UpdateOpBase<Context>::RunOnDevice() {
dX.template data<float16, Context>(), dX.template data<float16, Context>(),
X_grad->ReshapeLike(dX)->template mutable_data<float, Context>(), X_grad->ReshapeLike(dX)->template mutable_data<float, Context>(),
ctx()); ctx());
TransformGrad<float>(X_grad, X_master); TransformGrad<float>(X_grad);
ComputeUpdate(X_grad, X_master); ApplyUpdate(X_grad, X_master, X);
ApplyUpdate<float>(X_grad, X_master); } else if (dX.template IsType<float>()) {
math::Cast( TransformGrad<float>(&dX);
X->count(), ApplyUpdate(&dX, X, nullptr);
X_master->template data<float, Context>(), } else if (dX.template IsType<double>()) {
X->template mutable_data<float16, Context>(), TransformGrad<double>(&dX);
ctx()); ApplyUpdate(&dX, X, nullptr);
} else { } else {
LOG(FATAL) << MessageForUnsupported( LOG(FATAL) << MessageForUnsupported(
dtypes::to_string(dX.meta()), {"float16", "float32"}); dtypes::to_string(dX.meta()), {"float16", "float32", "float64"});
} }
} }
} }
......
...@@ -58,9 +58,6 @@ from dragon.core.ops import tensor_ops as _ ...@@ -58,9 +58,6 @@ from dragon.core.ops import tensor_ops as _
from dragon.core.ops.array_ops import assign from dragon.core.ops.array_ops import assign
from dragon.core.ops.array_ops import boolean_mask from dragon.core.ops.array_ops import boolean_mask
from dragon.core.ops.array_ops import broadcast_to from dragon.core.ops.array_ops import broadcast_to
from dragon.core.ops.array_ops import channel_affine
from dragon.core.ops.array_ops import channel_normalize
from dragon.core.ops.array_ops import channel_shuffle
from dragon.core.ops.array_ops import concat from dragon.core.ops.array_ops import concat
from dragon.core.ops.array_ops import expand_dims from dragon.core.ops.array_ops import expand_dims
from dragon.core.ops.array_ops import flatten from dragon.core.ops.array_ops import flatten
......
...@@ -21,6 +21,7 @@ from dragon.core.device.cuda import current_device ...@@ -21,6 +21,7 @@ from dragon.core.device.cuda import current_device
from dragon.core.device.cuda import get_device_capability from dragon.core.device.cuda import get_device_capability
from dragon.core.device.cuda import is_available from dragon.core.device.cuda import is_available
from dragon.core.device.cuda import memory_allocated from dragon.core.device.cuda import memory_allocated
from dragon.core.device.cuda import set_cublas_flags
from dragon.core.device.cuda import set_cudnn_flags from dragon.core.device.cuda import set_cudnn_flags
from dragon.core.device.cuda import set_default_device from dragon.core.device.cuda import set_default_device
from dragon.core.device.cuda import set_device from dragon.core.device.cuda import set_device
......
...@@ -17,8 +17,10 @@ from dragon.core.ops.activation_ops import sigmoid ...@@ -17,8 +17,10 @@ from dragon.core.ops.activation_ops import sigmoid
from dragon.core.ops.activation_ops import tanh from dragon.core.ops.activation_ops import tanh
from dragon.core.ops.math_ops import abs from dragon.core.ops.math_ops import abs
from dragon.core.ops.math_ops import add from dragon.core.ops.math_ops import add
from dragon.core.ops.math_ops import affine
from dragon.core.ops.math_ops import argmax from dragon.core.ops.math_ops import argmax
from dragon.core.ops.math_ops import argmin from dragon.core.ops.math_ops import argmin
from dragon.core.ops.math_ops import atan2
from dragon.core.ops.math_ops import ceil from dragon.core.ops.math_ops import ceil
from dragon.core.ops.math_ops import clip from dragon.core.ops.math_ops import clip
from dragon.core.ops.math_ops import cos from dragon.core.ops.math_ops import cos
...@@ -60,7 +62,6 @@ from dragon.core.ops.math_ops import sqrt ...@@ -60,7 +62,6 @@ from dragon.core.ops.math_ops import sqrt
from dragon.core.ops.math_ops import square from dragon.core.ops.math_ops import square
from dragon.core.ops.math_ops import sub from dragon.core.ops.math_ops import sub
from dragon.core.ops.math_ops import sum from dragon.core.ops.math_ops import sum
from dragon.core.ops.normalization_ops import lp_normalize
from dragon.core.ops.sort_ops import top_k from dragon.core.ops.sort_ops import top_k
__all__ = [_s for _s in dir() if not _s.startswith('_')] __all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -34,12 +34,15 @@ from dragon.core.ops.activation_ops import relu6 ...@@ -34,12 +34,15 @@ from dragon.core.ops.activation_ops import relu6
from dragon.core.ops.activation_ops import selu from dragon.core.ops.activation_ops import selu
from dragon.core.ops.activation_ops import silu from dragon.core.ops.activation_ops import silu
from dragon.core.ops.activation_ops import softmax from dragon.core.ops.activation_ops import softmax
from dragon.core.ops.array_ops import channel_shuffle
from dragon.core.ops.math_ops import moments from dragon.core.ops.math_ops import moments
from dragon.core.ops.normalization_ops import batch_norm from dragon.core.ops.normalization_ops import batch_norm
from dragon.core.ops.normalization_ops import channel_norm
from dragon.core.ops.normalization_ops import group_norm from dragon.core.ops.normalization_ops import group_norm
from dragon.core.ops.normalization_ops import instance_norm from dragon.core.ops.normalization_ops import instance_norm
from dragon.core.ops.normalization_ops import layer_norm from dragon.core.ops.normalization_ops import layer_norm
from dragon.core.ops.normalization_ops import local_response_norm from dragon.core.ops.normalization_ops import local_response_norm
from dragon.core.ops.normalization_ops import lp_norm
from dragon.core.ops.normalization_ops import sync_batch_norm from dragon.core.ops.normalization_ops import sync_batch_norm
from dragon.core.ops.vision_ops import bias_add from dragon.core.ops.vision_ops import bias_add
from dragon.core.ops.vision_ops import conv from dragon.core.ops.vision_ops import conv
......
...@@ -78,16 +78,13 @@ def cast_args(**kwargs): ...@@ -78,16 +78,13 @@ def cast_args(**kwargs):
return {'dtype': kwargs.get('dtype', 'float32')} return {'dtype': kwargs.get('dtype', 'float32')}
@register('ChannelAffine') @register('Affine')
def channel_affine_args(**kwargs): def affine_args(**kwargs):
return { return {'axes': kwargs.get('axes', None)}
'axis': kwargs.get('axis', -1),
'end_axis': kwargs.get('end_axis', kwargs.get('axis', -1)),
}
@register('ChannelNormalize') @register('ChannelNorm')
def channel_normalize_args(**kwargs): def channel_norm_args(**kwargs):
return { return {
'axis': kwargs.get('axis', -1), 'axis': kwargs.get('axis', -1),
'mean': kwargs.get('mean', None), 'mean': kwargs.get('mean', None),
...@@ -323,8 +320,8 @@ def loss_args(**kwargs): ...@@ -323,8 +320,8 @@ def loss_args(**kwargs):
return {'reduction': kwargs.get('reduction', 'MEAN')} return {'reduction': kwargs.get('reduction', 'MEAN')}
@register('LpNormalize') @register('LpNorm')
def lp_normalize_args(**kwargs): def lp_norm_args(**kwargs):
return { return {
'p': kwargs.get('p', 2), 'p': kwargs.get('p', 2),
'axis': kwargs.get('axis', -1), 'axis': kwargs.get('axis', -1),
......
...@@ -81,6 +81,7 @@ def binary_shape_spec(inputs, outputs): ...@@ -81,6 +81,7 @@ def binary_shape_spec(inputs, outputs):
@register([ @register([
'Add', 'Add',
'Atan2',
'BitwiseAnd', 'BitwiseAnd',
'BitwiseOr', 'BitwiseOr',
'BitwiseXor', 'BitwiseXor',
...@@ -403,7 +404,7 @@ def gemm_spec(args, inputs, outputs): ...@@ -403,7 +404,7 @@ def gemm_spec(args, inputs, outputs):
return outputs return outputs
@register('ChannelNormalize') @register('ChannelNorm')
def channel_normalize_spec(args, inputs, outputs): def channel_normalize_spec(args, inputs, outputs):
outputs[0]._dtype = args['dtype'] outputs[0]._dtype = args['dtype']
try: try:
......
...@@ -62,11 +62,23 @@ def current_device(): ...@@ -62,11 +62,23 @@ def current_device():
return backend.cudaGetDevice() return backend.cudaGetDevice()
def set_cublas_flags(allow_tf32=None):
"""Set the flags of cuBLAS library.
Parameters
----------
allow_tf32 : bool, optional, default=False
Allow TF32 tensor core operation or not.
"""
backend.cublasSetFlags(-1 if allow_tf32 is None else allow_tf32)
def set_cudnn_flags( def set_cudnn_flags(
enabled=True, enabled=None,
benchmark=False, benchmark=None,
deterministic=False, deterministic=None,
allow_tf32=False, allow_tf32=None,
): ):
"""Set the flags of cuDNN library. """Set the flags of cuDNN library.
...@@ -82,7 +94,11 @@ def set_cudnn_flags( ...@@ -82,7 +94,11 @@ def set_cudnn_flags(
Allow TF32 tensor core operation or not. Allow TF32 tensor core operation or not.
""" """
backend.cudnnSetFlags(enabled, benchmark, deterministic, allow_tf32) backend.cudnnSetFlags(
-1 if enabled is None else enabled,
-1 if benchmark is None else benchmark,
-1 if deterministic is None else deterministic,
-1 if allow_tf32 is None else allow_tf32)
def get_device_capability(device_index=None): def get_device_capability(device_index=None):
......
...@@ -122,107 +122,16 @@ def broadcast_to(inputs, shape, **kwargs): ...@@ -122,107 +122,16 @@ def broadcast_to(inputs, shape, **kwargs):
return OpLib.add('Expand', **args) return OpLib.add('Expand', **args)
@OpSchema.num_inputs(2, 3)
def channel_affine(inputs, axis=-1, end_axis=None, **kwargs):
r"""Apply affine transformation to each channel of input.
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input, weight and optional bias tensor.
axis : int, optional, default=-1
The first channel axis.
end_axis : int, optional
The last channel axis.
Returns
-------
dragon.Tensor
The output tensor.
"""
outputs = kwargs.pop('outputs', [None])
if context.executing_eagerly():
return OpLib.execute(
'ChannelAffine', inputs, outputs=outputs,
axis=axis, end_axis=end_axis)
return OpLib.add('ChannelAffine', inputs,
axis=axis, end_axis=end_axis, **kwargs)
@OpSchema.num_inputs(1)
@OpSchema.convert_arg('perm')
def channel_normalize(
inputs,
mean,
std,
axis=-1,
dtype='float32',
perm=None,
**kwargs
):
"""Apply normalization to each channel of input.
:attr:`axis` can be negative:
```python
m = s = (1., 1., 1.)
x = dragon.constant([1, 2, 3])
print(dragon.channel_normalize(x, m, s, axis=0)) # [0., 1., 2.]
print(dragon.channel_normalize(x, m, s, axis=-1)) # Equivalent
```
If :attr:`perm` provided, :attr:`axis` is selected from the output layout:
```python
m, s = (1., 2., 3.), (1., 1., 1.)
x = dragon.constant([[1, 2, 3]])
# Provided 3 values to normalize the last axis
# with length 1, only the first value will be taken
print(dragon.channel_normalize(x, m, s, perm=(1, 0))) # [[0.], [1.], [2.]]
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
mean : Sequence[float], required
The mean to subtract.
std : Sequence[float], required
The standard deviation to divide.
axis : int, optional, default=-1
The channel axis.
dtype : str, optional, default='float32'
The output data type.
perm : Sequence[Union[int, dragon.Tensor]], optional
The output permutation.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = OpSchema.parse_args(locals())
if context.executing_eagerly():
return OpLib.execute(
'ChannelNormalize', inputs,
axis=axis, mean=mean, std=std, dtype=dtype,
ndim=len(args['perm']) if perm is not None else 0,
perm=args['perm'])
return OpLib.add('ChannelNormalize', **args)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def channel_shuffle(inputs, axis=-1, group=1, **kwargs): def channel_shuffle(inputs, axis=-1, group=1, **kwargs):
"""Apply group shuffle to each channel of input. """Apply the group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_. `[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
Examples: Examples:
```python ```python
x = dragon.constant([1, 2, 3, 4]) x = dragon.constant([1, 2, 3, 4])
print(dragon.channel_shuffle(x, group=2)) # [1, 3, 2, 4] print(dragon.nn.channel_shuffle(x, group=2)) # [1, 3, 2, 4]
``` ```
Parameters Parameters
......
...@@ -82,6 +82,30 @@ def add(inputs, **kwargs): ...@@ -82,6 +82,30 @@ def add(inputs, **kwargs):
return OpLib.add('Add', inputs, **kwargs) return OpLib.add('Add', inputs, **kwargs)
@OpSchema.num_inputs(2, 3)
def affine(inputs, axis=-1, **kwargs):
"""Apply affine transformation to input.
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input, scale and bias tensor.
axis : Union[int, Sequence[int]], optional, default=-1
The axis to apply.
Returns
-------
dragon.Tensor
The output tensor.
"""
axes = nest.flatten(axis)
outputs = kwargs.pop('outputs', [None])
if context.executing_eagerly():
return OpLib.execute('Affine', inputs, outputs=outputs, axes=axes)
return OpLib.add('Affine', inputs, axes=axes, **kwargs)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def argmax(inputs, axis=0, keepdims=False, **kwargs): def argmax(inputs, axis=0, keepdims=False, **kwargs):
"""Compute the index of maximum elements along the given axis. """Compute the index of maximum elements along the given axis.
...@@ -149,6 +173,37 @@ def argmin(inputs, axis=0, keepdims=False, **kwargs): ...@@ -149,6 +173,37 @@ def argmin(inputs, axis=0, keepdims=False, **kwargs):
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
def atan2(inputs, **kwargs):
r"""Compute the element-wise arc-tangent of two arguments.
.. math:: \text{out} = \text{arctan}(\frac{\text{input1}}{\text{input2}})
Examples:
```python
y = dragon.constant(1)
x = dragon.constant(2)
print(dragon.math.atan2([y, x])) # 0.46364761
```
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input1 and input2 tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
inputs = constant_ops.remove_scalars(inputs)
if context.executing_eagerly():
return OpLib.execute('Atan2', inputs)
return OpLib.add('Atan2', inputs, **kwargs)
@OpSchema.num_inputs(2)
def bitwise_and(inputs, **kwargs): def bitwise_and(inputs, **kwargs):
r"""Compute the element-wise AND bitwise operation. r"""Compute the element-wise AND bitwise operation.
......
...@@ -72,6 +72,69 @@ def batch_norm( ...@@ -72,6 +72,69 @@ def batch_norm(
return OpLib.add('BatchNorm', **args) return OpLib.add('BatchNorm', **args)
@OpSchema.num_inputs(1)
@OpSchema.convert_arg('perm')
def channel_norm(
inputs,
mean,
std,
axis=-1,
dtype='float32',
perm=None,
**kwargs
):
"""Apply the normalization to each channel of input.
:attr:`axis` can be negative:
```python
m = s = (1., 1., 1.)
x = dragon.constant([1, 2, 3])
print(dragon.nn.channel_norm(x, m, s, axis=0)) # [0., 1., 2.]
print(dragon.nn.channel_norm(x, m, s, axis=-1)) # Equivalent
```
If :attr:`perm` provided, :attr:`axis` is selected from the output layout:
```python
m, s = (1., 2., 3.), (1., 1., 1.)
x = dragon.constant([[1, 2, 3]])
# Provided 3 values to normalize the last axis
# with length 1, only the first value will be taken
print(dragon.nn.channel_norm(x, m, s, perm=(1, 0))) # [[0.], [1.], [2.]]
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
mean : Sequence[float], required
The mean to subtract.
std : Sequence[float], required
The standard deviation to divide.
axis : int, optional, default=-1
The channel axis.
dtype : str, optional, default='float32'
The output data type.
perm : Sequence[Union[int, dragon.Tensor]], optional
The output permutation.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = OpSchema.parse_args(locals())
if context.executing_eagerly():
return OpLib.execute(
'ChannelNorm', inputs,
axis=axis, mean=mean, std=std, dtype=dtype,
ndim=len(args['perm']) if perm is not None else 0,
perm=args['perm'])
return OpLib.add('ChannelNorm', **args)
@OpSchema.num_inputs(3) @OpSchema.num_inputs(3)
def group_norm(inputs, axis=-1, group=0, epsilon=1e-5, **kwargs): def group_norm(inputs, axis=-1, group=0, epsilon=1e-5, **kwargs):
r"""Apply the group normalization. r"""Apply the group normalization.
...@@ -180,7 +243,7 @@ def layer_norm(inputs, axis=-1, epsilon=1e-5, **kwargs): ...@@ -180,7 +243,7 @@ def layer_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def lp_normalize( def lp_norm(
inputs, inputs,
axis=-1, axis=-1,
end_axis=None, end_axis=None,
...@@ -200,15 +263,15 @@ def lp_normalize( ...@@ -200,15 +263,15 @@ def lp_normalize(
```python ```python
x = dragon.constant([[1, 2, 3], [4, 5, 6]], 'float32') x = dragon.constant([[1, 2, 3], [4, 5, 6]], 'float32')
# A negative axis is the last-k axis # A negative axis is the last-k axis
print(dragon.math.lp_normalize(x, 1)) print(dragon.nn.lp_norm(x, 1))
print(dragon.math.lp_normalize(x, -1)) # Equivalent print(dragon.nn.lp_norm(x, -1)) # Equivalent
``` ```
More than one axis could be specified to reduce: More than one axis could be specified to reduce:
```python ```python
# Along the continuous axes: [axis, end_axis] # Along the continuous axes: [axis, end_axis]
print(dragon.math.lp_normalize(x, axis=0, end_axis=1)) print(dragon.nn.lp_norm(x, axis=0, end_axis=1))
``` ```
Parameters Parameters
...@@ -236,9 +299,9 @@ def lp_normalize( ...@@ -236,9 +299,9 @@ def lp_normalize(
reduction = reduction.upper() reduction = reduction.upper()
if context.executing_eagerly(): if context.executing_eagerly():
return OpLib.execute( return OpLib.execute(
'LpNormalize', inputs, p=p, axis=axis, end_axis=end_axis, 'LpNorm', inputs, p=p, axis=axis, end_axis=end_axis,
epsilon=epsilon, reduction=reduction) epsilon=epsilon, reduction=reduction)
return OpLib.add('LpNormalize', inputs, p=p, axis=axis, end_axis=end_axis, return OpLib.add('LpNorm', inputs, p=p, axis=axis, end_axis=end_axis,
epsilon=epsilon, reduction=reduction, **kwargs) epsilon=epsilon, reduction=reduction, **kwargs)
......
...@@ -24,9 +24,11 @@ class Adam(optimizer.Optimizer): ...@@ -24,9 +24,11 @@ class Adam(optimizer.Optimizer):
The **Adam** update is defined as: The **Adam** update is defined as:
.. math:: .. math::
\text{Adam}(g) = \frac{\text{lr} * m_{t}}{\sqrt{v_{t}} + \epsilon} \\ \text{Adam}(g) = \text{lr} * (\frac{\text{correction}* m_{t}}
{\sqrt{v_{t}} + \epsilon}) \\
\quad \\ \text{where}\quad \quad \\ \text{where}\quad
\begin{cases} \begin{cases}
\text{correction} = \sqrt{1 - \beta_{2}^{t}} / (1 - \beta_{1}^{t}) \\
m_{t} = \beta_{1} * m_{t-1} + (1 - \beta_{1}) * g \\ m_{t} = \beta_{1} * m_{t-1} + (1 - \beta_{1}) * g \\
v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2} v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2}
\end{cases} \end{cases}
...@@ -62,12 +64,13 @@ class AdamW(Adam): ...@@ -62,12 +64,13 @@ class AdamW(Adam):
The **AdamW** update is defined as: The **AdamW** update is defined as:
.. math:: .. math::
\text{AdamW}(g, p) = \text{lr} * (\frac{m_{t}}{\sqrt{v_{t}} + \epsilon} \text{AdamW}(g, p) = \text{lr} * (\frac{\text{correction} * m_{t}}
+ \lambda p) \\ {\sqrt{v_{t}} + \epsilon} + \lambda p) \\
\quad \\ \text{where}\quad \quad \\ \text{where}\quad
\begin{cases} \begin{cases}
\text{correction} = \sqrt{1 - \beta_{2}^{t}} / (1 - \beta_{1}^{t}) \\
m_{t} = \beta_{1} * m_{t-1} + (1 - \beta_{1}) * g \\ m_{t} = \beta_{1} * m_{t-1} + (1 - \beta_{1}) * g \\
v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2} v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2} \\
\end{cases} \end{cases}
""" """
......
...@@ -27,13 +27,13 @@ class RMSprop(optimizer.Optimizer): ...@@ -27,13 +27,13 @@ class RMSprop(optimizer.Optimizer):
\text{RMSprop}(g) = \text{lr} * m_{t} \\ \text{RMSprop}(g) = \text{lr} * m_{t} \\
\quad \\ \text{where} \quad \quad \\ \text{where} \quad
\begin{cases} \begin{cases}
v_{t} = \text{decay} * v_{t-1} + (1 - \text{decay}) * g^{2} \\ v_{t} = \alpha * v_{t-1} + (1 - \alpha) * g^{2} \\
m_{t} = \text{momentum} * m_{t-1} + \frac{g}{\sqrt{v_{t}} + \epsilon} m_{t} = \text{momentum} * m_{t-1} + \frac{g}{\sqrt{v_{t}} + \epsilon}
\end{cases} \end{cases}
""" """
def __init__(self, lr=0.01, momentum=0, decay=0.9, eps=1e-8, **kwargs): def __init__(self, lr=0.01, momentum=0, alpha=0.9, eps=1e-8, **kwargs):
r"""Create a ``RMSProp`` optimizer. r"""Create a ``RMSProp`` optimizer.
Parameters Parameters
...@@ -42,8 +42,8 @@ class RMSprop(optimizer.Optimizer): ...@@ -42,8 +42,8 @@ class RMSprop(optimizer.Optimizer):
The initial value to :math:`\text{lr}`. The initial value to :math:`\text{lr}`.
momentum : float, optional, default=0 momentum : float, optional, default=0
The initial value to :math:`\text{momentum}`. The initial value to :math:`\text{momentum}`.
decay : float, optional, default=0.9 alpha : float, optional, default=0.9
The initial value to :math:`\text{decay}`. The initial value to :math:`\alpha`.
eps : float, optional, default=1e-8 eps : float, optional, default=1e-8
The initial value to :math:`\epsilon`. The initial value to :math:`\epsilon`.
...@@ -51,5 +51,5 @@ class RMSprop(optimizer.Optimizer): ...@@ -51,5 +51,5 @@ class RMSprop(optimizer.Optimizer):
super(RMSprop, self).__init__(**kwargs) super(RMSprop, self).__init__(**kwargs)
self._set_hyper('lr', lr) self._set_hyper('lr', lr)
self._set_hyper('momentum', momentum) self._set_hyper('momentum', momentum)
self._set_hyper('decay', decay) self._set_hyper('alpha', alpha)
self._set_hyper('eps', eps) self._set_hyper('eps', eps)
...@@ -51,41 +51,6 @@ def cast_exporter(op_def, context): ...@@ -51,41 +51,6 @@ def cast_exporter(op_def, context):
return node, const_tensors return node, const_tensors
@export_util.register('ChannelAffine')
def channel_affine_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelAffine')
for arg in op_def.arg:
if arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
elif arg.name == 'end_axis':
helper.add_attribute(node, 'end_axis', arg.i)
return node, const_tensors
@export_util.register('ChannelNormalize')
def channel_normalize_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelNormalize')
for arg in op_def.arg:
if arg.name == 'mean':
helper.add_attribute(node, 'mean', arg.floats)
elif arg.name == 'std':
helper.add_attribute(node, 'std', arg.floats)
elif arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
elif arg.name == 'dtype':
helper.add_attribute(node, 'dtype', arg.s)
elif arg.name == 'perm':
helper.add_attribute(node, 'perm', arg.ints)
elif arg.name == 'perm_desc':
values = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'perm', values)
return node, const_tensors
@export_util.register('Concat') @export_util.register('Concat')
def concat_exporter(op_def, context): def concat_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
......
...@@ -31,6 +31,17 @@ def add_exporter(op_def, context): ...@@ -31,6 +31,17 @@ def add_exporter(op_def, context):
return node, const_tensors return node, const_tensors
@export_util.register('Affine')
def affine_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'Affine')
for arg in op_def.arg:
if arg.name == 'axes':
helper.add_attribute(node, 'axes', arg.ints)
return node, const_tensors
@export_util.register('Div') @export_util.register('Div')
def div_exporter(op_def, context): def div_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
......
...@@ -32,6 +32,28 @@ def batch_norm_exporter(op_def, context): ...@@ -32,6 +32,28 @@ def batch_norm_exporter(op_def, context):
return node, const_tensors return node, const_tensors
@export_util.register('ChannelNorm')
def channel_norm_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelNorm')
for arg in op_def.arg:
if arg.name == 'mean':
helper.add_attribute(node, 'mean', arg.floats)
elif arg.name == 'std':
helper.add_attribute(node, 'std', arg.floats)
elif arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
elif arg.name == 'dtype':
helper.add_attribute(node, 'dtype', arg.s)
elif arg.name == 'perm':
helper.add_attribute(node, 'perm', arg.ints)
elif arg.name == 'perm_desc':
values = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'perm', values)
return node, const_tensors
@export_util.register('GroupNorm') @export_util.register('GroupNorm')
def group_norm_exporter(op_def, context): def group_norm_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
...@@ -49,8 +71,8 @@ def group_norm_exporter(op_def, context): ...@@ -49,8 +71,8 @@ def group_norm_exporter(op_def, context):
return node, const_tensors return node, const_tensors
@export_util.register('LpNormalize') @export_util.register('LpNorm')
def lp_normalize_exporter(op_def, context): def lp_norm_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'LpNormalization' node.op_type = 'LpNormalization'
axis, end_axis = None, None axis, end_axis = None, None
......
...@@ -33,9 +33,6 @@ constexpr int CUDA_WARP_SIZE = 32; ...@@ -33,9 +33,6 @@ constexpr int CUDA_WARP_SIZE = 32;
/*! \brief The number of cuda threads in a block */ /*! \brief The number of cuda threads in a block */
constexpr int CUDA_THREADS = 256; constexpr int CUDA_THREADS = 256;
/*! \brief The maximum number of blocks to use in a default kernel call */
constexpr int CUDA_MAX_BLOCKS = 4096;
/*! \brief The maximum number of devices in a single machine */ /*! \brief The maximum number of devices in a single machine */
constexpr int CUDA_MAX_DEVICES = 16; constexpr int CUDA_MAX_DEVICES = 16;
...@@ -82,12 +79,15 @@ constexpr int CUDA_TENSOR_MAX_DIMS = 8; ...@@ -82,12 +79,15 @@ constexpr int CUDA_TENSOR_MAX_DIMS = 8;
for (size_t j = threadIdx.x; j < m; j += blockDim.x) for (size_t j = threadIdx.x; j < m; j += blockDim.x)
inline int CUDA_BLOCKS(const int N) { inline int CUDA_BLOCKS(const int N) {
return std::max( int device, sm_count, threads_per_sm;
std::min((N + CUDA_THREADS - 1) / CUDA_THREADS, CUDA_MAX_BLOCKS), 1); CUDA_CHECK(cudaGetDevice(&device));
} CUDA_CHECK(cudaDeviceGetAttribute(
&sm_count, cudaDevAttrMultiProcessorCount, device));
inline int CUDA_2D_BLOCKS(const int N) { CUDA_CHECK(cudaDeviceGetAttribute(
return std::max(std::min(N, CUDA_MAX_BLOCKS), 1); &threads_per_sm, cudaDevAttrMaxThreadsPerMultiProcessor, device));
const auto num_blocks = (N + CUDA_THREADS - 1) / CUDA_THREADS;
const auto max_blocks = sm_count * threads_per_sm / CUDA_THREADS * 32;
return std::max(1, std::min(num_blocks, max_blocks));
} }
#if CUDA_VERSION_MAX(9, 0) #if CUDA_VERSION_MAX(9, 0)
......
...@@ -84,6 +84,7 @@ DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Sub, T); ...@@ -84,6 +84,7 @@ DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Sub, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Mul, T); DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Mul, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Div, T); DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Div, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Pow, T); DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Pow, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Atan2, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Minimum, T); DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Minimum, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Maximum, T); DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Maximum, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Equal, bool); DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Equal, bool);
...@@ -434,6 +435,7 @@ DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(And, bool, std::logical_and); ...@@ -434,6 +435,7 @@ DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(And, bool, std::logical_and);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Or, bool, std::logical_or); DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Or, bool, std::logical_or);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Xor, bool, math::XorFunctor); DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Xor, bool, math::XorFunctor);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Pow, T, math::PowFunctor); DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Pow, T, math::PowFunctor);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Atan2, T, math::Atan2Functor);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Minimum, T, math::MinFunctor); DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Minimum, T, math::MinFunctor);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Maximum, T, math::MaxFunctor); DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Maximum, T, math::MaxFunctor);
#undef DEFINE_ROWWISE_COLWISE_BIANRY_FUNC #undef DEFINE_ROWWISE_COLWISE_BIANRY_FUNC
...@@ -469,6 +471,7 @@ DEFINE_BROADCAST_BINARY_FUNC(And, bool, std::logical_and); ...@@ -469,6 +471,7 @@ DEFINE_BROADCAST_BINARY_FUNC(And, bool, std::logical_and);
DEFINE_BROADCAST_BINARY_FUNC(Or, bool, std::logical_or); DEFINE_BROADCAST_BINARY_FUNC(Or, bool, std::logical_or);
DEFINE_BROADCAST_BINARY_FUNC(Xor, bool, math::XorFunctor); DEFINE_BROADCAST_BINARY_FUNC(Xor, bool, math::XorFunctor);
DEFINE_BROADCAST_BINARY_FUNC(Pow, T, math::PowFunctor); DEFINE_BROADCAST_BINARY_FUNC(Pow, T, math::PowFunctor);
DEFINE_BROADCAST_BINARY_FUNC(Atan2, T, math::Atan2Functor);
DEFINE_BROADCAST_BINARY_FUNC(Minimum, T, math::MinFunctor); DEFINE_BROADCAST_BINARY_FUNC(Minimum, T, math::MinFunctor);
DEFINE_BROADCAST_BINARY_FUNC(Maximum, T, math::MaxFunctor); DEFINE_BROADCAST_BINARY_FUNC(Maximum, T, math::MaxFunctor);
#undef DEFINE_BROADCAST_BINARY_FUNC #undef DEFINE_BROADCAST_BINARY_FUNC
...@@ -612,6 +615,9 @@ DEFINE_BINARY_FUNC(Div, float, float); ...@@ -612,6 +615,9 @@ DEFINE_BINARY_FUNC(Div, float, float);
DEFINE_BINARY_FUNC(Div, double, double); DEFINE_BINARY_FUNC(Div, double, double);
DEFINE_BINARY_FUNC(Pow, float, float); DEFINE_BINARY_FUNC(Pow, float, float);
DEFINE_BINARY_FUNC(Pow, double, double); DEFINE_BINARY_FUNC(Pow, double, double);
DEFINE_BINARY_FUNC(Atan2, float16, float16);
DEFINE_BINARY_FUNC(Atan2, float, float);
DEFINE_BINARY_FUNC(Atan2, double, double);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t); DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t); DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t);
DEFINE_BINARY_FUNC(Minimum, int, int); DEFINE_BINARY_FUNC(Minimum, int, int);
......
...@@ -388,6 +388,9 @@ DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor); ...@@ -388,6 +388,9 @@ DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor); DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor); DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor); DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Atan2, float16, float16, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, float, float, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, double, double, math::Atan2Functor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
......
...@@ -122,6 +122,17 @@ DRAGON_API void Pow( ...@@ -122,6 +122,17 @@ DRAGON_API void Pow(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void Atan2(
const int A_ndim,
const int64_t* A_dims,
const int B_ndim,
const int64_t* B_dims,
const T* a,
const T* b,
T* y,
Context* ctx);
template <typename T, class Context>
DRAGON_API void Minimum( DRAGON_API void Minimum(
const int A_ndim, const int A_ndim,
const int64_t* A_dims, const int64_t* A_dims,
......
...@@ -550,6 +550,9 @@ DEFINE_BINARY_FUNC(Maximum, double, double, max); ...@@ -550,6 +550,9 @@ DEFINE_BINARY_FUNC(Maximum, double, double, max);
_SimpleBinaryFunc(N, Functor<InputT>(), a, b, y); \ _SimpleBinaryFunc(N, Functor<InputT>(), a, b, y); \
} }
DEFINE_BINARY_FUNC(Atan2, float16, float16, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, float, float, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, double, double, math::Atan2Functor);
DEFINE_BINARY_FUNC(BitwiseAnd, bool, bool, std::bit_and); DEFINE_BINARY_FUNC(BitwiseAnd, bool, bool, std::bit_and);
DEFINE_BINARY_FUNC(BitwiseAnd, uint8_t, uint8_t, std::bit_and); DEFINE_BINARY_FUNC(BitwiseAnd, uint8_t, uint8_t, std::bit_and);
DEFINE_BINARY_FUNC(BitwiseAnd, int8_t, int8_t, std::bit_and); DEFINE_BINARY_FUNC(BitwiseAnd, int8_t, int8_t, std::bit_and);
......
...@@ -342,7 +342,10 @@ _Where(const int N, const T* a, const T* b, const bool* c, T* y) { ...@@ -342,7 +342,10 @@ _Where(const int N, const T* a, const T* b, const bool* c, T* y) {
DRAGON_API void name<InputT, CUDAContext>( \ DRAGON_API void name<InputT, CUDAContext>( \
const int N, const InputT* x, OutputT* y, CUDAContext* ctx) { \ const int N, const InputT* x, OutputT* y, CUDAContext* ctx) { \
_SimpleUnaryFunc<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _SimpleUnaryFunc<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, Functor<InputT>(), x, y); \ N, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(x), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} }
DEFINE_UNARY_FUNC(BitwiseNot, bool, bool, math::BitNotFunctor); DEFINE_UNARY_FUNC(BitwiseNot, bool, bool, math::BitNotFunctor);
...@@ -706,6 +709,87 @@ DEFINE_APPLY_MASK_FUNC(float, float); ...@@ -706,6 +709,87 @@ DEFINE_APPLY_MASK_FUNC(float, float);
DEFINE_APPLY_MASK_FUNC(double, double); DEFINE_APPLY_MASK_FUNC(double, double);
#undef DEFINE_APPLY_MASK_FUNC #undef DEFINE_APPLY_MASK_FUNC
#define DEFINE_BINARY_FUNC(name, T, Functor) \
template <> \
DRAGON_API void name<T, CUDAContext>( \
const int N, const T* a, const T* b, T* y, CUDAContext* ctx) { \
using ScalarT = typename math::ScalarType<T>::type; \
using ScalarT2 = typename math::ScalarType<T>::type2; \
if ((N & 1) == 0 && sizeof(ScalarT) != sizeof(ScalarT2)) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(N >> 1), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N >> 1, \
Functor<ScalarT2>(), \
reinterpret_cast<const ScalarT2*>(a), \
reinterpret_cast<const ScalarT2*>(b), \
reinterpret_cast<ScalarT2*>(y)); \
} else { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(N), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N, \
Functor<ScalarT>(), \
reinterpret_cast<const ScalarT*>(a), \
reinterpret_cast<const ScalarT*>(b), \
reinterpret_cast<ScalarT*>(y)); \
} \
}
DEFINE_BINARY_FUNC(Add, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float16, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float16, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float16, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Atan2, float16, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, float, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, double, math::Atan2Functor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, math::MaxFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, InputT, OutputT, Functor) \ #define DEFINE_BINARY_FUNC(name, InputT, OutputT, Functor) \
template <> \ template <> \
DRAGON_API void name<InputT, CUDAContext>( \ DRAGON_API void name<InputT, CUDAContext>( \
...@@ -726,51 +810,6 @@ DEFINE_APPLY_MASK_FUNC(double, double); ...@@ -726,51 +810,6 @@ DEFINE_APPLY_MASK_FUNC(double, double);
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \ reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} }
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float16, float16, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, float16, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float16, float16, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float16, float16, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, float16, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, float16, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor);
DEFINE_BINARY_FUNC(BitwiseAnd, bool, bool, math::BitAndFunctor); DEFINE_BINARY_FUNC(BitwiseAnd, bool, bool, math::BitAndFunctor);
DEFINE_BINARY_FUNC(BitwiseAnd, uint8_t, uint8_t, math::BitAndFunctor); DEFINE_BINARY_FUNC(BitwiseAnd, uint8_t, uint8_t, math::BitAndFunctor);
DEFINE_BINARY_FUNC(BitwiseAnd, int8_t, int8_t, math::BitAndFunctor); DEFINE_BINARY_FUNC(BitwiseAnd, int8_t, int8_t, math::BitAndFunctor);
......
...@@ -126,6 +126,9 @@ template <typename T, class Context> ...@@ -126,6 +126,9 @@ template <typename T, class Context>
DRAGON_API void Pow(const int N, const T* a, const T* b, T* y, Context* ctx); DRAGON_API void Pow(const int N, const T* a, const T* b, T* y, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void Atan2(const int N, const T* a, const T* b, T* y, Context* ctx);
template <typename T, class Context>
DRAGON_API void DRAGON_API void
Minimum(const int N, const T* a, const T* b, T* y, Context* ctx); Minimum(const int N, const T* a, const T* b, T* y, Context* ctx);
......
...@@ -108,13 +108,11 @@ void _GenericReduce( ...@@ -108,13 +108,11 @@ void _GenericReduce(
} \ } \
if (math::utils::IsRowwiseReduce( \ if (math::utils::IsRowwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \ num_dims, dims, out_dims.data(), &rows, &cols)) { \
_RowwiseReduce##name(rows, cols, scale, x, y); \ return _RowwiseReduce##name(rows, cols, scale, x, y); \
return; \
} \ } \
if (math::utils::IsColwiseReduce( \ if (math::utils::IsColwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \ num_dims, dims, out_dims.data(), &rows, &cols)) { \
_ColwiseReduce##name(rows, cols, scale, x, y); \ return _ColwiseReduce##name(rows, cols, scale, x, y); \
return; \
} \ } \
vec64_t transpose_axes(num_dims); \ vec64_t transpose_axes(num_dims); \
vec64_t transpose_strides(num_dims); \ vec64_t transpose_strides(num_dims); \
......
#include "dragon/utils/math/transform.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/math/reduce.h"
namespace dragon {
namespace math {
namespace {
template <typename T>
void _AffineChannel(
const int N,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y) {
EigenArrayMap<T> Y(y, C, N);
ConstEigenArrayMap<T> X(x, C, N);
Y = X.colwise() * ConstEigenVectorArrayMap<T>(scale, C);
if (bias != nullptr) {
Y.colwise() += ConstEigenVectorArrayMap<T>(bias, C);
}
}
template <typename T>
void _AffineChannel(
const int N,
const int C,
const int S,
const T* x,
const T* scale,
const T* bias,
T* y) {
const auto CxS = C * S;
for (int i = 0; i < N; ++i) {
EigenArrayMap<T> Y(y + i * CxS, S, C);
ConstEigenArrayMap<T> X(x + i * CxS, S, C);
Y = X.rowwise() * ConstEigenVectorArrayMap<T>(scale, C).transpose();
if (bias != nullptr) {
Y.rowwise() += ConstEigenVectorArrayMap<T>(bias, C).transpose();
}
}
}
template <typename T>
void _AffineImpl(
const int num_dims,
const int64_t* dims,
const int num_axes,
const int64_t* axes,
const T* x,
const T* scale,
const T* bias,
T* y) {
if (num_dims == 2 && num_axes == 1 && axes[0] == 1) {
_AffineChannel(dims[0], dims[1], x, scale, bias, y);
} else if (num_dims == 3 && num_axes == 1 && axes[0] == 1) {
_AffineChannel(dims[0], dims[1], dims[2], x, scale, bias, y);
} else {
LOG(FATAL) << "Unsupported affine dimensions.";
}
}
} // namespace
template <>
void Affine<float16, CPUContext>(
const int num_dims,
const int64_t* dims,
const int num_axes,
const int64_t* axes,
const float16* x,
const float16* scale,
const float16* bias,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_AFFINE_FUNC(T) \
template <> \
void Affine<T, CPUContext>( \
const int num_dims, \
const int64_t* dims, \
const int num_axes, \
const int64_t* axes, \
const T* x, \
const T* scale, \
const T* bias, \
T* y, \
CPUContext* ctx) { \
vec64_t new_dims, new_axes; \
math::utils::CollapseReduceAxes( \
num_dims, dims, num_axes, axes, new_dims, new_axes); \
_AffineImpl( \
new_dims.size(), \
new_dims.data(), \
new_axes.size(), \
new_axes.data(), \
x, \
scale, \
bias, \
y); \
}
DEFINE_AFFINE_FUNC(float);
DEFINE_AFFINE_FUNC(double);
#undef DEFINE_AFFINE_FUNC
} // namespace math
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/transform.h"
#include "dragon/utils/math/types.h"
#include "dragon/utils/math/utils.h"
namespace dragon {
namespace math {
namespace {
template <typename T>
__global__ void _AffineChannel(
const int NxC,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y) {
auto op3 = math::FMAFunctor<T>();
auto op2 = math::MultipliesFunctor<T>();
CUDA_1D_KERNEL_LOOP(i, NxC) {
if (bias != nullptr) {
y[i] = op3(x[i], __ldg(scale + i % C), __ldg(bias + i % C));
} else {
y[i] = op2(x[i], __ldg(scale + i % C));
}
}
}
template <typename T>
__global__ void _AffineChannel(
const int NxCxS,
const int C,
const int S,
const T* x,
const T* scale,
const T* bias,
T* y) {
auto op3 = math::FMAFunctor<T>();
auto op2 = math::MultipliesFunctor<T>();
CUDA_1D_KERNEL_LOOP(i, NxCxS) {
const int j = (i / S) % C;
if (bias != nullptr) {
y[i] = op3(x[i], __ldg(scale + j), __ldg(bias + j));
} else {
y[i] = op2(x[i], __ldg(scale + j));
}
}
}
template <typename T>
void _AffineImpl(
const int num_dims,
const int64_t* dims,
const int num_axes,
const int64_t* axes,
const T* x,
const T* scale,
const T* bias,
T* y,
CUDAContext* ctx) {
const auto N = math::utils::Prod(num_dims, dims);
if (num_dims == 2 && num_axes == 1 && axes[0] == 1) {
_AffineChannel<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, dims[1], x, scale, bias, y);
} else if (num_dims == 3 && num_axes == 1 && axes[0] == 1) {
_AffineChannel<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, dims[1], dims[2], x, scale, bias, y);
} else {
LOG(FATAL) << "Unsupported affine dimensions.";
}
}
} // namespace
#define DEFINE_AFFINE_FUNC(T) \
template <> \
void Affine<T, CUDAContext>( \
const int num_dims, \
const int64_t* dims, \
const int num_axes, \
const int64_t* axes, \
const T* x, \
const T* scale, \
const T* bias, \
T* y, \
CUDAContext* ctx) { \
vec64_t new_dims, new_axes; \
math::utils::CollapseReduceAxes( \
num_dims, dims, num_axes, axes, new_dims, new_axes); \
_AffineImpl( \
new_dims.size(), \
new_dims.data(), \
new_axes.size(), \
new_axes.data(), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<const math::ScalarType<T>::type*>(scale), \
reinterpret_cast<const math::ScalarType<T>::type*>(bias), \
reinterpret_cast<math::ScalarType<T>::type*>(y), \
ctx); \
}
DEFINE_AFFINE_FUNC(float);
DEFINE_AFFINE_FUNC(float16);
DEFINE_AFFINE_FUNC(double);
#undef DEFINE_AFFINE_FUNC
} // namespace math
} // namespace dragon
#endif // USE_CUDA
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_MATH_TRANSFORM_H_
#define DRAGON_UTILS_MATH_TRANSFORM_H_
#include "dragon/core/context.h"
namespace dragon {
namespace math {
template <typename T, class Context>
DRAGON_API void Affine(
const int num_dims,
const int64_t* dims,
const int num_axes,
const int64_t* axes,
const T* x,
const T* scale,
const T* bias,
T* y,
Context* ctx);
} // namespace math
} // namespace dragon
#endif // DRAGON_UTILS_MATH_TRANSFORM_H_
...@@ -141,8 +141,7 @@ void _TransposeImpl( ...@@ -141,8 +141,7 @@ void _TransposeImpl(
CUDAContext* ctx) { CUDAContext* ctx) {
auto aligned_size = sizeof(T); auto aligned_size = sizeof(T);
if (axes.back() == D - 1) { if (axes.back() == D - 1) {
const auto N = math::utils::Prod(D, dims.data()); aligned_size = utils::GetAlignedSize<T, 16>(dims[D - 1], x, y);
aligned_size = utils::GetAlignedSize<T, 16>(N, x, y);
} }
SimpleArray<int, D> X_dims, X_strides, Y_dims; SimpleArray<int, D> X_dims, X_strides, Y_dims;
for (int i = 0; i < D; ++i) { for (int i = 0; i < D; ++i) {
......
...@@ -27,6 +27,7 @@ template <typename T> ...@@ -27,6 +27,7 @@ template <typename T>
class ScalarType { class ScalarType {
public: public:
typedef T type; typedef T type;
typedef T type2;
}; };
#if defined(__CUDACC__) #if defined(__CUDACC__)
...@@ -34,6 +35,7 @@ template <> ...@@ -34,6 +35,7 @@ template <>
class ScalarType<float16> { class ScalarType<float16> {
public: public:
typedef half type; typedef half type;
typedef half2 type2;
}; };
#endif #endif
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#include "dragon/utils/conversions.h" #include "dragon/utils/conversions.h"
#if defined(__CUDACC__) #if defined(__CUDACC__)
#define MATH_UTILS_DECL inline __host__ __device__ #define HOSTDEVICE_DECL inline __host__ __device__
#else #else
#define MATH_UTILS_DECL inline #define HOSTDEVICE_DECL inline
#endif #endif
#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) \ #define FIXED_DIVISOR_DIV_MOD(d, n, q, r) \
...@@ -41,28 +41,28 @@ namespace utils { ...@@ -41,28 +41,28 @@ namespace utils {
template < template <
typename T, typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0> typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
MATH_UTILS_DECL T IsInf(const T x) { HOSTDEVICE_DECL T IsInf(const T x) {
return false; return false;
} }
template < template <
typename T, typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0> typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
MATH_UTILS_DECL T IsNaN(const T x) { HOSTDEVICE_DECL T IsNaN(const T x) {
return false; return false;
} }
template < template <
typename T, typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0> typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
MATH_UTILS_DECL T IsFinite(const T x) { HOSTDEVICE_DECL T IsFinite(const T x) {
return true; return true;
} }
template < template <
typename T, typename T,
typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0> typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
MATH_UTILS_DECL bool IsInf(T x) { HOSTDEVICE_DECL bool IsInf(T x) {
#if defined(__CUDACC__) #if defined(__CUDACC__)
return isinf(x); return isinf(x);
#else #else
...@@ -73,7 +73,7 @@ MATH_UTILS_DECL bool IsInf(T x) { ...@@ -73,7 +73,7 @@ MATH_UTILS_DECL bool IsInf(T x) {
template < template <
typename T, typename T,
typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0> typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
MATH_UTILS_DECL bool IsNaN(T x) { HOSTDEVICE_DECL bool IsNaN(T x) {
#if defined(__CUDACC__) #if defined(__CUDACC__)
return isnan(x); return isnan(x);
#else #else
...@@ -84,7 +84,7 @@ MATH_UTILS_DECL bool IsNaN(T x) { ...@@ -84,7 +84,7 @@ MATH_UTILS_DECL bool IsNaN(T x) {
template < template <
typename T, typename T,
typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0> typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
MATH_UTILS_DECL bool IsFinite(T x) { HOSTDEVICE_DECL bool IsFinite(T x) {
#if defined(__CUDACC__) #if defined(__CUDACC__)
return isfinite(x); return isfinite(x);
#else #else
...@@ -106,27 +106,27 @@ inline bool IsFinite(float16 x) { ...@@ -106,27 +106,27 @@ inline bool IsFinite(float16 x) {
} }
template <typename T> template <typename T>
MATH_UTILS_DECL bool IsAGeZeroAndALtB(const T a, const T b) { HOSTDEVICE_DECL bool IsAGeZeroAndALtB(const T a, const T b) {
return static_cast<unsigned int>(a) < static_cast<unsigned int>(b); return static_cast<unsigned int>(a) < static_cast<unsigned int>(b);
} }
template <typename T> template <typename T>
MATH_UTILS_DECL T Sign(const T x) { HOSTDEVICE_DECL T Sign(const T x) {
return x > T(0) ? T(1) : (x < T(0) ? T(-1) : T(0)); return x > T(0) ? T(1) : (x < T(0) ? T(-1) : T(0));
} }
template <typename T> template <typename T>
MATH_UTILS_DECL T Identity(const T x) { HOSTDEVICE_DECL T Identity(const T x) {
return x; return x;
} }
template <typename T> template <typename T>
MATH_UTILS_DECL T Square(const T x) { HOSTDEVICE_DECL T Square(const T x) {
return x * x; return x * x;
} }
template <typename T> template <typename T>
MATH_UTILS_DECL T Cube(const T x) { HOSTDEVICE_DECL T Cube(const T x) {
return x * x * x; return x * x * x;
} }
...@@ -247,4 +247,6 @@ void IncreaseIndexInDims(const int num_dims, const DimT* dims, IndexT* index) { ...@@ -247,4 +247,6 @@ void IncreaseIndexInDims(const int num_dims, const DimT* dims, IndexT* index) {
} // namespace dragon } // namespace dragon
#undef HOSTDEVICE_DECL
#endif // DRAGON_UTILS_MATH_UTILS_H_ #endif // DRAGON_UTILS_MATH_UTILS_H_
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "dragon/utils/math/random.h" #include "dragon/utils/math/random.h"
#include "dragon/utils/math/reduce.h" #include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/sort.h" #include "dragon/utils/math/sort.h"
#include "dragon/utils/math/transform.h"
#include "dragon/utils/math/transpose.h" #include "dragon/utils/math/transpose.h"
#include "dragon/utils/math/types.h" #include "dragon/utils/math/types.h"
#include "dragon/utils/math/utils.h" #include "dragon/utils/math/utils.h"
......
...@@ -284,39 +284,6 @@ void BooleanMaskGrad( ...@@ -284,39 +284,6 @@ void BooleanMaskGrad(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void ChannelAffine(
const int N,
const int S,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y,
Context* ctx);
template <typename InputT, typename OutputT, class Context>
void ChannelNormalize(
const int axis,
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const InputT* x,
const float* mean,
const float* std,
OutputT* y,
Context* ctx);
template <typename T, class Context>
void ChannelShuffle(
const int N,
const int S,
const int C,
const int G,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void ConstPad( void ConstPad(
const int num_dims, const int num_dims,
const int64_t* x_dims, const int64_t* x_dims,
...@@ -813,6 +780,18 @@ void TopK( ...@@ -813,6 +780,18 @@ void TopK(
* NormalizationOp Kernels * NormalizationOp Kernels
*/ */
template <typename InputT, typename OutputT, class Context>
void ChannelNorm(
const int axis,
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const InputT* x,
const float* mean,
const float* std,
OutputT* y,
Context* ctx);
template <typename T, typename AccT, class Context> template <typename T, typename AccT, class Context>
void BatchNormExpectation( void BatchNormExpectation(
const int N, const int N,
...@@ -923,7 +902,7 @@ void GroupNormGrad( ...@@ -923,7 +902,7 @@ void GroupNormGrad(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void L1Normalize( void L1Norm(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -934,7 +913,7 @@ void L1Normalize( ...@@ -934,7 +913,7 @@ void L1Normalize(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void L1NormalizeGrad( void L1NormGrad(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -946,7 +925,7 @@ void L1NormalizeGrad( ...@@ -946,7 +925,7 @@ void L1NormalizeGrad(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void L2Normalize( void L2Norm(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -957,7 +936,7 @@ void L2Normalize( ...@@ -957,7 +936,7 @@ void L2Normalize(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void L2NormalizeGrad( void L2NormGrad(
const int N, const int N,
const int S, const int S,
const int C, const int C,
...@@ -1012,19 +991,23 @@ void LSTMCellGrad( ...@@ -1012,19 +991,23 @@ void LSTMCellGrad(
* TrainingOp Kernels * TrainingOp Kernels
*/ */
template <typename T, class Context> template <typename T, typename CopyT, class Context>
void Adam( void Adam(
const int N, const int N,
const float lr, const float lr,
const float beta1, const float beta1,
const float beta2, const float beta2,
const float eps, const float eps,
T* g, const float wd,
const T* x,
const T* g,
T* m, T* m,
T* v, T* v,
T* y,
CopyT* y_copy,
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, typename CopyT, class Context>
void AdamW( void AdamW(
const int N, const int N,
const float lr, const float lr,
...@@ -1033,39 +1016,53 @@ void AdamW( ...@@ -1033,39 +1016,53 @@ void AdamW(
const float eps, const float eps,
const float wd, const float wd,
const T* x, const T* x,
T* g, const T* g,
T* m, T* m,
T* v, T* v,
T* y,
CopyT* y_copy,
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, typename CopyT, class Context>
void MomentumSGD( void MomentumSGD(
const int N, const int N,
const float lr, const float lr,
const float momentum, const float momentum,
T* g, const float wd,
const T* x,
const T* g,
T* m, T* m,
T* y,
CopyT* y_copy,
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, typename CopyT, class Context>
void NesterovSGD( void NesterovSGD(
const int N, const int N,
const float lr, const float lr,
const float momentum, const float momentum,
T* g, const float wd,
const T* x,
const T* g,
T* m, T* m,
T* y,
CopyT* y_copy,
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, typename CopyT, class Context>
void RMSprop( void RMSprop(
const int N, const int N,
const float lr, const float lr,
const float momentum, const float momentum,
const float decay, const float alpha,
const float eps, const float eps,
T* g, const float wd,
const T* x,
const T* g,
T* m, T* m,
T* v, T* v,
T* y,
CopyT* y_copy,
Context* ctx); Context* ctx);
/* /*
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!