Commit 494774d3 by Ting PAN

Optimize training update operators

Summary:
This commit fuses the weight decay and mixed precision conversion
into update kernels to get lower training latency.
1 parent fb47d86f
Showing with 1703 additions and 1326 deletions
......@@ -418,9 +418,9 @@ class Normalize(Layer):
def __call__(self, bottom):
if len(self.blobs) == 0:
self.build(bottom)
outputs = [normalization_ops.lp_normalize(bottom, **self.norm_args)]
outputs = [normalization_ops.lp_norm(bottom, **self.norm_args)]
outputs += [blob['data'] for blob in self.blobs]
return array_ops.channel_affine(outputs, **self.scale_args)
return math_ops.affine(outputs, **self.scale_args)
class Permute(Layer):
......@@ -591,8 +591,7 @@ class Scale(Layer):
param = layer_param.scale_param
self.axis = param.axis
self.num_axes = param.num_axes
end_axis = -1 if self.num_axes < 1 else self.axis + self.num_axes - 1
self.call_args = {'axis': self.axis, 'end_axis': end_axis}
self.call_args = {'axis': list(range(self.axis, self.axis + self.num_axes))}
self.filler = caffe_pb2.FillerParameter(type='constant', value=1)
self.filler = param.filler if param.HasField('filler') else self.filler
self.bias_filler = param.bias_filler
......@@ -609,7 +608,7 @@ class Scale(Layer):
if len(self.blobs) == 0:
self.build(bottom)
inputs = [bottom] + [blob['data'] for blob in self.blobs]
return array_ops.channel_affine(inputs, **self.call_args)
return math_ops.affine(inputs, **self.call_args)
class Slice(Layer):
......
......@@ -16,8 +16,8 @@ from __future__ import print_function
from dragon.core.framework import workspace
from dragon.core.io.kpl_record import KPLRecordDataset
from dragon.core.ops import array_ops
from dragon.core.ops import framework_ops
from dragon.core.ops import normalization_ops
from dragon.utils import vision
from dragon.vm.caffe.core.layer import Layer
......@@ -121,5 +121,5 @@ class Data(Layer):
data._shape = (self.data_args['batch_size'],
None, None, len(self.norm_args['mean']))
label._shape = (self.data_args['batch_size'], None)
data = array_ops.channel_normalize(data, **self.norm_args)
data = normalization_ops.channel_norm(data, **self.norm_args)
return data, label
......@@ -9,6 +9,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# ---[ Compiler flags
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_ENABLE_EXTENDED_ALIGNED_STORAGE")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}
/wd4003 /wd4114
......
......@@ -36,16 +36,6 @@ dragon
`cast(...) <dragon/cast.html>`_
: Cast the data type of input.
`channel_affine(...) <dragon/channel_affine.html>`_
: Apply affine transformation to each channel of input.
`channel_normalize(...) <dragon/channel_normalize.html>`_
: Apply normalization to each channel of input.
`channel_shuffle(...) <dragon/channel_shuffle.html>`_
: Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`concat(...) <dragon/concat.html>`_
: Concatenate the inputs along the given axis.
......@@ -211,9 +201,6 @@ dragon
dragon/boolean_mask
dragon/broadcast_to
dragon/cast
dragon/channel_affine
dragon/channel_normalize
dragon/channel_shuffle
dragon/concat
dragon/constant
dragon/device
......
......@@ -24,6 +24,9 @@ dragon.cuda
`memory_allocated(...) <cuda/memory_allocated.html>`_
: Return the size of memory used by tensors in current workspace.
`set_cublas_flags(...) <cuda/set_cublas_flags.html>`_
: Set the flags of cuBLAS library.
`set_cudnn_flags(...) <cuda/set_cudnn_flags.html>`_
: Set the flags of cuDNN library.
......@@ -44,6 +47,7 @@ dragon.cuda
cuda/get_device_capability
cuda/is_available
cuda/memory_allocated
cuda/set_cublas_flags
cuda/set_cudnn_flags
cuda/set_default_device
cuda/set_device
......
channel_affine
==============
set_cublas_flags
================
.. autofunction:: dragon.channel_affine
.. autofunction:: dragon.cuda.set_cublas_flags
.. raw:: html
<style>
h1:before {
content: "dragon.";
content: "dragon.cuda.";
color: #103d3e;
}
</style>
......@@ -12,12 +12,18 @@ dragon.math
`add(...) <math/add.html>`_
: Compute the element-wise addition.
`affine(...) <math/affine.html>`_
: Apply the affine transformation to input.
`argmax(...) <math/argmax.html>`_
: Compute the index of maximum elements along the given axis.
`argmin(...) <math/argmin.html>`_
: Compute the index of minimum elements along the given axis.
`atan2(...) <math/atan2.html>`_
: Compute the element-wise arc-tangent of two arguments.
`ceil(...) <math/ceil.html>`_
: Compute the smallest integer not less than input.
......@@ -81,9 +87,6 @@ dragon.math
`logical_xor(...) <math/logical_xor.html>`_
: Compute the element-wise XOR logical operation.
`lp_normalize(...) <math/lp_normalize.html>`_
: Apply the lp normalization.
`matmul(...) <math/matmul.html>`_
: Compute the matrix multiplication.
......@@ -158,8 +161,10 @@ dragon.math
math/abs
math/add
math/affine
math/argmax
math/argmin
math/atan2
math/ceil
math/clip
math/cos
......@@ -181,7 +186,6 @@ dragon.math
math/logical_not
math/logical_or
math/logical_xor
math/lp_normalize
math/matmul
math/max
math/maximum
......
lp_normalize
============
affine
======
.. autofunction:: dragon.math.lp_normalize
.. autofunction:: dragon.math.affine
.. raw:: html
......
channel_normalize
=================
atan2
=====
.. autofunction:: dragon.channel_normalize
.. autofunction:: dragon.math.atan2
.. raw:: html
<style>
h1:before {
content: "dragon.";
content: "dragon.math.";
color: #103d3e;
}
</style>
......@@ -28,6 +28,13 @@ dragon.nn
`bias_add(...) <nn/bias_add.html>`_
: Add the bias across channels to input.
`channel_norm(...) <nn/channel_norm.html>`_
: Apply the normalization to each channel of input.
`channel_shuffle(...) <nn/channel_shuffle.html>`_
: Apply the group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`conv(...) <nn/conv.html>`_
: Apply the n-dimension convolution.
......@@ -107,6 +114,9 @@ dragon.nn
`log_softmax(...) <nn/log_softmax.html>`_
: Compute the composite of logarithm and softmax.
`lp_norm(...) <nn/lp_norm.html>`_
: Apply the lp normalization.
`moments(...) <nn/moments.html>`_
: Compute the mean and variance of input along the given axis.
......@@ -157,6 +167,8 @@ dragon.nn
nn/RNN
nn/batch_norm
nn/bias_add
nn/channel_norm
nn/channel_shuffle
nn/conv
nn/conv_transpose
nn/conv1d
......@@ -180,6 +192,7 @@ dragon.nn
nn/leaky_relu
nn/local_response_norm
nn/log_softmax
nn/lp_norm
nn/moments
nn/pool
nn/pool1d
......
channel_norm
============
.. autofunction:: dragon.nn.channel_norm
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
channel_shuffle
===============
.. autofunction:: dragon.channel_shuffle
.. autofunction:: dragon.nn.channel_shuffle
.. raw:: html
<style>
h1:before {
content: "dragon.";
content: "dragon.nn.";
color: #103d3e;
}
</style>
lp_norm
=======
.. autofunction:: dragon.nn.lp_norm
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
......@@ -21,6 +21,9 @@ vm.tensorflow.math
`argmin(...) <math/argmin.html>`_
: Compute the index of minimum elements along the given axis.
`atan2(...) <math/atan2.html>`_
: Compute the element-wise arc-tangent of two arguments.
`ceil(...) <math/ceil.html>`_
: Compute the smallest integer not less than input.
......@@ -134,6 +137,7 @@ vm.tensorflow.math
math/add_n
math/argmax
math/argmin
math/atan2
math/ceil
math/cos
math/cumsum
......
channel_normalize
=================
atan2
=====
.. autofunction:: dragon.vm.torch.channel_normalize
.. autofunction:: dragon.vm.tensorflow.math.atan2
.. raw:: html
<style>
h1:before {
content: "torch.";
content: "tf.math.";
color: #103d3e;
}
</style>
......@@ -51,6 +51,9 @@ vm.torch
`argsort(...) <torch/argsort.html>`_
: Return the index of sorted elements along the given dimension.
`atan2(...) <torch/atan2.html>`_
: Compute the element-wise arc-tangent of two arguments.
`baddbmm(...) <torch/baddbmm.html>`_
: Add input to the result of batched matrix-matrix multiplication.
......@@ -75,12 +78,6 @@ vm.torch
`ceil(...) <torch/ceil.html>`_
: Compute the smallest integer not less than input.
`channel_affine(...) <torch/channel_affine.html>`_
: Apply affine transformation to each channel of input.
`channel_normalize(...) <torch/channel_normalize.html>`_
: Apply normalization to each channel of input.
`chunk(...) <torch/chunk.html>`_
: Split input into a specific number of chunks.
......@@ -345,6 +342,7 @@ vm.torch
torch/argmax
torch/argmin
torch/argsort
torch/atan2
torch/baddbmm
torch/bitwise_and
torch/bitwise_not
......@@ -353,8 +351,6 @@ vm.torch
torch/bmm
torch/cat
torch/ceil
torch/channel_affine
torch/channel_normalize
torch/chunk
torch/clamp
torch/cos
......
......@@ -73,6 +73,10 @@ argsort
#######
.. automethod:: dragon.vm.torch.Tensor.argsort
atan2
#####
.. automethod:: dragon.vm.torch.Tensor.atan2
backward
########
.. automethod:: dragon.vm.torch.Tensor.backward
......@@ -699,6 +703,7 @@ zero\_
.. _torch.argmax(...): argmax.html
.. _torch.argmin(...): argmin.html
.. _torch.argsort(...): argsort.html
.. _torch.atan2(...): atan2.html
.. _torch.baddbmm(...): baddbmm.html
.. _torch.bitwise_and(...): bitwise_and.html
.. _torch.bitwise_not(...): bitwise_not.html
......
channel_affine
==============
atan2
=====
.. autofunction:: dragon.vm.torch.channel_affine
.. autofunction:: dragon.vm.torch.atan2
.. raw:: html
......
......@@ -6,12 +6,16 @@ vm.torch.backends
Modules
-------
`Module cuda <backends/cuda.html>`_
: The CUDA backend module.
`Module cudnn <backends/cudnn.html>`_
: The cuDNN backend module.
.. toctree::
:hidden:
backends/cuda
backends/cudnn
.. raw:: html
......
cuda
====
Properties
----------
matmul.allow_tf32
#################
.. data:: dragon.vm.torch.backends.cuda.matmul.allow_tf32
:annotation: = False
The flag that allows TF32 math type for matmul or not.
Functions
---------
is_built
########
.. automethod:: dragon.vm.torch.backends.cuda.is_built
.. raw:: html
<style>
h1:before {
content: "torch.backends.";
color: #103d3e;
}
</style>
......@@ -24,8 +24,8 @@ vm.torch.nn
`class AdaptiveMaxPool3d <nn/AdaptiveMaxPool3d.html>`_
: Apply the 3d adaptive max pooling.
`class AffineChannel <nn/AffineChannel.html>`_
: Apply affine transformation along the channels.
`class Affine <nn/Affine.html>`_
: Apply the affine transformation.
`class AvgPool1d <nn/AvgPool1d.html>`_
: Apply the 1d average pooling.
......@@ -312,7 +312,7 @@ vm.torch.nn
nn/AdaptiveMaxPool1d
nn/AdaptiveMaxPool2d
nn/AdaptiveMaxPool3d
nn/AffineChannel
nn/Affine
nn/AvgPool1d
nn/AvgPool2d
nn/AvgPool3d
......
AffineChannel
=============
Affine
======
.. autoclass:: dragon.vm.torch.nn.AffineChannel
.. autoclass:: dragon.vm.torch.nn.Affine
__init__
--------
.. automethod:: dragon.vm.torch.nn.AffineChannel.__init__
.. automethod:: dragon.vm.torch.nn.Affine.__init__
.. _torch.channel_affine(...): ../channel_affine.html
.. _torch.nn.functional.affine(...): functional/affine.html
.. raw:: html
......
......@@ -24,6 +24,9 @@ vm.torch.nn.functional
`adaptive_max_pool3d(...) <functional/adaptive_max_pool3d.html>`_
: Apply the 3d adaptive max pooling to input.
`affine(...) <functional/affine.html>`_
: Apply the affine transformation to input.
`avg_pool1d(...) <functional/avg_pool1d.html>`_
: Apply the 1d average pooling to input.
......@@ -40,8 +43,11 @@ vm.torch.nn.functional
`binary_cross_entropy_with_logits(...) <functional/binary_cross_entropy_with_logits.html>`_
: Compute the sigmoid cross entropy with contiguous target.
`channel_norm(...) <nn/channel_norm.html>`_
: Apply the normalization to each channel of input.
`channel_shuffle(...) <functional/channel_shuffle.html>`_
: Apply group shuffle to each channel of input.
: Apply the group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`conv1d(...) <functional/conv1d.html>`_
......@@ -229,11 +235,13 @@ vm.torch.nn.functional
functional/adaptive_max_pool1d
functional/adaptive_max_pool2d
functional/adaptive_max_pool3d
functional/affine
functional/avg_pool1d
functional/avg_pool2d
functional/avg_pool3d
functional/batch_norm
functional/binary_cross_entropy_with_logits
functional/channel_norm
functional/channel_shuffle
functional/conv1d
functional/conv2d
......
affine
======
.. autofunction:: dragon.vm.torch.nn.functional.affine
.. _torch.nn.affine(...): ../Affine.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
channel_norm
============
.. autofunction:: dragon.vm.torch.nn.functional.channel_norm
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
......@@ -56,15 +56,16 @@ class CUDAObjects {
auto& handle = handles[stream_id];
CUBLAS_CHECK(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSetStream(handle, stream(device_id, stream_id)));
}
auto& handle = handles[stream_id];
#if CUDA_VERSION >= 11000
if (cudnn_allow_tf32_) {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#endif
if (cublas_allow_tf32_) {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
return handles[stream_id];
#endif
return handle;
}
/*! \brief Return the specified cudnn handle */
......@@ -150,6 +151,9 @@ class CUDAObjects {
Map<string, ncclComm_t> nccl_comms_[CUDA_MAX_DEVICES];
#endif
/*! \brief The flag that allows cuBLAS TF32 math type or not */
bool cublas_allow_tf32_ = false;
/*! \brief The flag that uses cuDNN or not */
bool cudnn_enabled_ = true;
......
......@@ -20,32 +20,32 @@ namespace dragon {
/*!
* \brief Registry to create class instances.
*/
template <class KeyType, class ObjectType, class... Args>
template <class KeyT, class ClassT, class... Args>
class Registry {
public:
typedef std::function<ObjectType*(Args...)> Creator;
typedef std::function<ClassT*(Args...)> Creator;
/*! \brief Create an instance of specified class */
ObjectType* Create(const KeyType& key, Args... args) {
ClassT* Create(const KeyT& key, Args... args) {
CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered.";
return registry_[key](args...);
}
/*! \brief Return whether the specified class is registered */
bool Has(const KeyType& key) {
bool Has(const KeyT& key) {
return (registry_.count(key)) != 0;
}
/*! \brief Register a class with the creator */
void Register(const KeyType& key, Creator creator) {
void Register(const KeyT& key, Creator creator) {
CHECK(!registry_.count(key))
<< "\nKey(" << key << ") has already registered.";
registry_[key] = creator;
}
/*! \brief Return the key of registered classes */
vector<KeyType> keys() {
vector<KeyType> ret;
vector<KeyT> keys() {
vector<KeyT> ret;
for (const auto& it : registry_) {
ret.push_back(it.first);
}
......@@ -54,50 +54,49 @@ class Registry {
private:
/*! \brief The registry map */
Map<KeyType, Creator> registry_;
Map<KeyT, Creator> registry_;
};
/*!
* \brief Register creator into the registry.
*/
template <class KeyType, class ObjectType, class... Args>
template <class KeyT, class ClassT, class... Args>
class Registerer {
public:
/*! \brief Constructor with key and creator */
Registerer(
const KeyType& key,
Registry<KeyType, ObjectType, Args...>* registry,
typename Registry<KeyType, ObjectType, Args...>::Creator creator,
const KeyT& key,
Registry<KeyT, ClassT, Args...>* registry,
typename Registry<KeyT, ClassT, Args...>::Creator creator,
const string& help_msg = "") {
registry->Register(key, creator);
}
/*! \brief Return the default creator */
template <class DerivedType>
static ObjectType* DefaultCreator(Args... args) {
return new DerivedType(args...);
template <class DerivedT>
static ClassT* DefaultCreator(Args... args) {
return new DerivedT(args...);
}
};
// Used in *.h files
#define DECLARE_TYPED_REGISTRY(RegistryName, KeyType, ObjectType, ...) \
DRAGON_API Registry<KeyType, ObjectType, ##__VA_ARGS__>* RegistryName(); \
typedef Registerer<KeyType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName;
// Used in *.cc files
#define DEFINE_TYPED_REGISTRY(RegistryName, KeyType, ObjectType, ...) \
Registry<KeyType, ObjectType, ##__VA_ARGS__>* RegistryName() { \
static Registry<KeyType, ObjectType, ##__VA_ARGS__>* registry = \
new Registry<KeyType, ObjectType, ##__VA_ARGS__>(); \
return registry; \
// Used in *.h files.
#define DECLARE_TYPED_REGISTRY(RegistryName, KeyT, ClassT, ...) \
DRAGON_API Registry<KeyT, ClassT, ##__VA_ARGS__>* RegistryName(); \
typedef Registerer<KeyT, ClassT, ##__VA_ARGS__> Registerer##RegistryName;
// Used in *.cc files.
#define DEFINE_TYPED_REGISTRY(RegistryName, KeyT, ClassT, ...) \
Registry<KeyT, ClassT, ##__VA_ARGS__>* RegistryName() { \
static Registry<KeyT, ClassT, ##__VA_ARGS__>* registry = \
new Registry<KeyT, ClassT, ##__VA_ARGS__>(); \
return registry; \
}
#define DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
DECLARE_TYPED_REGISTRY(RegistryName, string, ObjectType, ##__VA_ARGS__)
#define DECLARE_REGISTRY(RegistryName, ClassT, ...) \
DECLARE_TYPED_REGISTRY(RegistryName, string, ClassT, ##__VA_ARGS__)
#define DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
DEFINE_TYPED_REGISTRY(RegistryName, string, ObjectType, ##__VA_ARGS__)
#define DEFINE_REGISTRY(RegistryName, ClassT, ...) \
DEFINE_TYPED_REGISTRY(RegistryName, string, ClassT, ##__VA_ARGS__)
#define REGISTER_TYPED_CLASS(RegistryName, key, ...) \
static Registerer##RegistryName ANONYMOUS_VARIABLE(g_##RegistryName)( \
......
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _ChannelAffine(
const int N,
const int S,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y) {
if (S == 1) {
if (bias != nullptr) {
EigenArrayMap<T>(y, C, N) = (ConstEigenArrayMap<T>(x, C, N).colwise() *
ConstEigenVectorArrayMap<T>(scale, C))
.colwise() +
ConstEigenVectorArrayMap<T>(bias, C);
} else {
EigenArrayMap<T>(y, C, N) = ConstEigenArrayMap<T>(x, C, N).colwise() *
ConstEigenVectorArrayMap<T>(scale, C);
}
return;
}
for (int i = 0; i < N; ++i) {
for (int j = 0; j < C; ++j) {
if (bias != nullptr) {
EigenVectorArrayMap<T>(y, S) =
ConstEigenVectorArrayMap<T>(x, S) * scale[j] + bias[j];
} else {
EigenVectorArrayMap<T>(y, S) =
ConstEigenVectorArrayMap<T>(x, S) * scale[j];
}
x += S;
y += S;
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void ChannelAffine<float16, CPUContext>(
const int N,
const int S,
const int C,
const float16* x,
const float16* w,
const float16* b,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelAffine<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const T* x, \
const T* scale, \
const T* bias, \
T* y, \
CPUContext* ctx) { \
_ChannelAffine(N, S, C, x, scale, bias, y); \
}
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, typename AccT>
__global__ void _ChannelAffine(
const int NxCxS,
const int S,
const int C,
const T* x,
const T* scale,
T* y) {
CUDA_1D_KERNEL_LOOP(i, NxCxS) {
y[i] = convert::To<T>(
convert::To<AccT>(x[i]) *
convert::To<AccT>(__ldg(scale + (i / S) % C)));
}
}
template <typename T, typename AccT>
__global__ void _ChannelAffine(
const int NxCxS,
const int S,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y) {
CUDA_1D_KERNEL_LOOP(i, NxCxS) {
const int j = (i / S) % C;
y[i] = convert::To<T>(
fma(convert::To<AccT>(x[i]),
convert::To<AccT>(__ldg(scale + j)),
convert::To<AccT>(__ldg(bias + j))));
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelAffine<T, CUDAContext>( \
const int N, \
const int S, \
const int C, \
const T* x, \
const T* scale, \
const T* bias, \
T* y, \
CUDAContext* ctx) { \
const auto NxCxS = N * C * S; \
if (bias != nullptr) { \
_ChannelAffine<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_BLOCKS(NxCxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxCxS, \
S, \
C, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<const math::ScalarType<T>::type*>(scale), \
reinterpret_cast<const math::ScalarType<T>::type*>(bias), \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} else { \
_ChannelAffine<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_BLOCKS(NxCxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxCxS, \
S, \
C, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<const math::ScalarType<T>::type*>(scale), \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} \
}
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _ChannelShuffle(
const int N,
const int S,
const int G,
const int K,
const T* x,
T* y) {
for (int i = 0; i < N; ++i) {
for (int gi = 0; gi < G; ++gi) {
for (int ki = 0; ki < K; ++ki) {
std::memcpy(
y + ((i * K + ki) * G + gi) * S,
x + ((i * G + gi) * K + ki) * S,
S * sizeof(T));
}
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelShuffle<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const int G, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_ChannelShuffle(N, S, G, C / G, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
__global__ void _ChannelShuffle(
const int NxCxS,
const int S,
const int G,
const int K,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(index, NxCxS) {
const int j = index % S;
const int gi = index / S % G;
const int ki = index / S / G % K;
const int i = index / S / G / K;
y[index] = x[((i * G + gi) * K + ki) * S + j];
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ChannelShuffle<T, CUDAContext>( \
const int N, \
const int S, \
const int C, \
const int G, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const auto NxCxS = N * C * S; \
_ChannelShuffle<<< \
CUDA_BLOCKS(NxCxS), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(NxCxS, S, G, C / G, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
......@@ -52,14 +52,23 @@ __global__ void _ComputeCounts(
CUDAContext* ctx) { \
math::Copy(dim, x, y, ctx); \
auto policy = thrust::cuda::par.on(ctx->cuda_stream()); \
auto* data = reinterpret_cast<math::ScalarType<T>::type*>(y); \
thrust::device_vector<int> order1(dim), order2(dim); \
thrust::sequence(policy, order1.begin(), order1.end()); \
thrust::sequence(policy, order2.begin(), order2.end()); \
thrust::sort_by_key( \
policy, y, y + dim, order1.begin(), math::LessFunctor<T>()); \
policy, \
data, \
data + dim, \
order1.begin(), \
math::LessFunctor<math::ScalarType<T>::type>()); \
auto last = thrust::unique_by_key( \
policy, y, y + dim, order2.begin(), math::EqualFunctor<T>()); \
int n = num[0] = last.first - y; \
policy, \
data, \
data + dim, \
order2.begin(), \
math::EqualFunctor<math::ScalarType<T>::type>()); \
int n = num[0] = last.first - data; \
if (inverse_index) { \
_RemapInverse<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
dim, n, order1.data(), order2.data(), inverse_index); \
......
......@@ -8,7 +8,7 @@ namespace kernels {
namespace {
template <typename InputT, typename OutputT>
void _ChannelNormalize(
void _ChannelNorm(
const int axis,
const int num_dims,
const int64_t* x_strides,
......@@ -19,15 +19,14 @@ void _ChannelNormalize(
OutputT* y) {
const auto N = math::utils::Prod(num_dims, y_dims);
vec64_t idx(num_dims, 0);
int64_t xi, wi;
for (int yi = 0; yi < N; ++yi) {
xi = 0;
int64_t xi = 0, wi;
for (int d = num_dims - 1; d >= 0; --d) {
xi += idx[d] * x_strides[d];
if (d == axis) wi = idx[d];
}
y[yi] =
convert::To<OutputT>((convert::To<float>(x[xi]) - mean[wi]) / std[wi]);
const float val = convert::To<float>(x[xi]);
y[yi] = convert::To<OutputT>((val - mean[wi]) / std[wi]);
math::utils::IncreaseIndexInDims(num_dims, y_dims, idx.data());
}
}
......@@ -36,19 +35,19 @@ void _ChannelNormalize(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \
void ChannelNormalize<InputT, OutputT, CPUContext>( \
const int axis, \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const InputT* x, \
const float* mean, \
const float* std, \
OutputT* y, \
CPUContext* ctx) { \
_ChannelNormalize(axis, num_dims, x_strides, y_dims, x, mean, std, y); \
#define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \
void ChannelNorm<InputT, OutputT, CPUContext>( \
const int axis, \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const InputT* x, \
const float* mean, \
const float* std, \
OutputT* y, \
CPUContext* ctx) { \
_ChannelNorm(axis, num_dims, x_strides, y_dims, x, mean, std, y); \
}
DEFINE_KERNEL_LAUNCHER(uint8_t, float16);
......
......@@ -11,7 +11,7 @@ namespace kernels {
namespace {
template <typename InputT, typename OutputT, int D>
__global__ void _ChannelNormalize(
__global__ void _ChannelNorm(
const int N,
const int axis,
const int num_dims,
......@@ -38,31 +38,27 @@ __global__ void _ChannelNormalize(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \
void ChannelNormalize<InputT, OutputT, CUDAContext>( \
const int axis, \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const InputT* x, \
const float* mean, \
const float* std, \
OutputT* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims; \
const auto N = math::utils::Prod(num_dims, y_dims); \
for (int i = 0; i < num_dims; ++i) { \
X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \
} \
_ChannelNormalize<<< \
CUDA_BLOCKS(N), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N, axis, num_dims, X_strides, Y_dims, x, mean, std, y); \
#define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \
void ChannelNorm<InputT, OutputT, CUDAContext>( \
const int axis, \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const InputT* x, \
const float* mean, \
const float* std, \
OutputT* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims; \
const auto N = math::utils::Prod(num_dims, y_dims); \
for (int i = 0; i < num_dims; ++i) { \
X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \
} \
_ChannelNorm<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, axis, num_dims, X_strides, Y_dims, x, mean, std, y); \
}
DEFINE_KERNEL_LAUNCHER(uint8_t, float16);
......
......@@ -8,7 +8,7 @@ namespace kernels {
namespace {
template <typename T>
void _L1Normalize(
void _L1Norm(
const int N,
const int S,
const int C,
......@@ -28,7 +28,7 @@ void _L1Normalize(
}
template <typename T>
void _L2Normalize(
void _L2Norm(
const int N,
const int S,
const int C,
......@@ -48,7 +48,7 @@ void _L2Normalize(
}
template <typename T>
void _L1NormalizeGrad(
void _L1NormGrad(
const int N,
const int S,
const int C,
......@@ -73,7 +73,7 @@ void _L1NormalizeGrad(
}
template <typename T>
void _L2NormalizeGrad(
void _L2NormGrad(
const int N,
const int S,
const int C,
......@@ -101,7 +101,7 @@ void _L2NormalizeGrad(
/* ------------------- Launcher Separator ------------------- */
template <>
void L1Normalize<float16, CPUContext>(
void L1Norm<float16, CPUContext>(
const int N,
const int S,
const int C,
......@@ -114,7 +114,7 @@ void L1Normalize<float16, CPUContext>(
}
template <>
void L2Normalize<float16, CPUContext>(
void L2Norm<float16, CPUContext>(
const int N,
const int S,
const int C,
......@@ -127,7 +127,7 @@ void L2Normalize<float16, CPUContext>(
}
template <>
void L1NormalizeGrad<float16, CPUContext>(
void L1NormGrad<float16, CPUContext>(
const int N,
const int S,
const int C,
......@@ -138,10 +138,10 @@ void L1NormalizeGrad<float16, CPUContext>(
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
} // L1NormalizeGrad
} // L1NormGrad
template <>
void L2NormalizeGrad<float16, CPUContext>(
void L2NormGrad<float16, CPUContext>(
const int N,
const int S,
const int C,
......@@ -152,7 +152,7 @@ void L2NormalizeGrad<float16, CPUContext>(
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
} // L2NormalizeGrad
} // L2NormGrad
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
......@@ -183,14 +183,14 @@ void L2NormalizeGrad<float16, CPUContext>(
_##name<T>(N, S, C, normalizer, eps, dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double);
DEFINE_KERNEL_LAUNCHER(L1Norm, float);
DEFINE_KERNEL_LAUNCHER(L1Norm, double);
DEFINE_KERNEL_LAUNCHER(L2Norm, float);
DEFINE_KERNEL_LAUNCHER(L2Norm, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
......
......@@ -12,7 +12,7 @@ namespace kernels {
namespace {
template <typename T, typename AccT>
__global__ void _L1Normalize(
__global__ void _L1Norm(
const int NxS,
const int S,
const int C,
......@@ -41,7 +41,7 @@ __global__ void _L1Normalize(
}
template <typename T, typename AccT>
__global__ void _L2Normalize(
__global__ void _L2Norm(
const int NxS,
const int S,
const int C,
......@@ -70,7 +70,7 @@ __global__ void _L2Normalize(
}
template <typename T, typename AccT>
__global__ void _L1NormalizeGrad(
__global__ void _L1NormGrad(
const int NxS,
const int S,
const int C,
......@@ -107,7 +107,7 @@ __global__ void _L1NormalizeGrad(
}
template <typename T, typename AccT>
__global__ void _L2NormalizeGrad(
__global__ void _L2NormGrad(
const int NxS,
const int S,
const int C,
......@@ -195,18 +195,18 @@ __global__ void _L2NormalizeGrad(
reinterpret_cast<math::ScalarType<T>::type*>(dx)); \
}
DEFINE_KERNEL_LAUNCHER(L1Normalize, float16, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, float, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float16, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float16, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double, double);
DEFINE_KERNEL_LAUNCHER(L1Norm, float16, float);
DEFINE_KERNEL_LAUNCHER(L1Norm, float, float);
DEFINE_KERNEL_LAUNCHER(L1Norm, double, double);
DEFINE_KERNEL_LAUNCHER(L2Norm, float16, float);
DEFINE_KERNEL_LAUNCHER(L2Norm, float, float);
DEFINE_KERNEL_LAUNCHER(L2Norm, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, float16, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormGrad, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, float16, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormGrad, double, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
......
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
template <>
void Adam<float, CPUContext>(
namespace {
template <typename T, typename CopyT>
void _Adam(
const int N,
const float lr,
const float beta1,
const float beta2,
const float eps,
float* g,
float* m,
float* v,
CPUContext* ctx) {
const T lr,
const T beta1,
const T beta2,
const T eps,
const T wd,
const T* x,
const T* g,
T* m,
T* v,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) {
float gi = g[i];
float mi = m[i] = m[i] * beta1 + gi * (1 - beta1);
float vi = v[i] = v[i] * beta2 + gi * gi * (1 - beta2);
g[i] = lr * mi / (std::sqrt(vi) + eps);
const T gi = wd > T(0) ? std::fma(wd, x[i], g[i]) : g[i];
const T mi = m[i] = std::fma(beta1, m[i], (T(1) - beta1) * gi);
const T vi = v[i] = std::fma(beta2, v[i], (T(1) - beta2) * gi * gi);
y[i] -= lr * mi / (std::sqrt(vi) + eps);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
template <>
void AdamW<float, CPUContext>(
template <typename T, typename CopyT>
void _AdamW(
const int N,
const float lr,
const float beta1,
const float beta2,
const float eps,
const float wd,
const float* x,
float* g,
float* m,
float* v,
CPUContext* ctx) {
const T lr,
const T beta1,
const T beta2,
const T eps,
const T wd,
const T* x,
const T* g,
T* m,
T* v,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) {
float gi = g[i];
float mi = m[i] = m[i] * beta1 + gi * (1 - beta1);
float vi = v[i] = v[i] * beta2 + gi * gi * (1 - beta2);
g[i] = lr * mi / (std::sqrt(vi) + eps) + wd * x[i];
const T gi = g[i];
const T mi = m[i] = std::fma(beta1, m[i], (T(1) - beta1) * gi);
const T vi = v[i] = std::fma(beta2, v[i], (T(1) - beta2) * gi * gi);
y[i] -= wd > T(0) ? std::fma(wd, x[i], lr * mi / (std::sqrt(vi) + eps))
: lr * mi / (std::sqrt(vi) + eps);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CPUContext>( \
const int N, \
const float lr, \
const float beta1, \
const float beta2, \
const float eps, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* v, \
T* y, \
CopyT* y_copy, \
CPUContext* ctx) { \
_##name( \
N, \
convert::To<T>(lr), \
convert::To<T>(beta1), \
convert::To<T>(beta2), \
convert::To<T>(eps), \
convert::To<T>(wd), \
x, \
g, \
m, \
v, \
y, \
y_copy); \
}
DEFINE_KERNEL_LAUNCHER(Adam, float, float16);
DEFINE_KERNEL_LAUNCHER(Adam, float, float);
DEFINE_KERNEL_LAUNCHER(Adam, double, double);
DEFINE_KERNEL_LAUNCHER(AdamW, float, float16);
DEFINE_KERNEL_LAUNCHER(AdamW, float, float);
DEFINE_KERNEL_LAUNCHER(AdamW, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -9,25 +10,32 @@ namespace kernels {
namespace {
template <typename T>
template <typename T, typename CopyT>
__global__ void _Adam(
const int N,
const T lr,
const T beta1,
const T beta2,
const T eps,
T* g,
const T wd,
const T* x,
const T* g,
T* m,
T* v) {
T* v,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) {
T gi = g[i];
T mi = m[i] = m[i] * beta1 + gi * (1 - beta1);
T vi = v[i] = v[i] * beta2 + gi * gi * (1 - beta2);
g[i] = lr * mi / (sqrt(vi) + eps);
const T gi = wd > T(0) ? fma(wd, x[i], g[i]) : g[i];
const T mi = m[i] = fma(beta1, m[i], (T(1) - beta1) * gi);
const T vi = v[i] = fma(beta2, v[i], (T(1) - beta2) * gi * gi);
y[i] -= lr * mi / (sqrt(vi) + eps);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
template <typename T>
template <typename T, typename CopyT>
__global__ void _AdamW(
const int N,
const T lr,
......@@ -36,14 +44,20 @@ __global__ void _AdamW(
const T eps,
const T wd,
const T* x,
T* g,
const T* g,
T* m,
T* v) {
T* v,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) {
T gi = g[i];
T mi = m[i] = m[i] * beta1 + gi * (1 - beta1);
T vi = v[i] = v[i] * beta2 + gi * gi * (1 - beta2);
g[i] = lr * mi / (sqrt(vi) + eps) + wd * x[i];
const T gi = g[i];
const T mi = m[i] = fma(beta1, m[i], (T(1) - beta1) * gi);
const T vi = v[i] = fma(beta2, v[i], (T(1) - beta2) * gi * gi);
y[i] -= wd > T(0) ? fma(wd, x[i], lr * mi / (sqrt(vi) + eps))
: lr * mi / (sqrt(vi) + eps);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
......@@ -51,37 +65,44 @@ __global__ void _AdamW(
/* ------------------- Launcher Separator ------------------- */
template <>
void Adam<float, CUDAContext>(
const int N,
const float lr,
const float beta1,
const float beta2,
const float eps,
float* g,
float* m,
float* v,
CUDAContext* ctx) {
_Adam<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, lr, beta1, beta2, eps, g, m, v);
}
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CUDAContext>( \
const int N, \
const float lr, \
const float beta1, \
const float beta2, \
const float eps, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* v, \
T* y, \
CopyT* y_copy, \
CUDAContext* ctx) { \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
convert::To<T>(lr), \
convert::To<T>(beta1), \
convert::To<T>(beta2), \
convert::To<T>(eps), \
convert::To<T>(wd), \
x, \
g, \
m, \
v, \
y, \
reinterpret_cast<math::ScalarType<CopyT>::type*>(y_copy)); \
}
template <>
void AdamW<float, CUDAContext>(
const int N,
const float lr,
const float beta1,
const float beta2,
const float eps,
const float wd,
const float* x,
float* g,
float* m,
float* v,
CUDAContext* ctx) {
_AdamW<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, lr, beta1, beta2, eps, wd, x, g, m, v);
}
DEFINE_KERNEL_LAUNCHER(Adam, float, float16);
DEFINE_KERNEL_LAUNCHER(Adam, float, float);
DEFINE_KERNEL_LAUNCHER(Adam, double, double);
DEFINE_KERNEL_LAUNCHER(AdamW, float, float16);
DEFINE_KERNEL_LAUNCHER(AdamW, float, float);
DEFINE_KERNEL_LAUNCHER(AdamW, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
......
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
template <>
void RMSprop<float, CPUContext>(
namespace {
template <typename T, typename CopyT>
void _RMSprop(
const int N,
const float lr,
const float momentum,
const float decay,
const float eps,
float* g,
float* m,
float* v,
CPUContext* ctx) {
const T lr,
const T momentum,
const T alpha,
const T eps,
const T wd,
const T* x,
const T* g,
T* m,
T* v,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) {
float gi = g[i];
float vi = v[i] = decay * v[i] + (1 - decay) * gi * gi;
float mi = m[i] = std::fma(momentum, m[i], gi / (std::sqrt(vi) + eps));
g[i] = lr * mi;
const T gi = wd > T(0) ? std::fma(wd, x[i], g[i]) : g[i];
const T vi = v[i] = std::fma(alpha, v[i], (T(1) - alpha) * gi * gi);
const T mi = m[i] = std::fma(momentum, m[i], gi / (std::sqrt(vi) + eps));
y[i] -= lr * mi;
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CPUContext>( \
const int N, \
const float lr, \
const float momentum, \
const float alpha, \
const float eps, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* v, \
T* y, \
CopyT* y_copy, \
CPUContext* ctx) { \
_##name( \
N, \
convert::To<T>(lr), \
convert::To<T>(momentum), \
convert::To<T>(alpha), \
convert::To<T>(eps), \
convert::To<T>(wd), \
x, \
g, \
m, \
v, \
y, \
y_copy); \
}
DEFINE_KERNEL_LAUNCHER(RMSprop, float, float16);
DEFINE_KERNEL_LAUNCHER(RMSprop, float, float);
DEFINE_KERNEL_LAUNCHER(RMSprop, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -9,21 +10,28 @@ namespace kernels {
namespace {
template <typename T>
template <typename T, typename CopyT>
__global__ void _RMSprop(
const int N,
const T lr,
const T momentum,
const T decay,
const T alpha,
const T eps,
T* g,
const T wd,
const T* x,
const T* g,
T* m,
T* v) {
T* v,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) {
T gi = g[i];
T vi = v[i] = decay * v[i] + (1 - decay) * gi * gi;
T mi = m[i] = fma(momentum, m[i], gi / (sqrt(vi) + eps));
g[i] = lr * mi;
const T gi = wd > T(0) ? fma(wd, x[i], g[i]) : g[i];
const T vi = v[i] = fma(alpha, v[i], (T(1) - alpha) * gi * gi);
const T mi = m[i] = fma(momentum, m[i], gi / (std::sqrt(vi) + eps));
y[i] -= lr * mi;
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
......@@ -31,20 +39,41 @@ __global__ void _RMSprop(
/* ------------------- Launcher Separator ------------------- */
template <>
void RMSprop<float, CUDAContext>(
const int N,
const float lr,
const float momentum,
const float decay,
const float eps,
float* g,
float* m,
float* v,
CUDAContext* ctx) {
_RMSprop<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, lr, momentum, decay, eps, g, m, v);
}
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CUDAContext>( \
const int N, \
const float lr, \
const float momentum, \
const float alpha, \
const float eps, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* v, \
T* y, \
CopyT* y_copy, \
CUDAContext* ctx) { \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
convert::To<T>(lr), \
convert::To<T>(momentum), \
convert::To<T>(alpha), \
convert::To<T>(eps), \
convert::To<T>(wd), \
x, \
g, \
m, \
v, \
y, \
reinterpret_cast<math::ScalarType<CopyT>::type*>(y_copy)); \
}
DEFINE_KERNEL_LAUNCHER(RMSprop, float, float16);
DEFINE_KERNEL_LAUNCHER(RMSprop, float, float);
DEFINE_KERNEL_LAUNCHER(RMSprop, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
......
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
template <>
void MomentumSGD<float, CPUContext>(
namespace {
template <typename T, typename CopyT>
void _MomentumSGD(
const int N,
const float lr,
const float momentum,
float* g,
float* m,
CPUContext* ctx) {
const T lr,
const T momentum,
const T wd,
const T* x,
const T* g,
T* m,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) {
float mi = m[i] = std::fma(momentum, m[i], g[i]);
g[i] = lr * mi;
const T gi = wd > T(0) ? std::fma(wd, x[i], g[i]) : g[i];
const T mi = m[i] = std::fma(momentum, m[i], gi);
y[i] -= lr * mi;
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
template <>
void NesterovSGD<float, CPUContext>(
template <typename T, typename CopyT>
void _NesterovSGD(
const int N,
const float lr,
const float momentum,
float* g,
float* m,
CPUContext* ctx) {
const T lr,
const T momentum,
const T wd,
const T* x,
const T* g,
T* m,
T* y,
CopyT* y_copy) {
for (int i = 0; i < N; ++i) {
float gi = g[i];
float mi = m[i] = std::fma(momentum, m[i], gi);
g[i] = lr * std::fma(momentum, mi, gi);
const T gi = wd > T(0) ? std::fma(wd, x[i], g[i]) : g[i];
const T mi = m[i] = std::fma(momentum, m[i], gi);
y[i] -= lr * std::fma(momentum, mi, gi);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CPUContext>( \
const int N, \
const float lr, \
const float momentum, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* y, \
CopyT* y_copy, \
CPUContext* ctx) { \
_##name( \
N, \
convert::To<T>(lr), \
convert::To<T>(momentum), \
convert::To<T>(wd), \
x, \
g, \
m, \
y, \
y_copy); \
}
DEFINE_KERNEL_LAUNCHER(MomentumSGD, float, float16);
DEFINE_KERNEL_LAUNCHER(MomentumSGD, float, float);
DEFINE_KERNEL_LAUNCHER(MomentumSGD, double, double);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, float, float16);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, float, float);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -9,22 +10,45 @@ namespace kernels {
namespace {
template <typename T>
__global__ void
_MomentumSGD(const int N, const T lr, const T momentum, T* g, T* m) {
template <typename T, typename CopyT>
__global__ void _MomentumSGD(
const int N,
const T lr,
const T momentum,
const T wd,
const T* x,
const T* g,
T* m,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) {
T mi = m[i] = fma(momentum, m[i], g[i]);
g[i] = lr * mi;
const T gi = wd > T(0) ? fma(wd, x[i], g[i]) : g[i];
const T mi = m[i] = fma(momentum, m[i], gi);
y[i] -= lr * mi;
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
template <typename T>
__global__ void
_NesterovSGD(const int N, const T lr, const T momentum, T* g, T* m) {
template <typename T, typename CopyT>
__global__ void _NesterovSGD(
const int N,
const T lr,
const T momentum,
const T wd,
const T* x,
const T* g,
T* m,
T* y,
CopyT* y_copy) {
CUDA_1D_KERNEL_LOOP(i, N) {
T gi = g[i];
T mi = m[i] = fma(momentum, m[i], gi);
g[i] = lr * fma(momentum, mi, gi);
const T gi = wd > T(0) ? fma(wd, x[i], g[i]) : g[i];
const T mi = m[i] = fma(momentum, m[i], gi);
y[i] -= lr * fma(momentum, mi, gi);
if (y_copy != nullptr) {
y_copy[i] = convert::To<CopyT>(y[i]);
}
}
}
......@@ -32,29 +56,38 @@ _NesterovSGD(const int N, const T lr, const T momentum, T* g, T* m) {
/* ------------------- Launcher Separator ------------------- */
template <>
void MomentumSGD<float, CUDAContext>(
const int N,
const float lr,
const float momentum,
float* g,
float* m,
CUDAContext* ctx) {
_MomentumSGD<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, lr, momentum, g, m);
}
#define DEFINE_KERNEL_LAUNCHER(name, T, CopyT) \
template <> \
void name<T, CopyT, CUDAContext>( \
const int N, \
const float lr, \
const float momentum, \
const float wd, \
const T* x, \
const T* g, \
T* m, \
T* y, \
CopyT* y_copy, \
CUDAContext* ctx) { \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
convert::To<T>(lr), \
convert::To<T>(momentum), \
convert::To<T>(wd), \
x, \
g, \
m, \
y, \
reinterpret_cast<math::ScalarType<CopyT>::type*>(y_copy)); \
}
template <>
void NesterovSGD<float, CUDAContext>(
const int N,
const float lr,
const float momentum,
float* g,
float* m,
CUDAContext* ctx) {
_NesterovSGD<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, lr, momentum, g, m);
}
DEFINE_KERNEL_LAUNCHER(MomentumSGD, float, float16);
DEFINE_KERNEL_LAUNCHER(MomentumSGD, float, float);
DEFINE_KERNEL_LAUNCHER(MomentumSGD, double, double);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, float, float16);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, float, float);
DEFINE_KERNEL_LAUNCHER(NesterovSGD, double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
......
......@@ -91,16 +91,24 @@ void RegisterModule_cuda(py::module& m) {
#endif
});
/*! \brief Set the flags of cuBLAS library */
m.def("cublasSetFlags", [](int allow_tf32) {
#ifdef USE_CUDA
auto& ctx = CUDAContext::objects();
if (allow_tf32 >= 0) ctx.cublas_allow_tf32_ = allow_tf32;
#endif
});
/*! \brief Set the flags of cuDNN library */
m.def(
"cudnnSetFlags",
[](bool enabled, bool benchmark, bool deterministic, bool allow_tf32) {
[](int enabled, int benchmark, int deterministic, int allow_tf32) {
#ifdef USE_CUDA
auto& cuda_objects = CUDAContext::objects();
cuda_objects.cudnn_enabled_ = enabled;
cuda_objects.cudnn_deterministic_ = deterministic;
cuda_objects.cudnn_benchmark_ = benchmark;
cuda_objects.cudnn_allow_tf32_ = allow_tf32;
auto& ctx = CUDAContext::objects();
if (enabled >= 0) ctx.cudnn_enabled_ = enabled;
if (benchmark >= 0) ctx.cudnn_benchmark_ = benchmark;
if (deterministic >= 0) ctx.cudnn_deterministic_ = deterministic;
if (allow_tf32 >= 0) ctx.cudnn_allow_tf32_ = allow_tf32;
#endif
});
......
......@@ -132,8 +132,8 @@ PYBIND11_MODULE(libdragon_python, m) {
PRINT(INFO) << GetVerboseDef(def.DebugString(), "graph");
}
}
// Return the graph name may be different from the def
// We will make a unique dummy name on creating the graph
// Return the graph name may be different from the def.
// We will make a unique dummy name on creating the graph.
return graph->name();
})
......@@ -175,8 +175,8 @@ PYBIND11_MODULE(libdragon_python, m) {
GraphDef init_graph, pred_graph;
onnx::ONNXBackend onnx_backend;
onnx_backend.Prepare(model_path, &init_graph, &pred_graph);
// Serializing to Python is intractable
// We should apply the initializer immediately
// Serializing to Python is intractable.
// We should apply the initializer immediately.
self->RunGraph(self->CreateGraph(init_graph)->name());
return py::bytes(pred_graph.SerializeAsString());
});
......
......@@ -24,14 +24,14 @@ PythonPluginOp<Context>::PythonPluginOp(const OperatorDef& def, Workspace* ws)
Py_Initialize();
auto* module = PyImport_ImportModule(module_name_.c_str());
CHECK(module) << "\nFailed to import module: " << module;
auto* module_dict = PyModule_GetDict(module);
auto* op_class = PyDict_GetItemString(module_dict, class_name_.c_str());
CHECK(op_class) << "\nFailed to import class: " << class_name_
<< " from module: " << module_name_;
self_ = PyObject_CallObject(op_class, NULL);
// Project inputs and outputs.
// Set inputs and outputs.
inputs_ = PyList_New(InputSize());
outputs_ = PyList_New(OutputSize());
for (int i = 0; i < InputSize(); i++) {
......@@ -41,16 +41,15 @@ PythonPluginOp<Context>::PythonPluginOp(const OperatorDef& def, Workspace* ws)
PyList_SetItem(outputs_, i, PyBytes_FromStdString(Output(i)->name()));
}
// Set: self.kwargs_str
// Attr: "kwargs_str"
PyObject_SetAttr(
self_,
PyBytes_FromRawString("kwargs_str"),
PyBytes_FromStdString(kwargs_str_));
// Method: self.setup(inputs, outputs)
if (PyObject_HasAttr(self_, PyBytes_FromRawString("setup"))) {
CHECK(PyObject_CallMethod(self_, "setup", "OO", inputs_, outputs_))
<< CallMethodHelper("setup");
<< CallMethodHelper("setup"); // Method: setup(inputs, outputs)
}
}
......@@ -67,27 +66,24 @@ string PythonPluginOp<Context>::CallMethodHelper(const string& method_name) {
template <class Context>
void PythonPluginOp<Context>::RunOnDevice() {
// GIL may have been released
// GIL may have been released.
pybind11::gil_scoped_acquire g;
// Atrribute: self.phase
// Attr: phase
PyObject_SetAttr(
self_, PyBytes_FromRawString("phase"), PyBytes_FromStdString(phase()));
// Method: self.reshape(input, outputs)
if (PyObject_HasAttr(self_, PyBytes_FromRawString("reshape"))) {
CHECK(PyObject_CallMethod(self_, "reshape", "OO", inputs_, outputs_))
<< CallMethodHelper("reshape");
<< CallMethodHelper("reshape"); // Method: reshape(input, outputs)
}
// Method: self.run(input, outputs)
// Method: self.forward(input, outputs)
if (PyObject_HasAttr(self_, PyBytes_FromRawString("forward"))) {
CHECK(PyObject_CallMethod(self_, "forward", "OO", inputs_, outputs_))
<< CallMethodHelper("forward");
<< CallMethodHelper("forward"); // Method: run(input, outputs)
} else if (PyObject_HasAttr(self_, PyBytes_FromRawString("run"))) {
CHECK(PyObject_CallMethod(self_, "run", "OO", inputs_, outputs_))
<< CallMethodHelper("run");
<< CallMethodHelper("run"); // Method: forward(input, outputs)
}
}
......
......@@ -13,7 +13,6 @@ void ONNXBackend::Prepare(
ModelProto onnx_model;
CHECK(ReadProtoFromBinaryFile(onnx_model_path.c_str(), &onnx_model))
<< "\nFailed to parse the onnx model.";
int opset_version = -1;
for (const auto& imp : onnx_model.opset_import()) {
if ((!imp.has_domain()) || imp.domain().empty()) {
......@@ -31,7 +30,6 @@ void ONNXBackend::Prepare(
std::cout << "Unrecognized operator set " << opset_version << std::endl;
}
}
if (opset_version < 0) {
if (onnx_model.ir_version() >= 0x00000003) {
LOG(FATAL) << "Model with IR version >= 3 "
......@@ -40,7 +38,6 @@ void ONNXBackend::Prepare(
opset_version = 1;
}
}
ONNXToDragon(onnx_model, opset_version, true, init_graph, pred_graph);
}
......@@ -52,22 +49,23 @@ void ONNXBackend::ONNXToDragon(
GraphDef* pred_graph) {
ModelProto init_model = ModelProto();
ModelProto pred_model = onnx_model;
pred_graph->set_name(onnx_model.graph().name());
init_graph->set_name(onnx_model.graph().name() + "/init");
ValueInfoMap graph_value_infos{};
InitializerMap graph_initializer{};
for (const auto& vi : onnx_model.graph().input())
graph_value_infos[vi.name()].CopyFrom(vi);
for (const auto& vi : onnx_model.graph().output())
graph_value_infos[vi.name()].CopyFrom(vi);
for (const auto& vi : onnx_model.graph().value_info())
graph_value_infos[vi.name()].CopyFrom(vi);
// Collect graph inputs.
for (const auto& v : onnx_model.graph().input()) {
graph_value_infos[v.name()].CopyFrom(v);
}
// Collect graph outputs.
for (const auto& v : onnx_model.graph().output()) {
graph_value_infos[v.name()].CopyFrom(v);
}
// Collect graph values.
for (const auto& v : onnx_model.graph().value_info()) {
graph_value_infos[v.name()].CopyFrom(v);
}
// Collect graph initializers.
for (const auto& tensor : onnx_model.graph().initializer()) {
if (include_initializers) {
auto* op_def = init_graph->add_op();
......@@ -76,16 +74,18 @@ void ONNXBackend::ONNXToDragon(
}
graph_initializer[tensor.name()] = &tensor;
}
// Convert to graph defs.
auto converter = [&](const ModelProto& model, GraphDef* graph) mutable {
for (const auto& node : model.graph().node()) {
ValueInfoMap value_infos{};
InitializerMap initializer{};
for (const auto& name : node.input()) {
if (graph_value_infos.count(name))
if (graph_value_infos.count(name)) {
value_infos[name].CopyFrom(graph_value_infos[name]);
if (graph_initializer.count(name))
}
if (graph_initializer.count(name)) {
initializer[name] = graph_initializer[name];
}
}
auto onnx_node = ONNXNode(node);
auto returns = ONNXNodeToOps(
......@@ -98,23 +98,18 @@ void ONNXBackend::ONNXToDragon(
}
}
};
converter(pred_model, pred_graph);
// Set(Initializer) + Set(Placehoders) = Set(Inputs)
// Add external inputs.
Set<string> initializer;
for (const auto& e : onnx_model.graph().initializer()) {
initializer.insert(e.name());
for (const auto& v : onnx_model.graph().initializer()) {
initializer.insert(v.name());
}
// Add External Inputs
for (const auto& e : onnx_model.graph().input()) {
if (initializer.count(e.name()) == 0) {
pred_graph->add_input(e.name());
}
}
// Add External Outputs
// Add external outputs.
for (const auto& e : onnx_model.graph().output()) {
pred_graph->add_output(e.name());
}
......
#include "dragon/operators/array/channel_shuffle_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void ChannelShuffleOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
CHECK_EQ(X.dim(axis) % group_, 0)
<< "\nThe " << X.dim(axis) << " channels "
<< "can not be split into " << group_ << " groups.";
kernels::ChannelShuffle(
X.count(0, axis),
X.count(axis + 1),
X.dim(axis),
group_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void ChannelShuffleOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void ChannelShuffleGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0);
GET_OP_AXIS_ARG(axis, dY.ndim(), -1);
kernels::ChannelShuffle(
dY.count(0, axis),
dY.count(axis + 1),
dY.dim(axis),
dY.dim(axis) / group_,
dY.template data<T, Context>(),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void ChannelShuffleGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(ChannelShuffle);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelShuffle);
#endif
DEPLOY_CPU_OPERATOR(ChannelShuffleGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelShuffleGradient);
#endif
OPERATOR_SCHEMA(ChannelShuffle)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(ChannelShuffleGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(ChannelShuffle, SimpleGradientMaker);
} // namespace dragon
#include "dragon/operators/array/shuffle_op.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
template <class Context>
template <typename T>
void ChannelShuffleOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
CHECK_EQ(X.dim(axis) % group_, 0)
<< "\nThe " << X.dim(axis) << " channels "
<< "can not be split into " << group_ << " groups.";
auto G = group_, K = X.dim(axis) / group_;
if (def().type() == "ChannelShuffleGradient") std::swap(G, K);
math::Transpose(
4,
vec64_t({X.count(0, axis), G, K, X.count(axis + 1)}).data(),
vec64_t({0, 2, 1, 3}).data(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
DEPLOY_CPU_OPERATOR(ChannelShuffle);
REGISTER_CPU_OPERATOR(ChannelShuffleGradient, ChannelShuffleOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelShuffle);
REGISTER_CUDA_OPERATOR(ChannelShuffleGradient, ChannelShuffleOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(ChannelShuffle).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(ChannelShuffleGradient).NumInputs(1).NumOutputs(1);
REGISTER_GRADIENT(ChannelShuffle, SimpleGradientMaker);
} // namespace dragon
......@@ -10,8 +10,8 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
#define DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
#ifndef DRAGON_OPERATORS_ARRAY_SHUFFLE_OP_H_
#define DRAGON_OPERATORS_ARRAY_SHUFFLE_OP_H_
#include "dragon/core/operator.h"
......@@ -25,7 +25,9 @@ class ChannelShuffleOp final : public Operator<Context> {
group_(OP_SINGLE_ARG(int64_t, "group", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -34,22 +36,6 @@ class ChannelShuffleOp final : public Operator<Context> {
int64_t group_;
};
template <class Context>
class ChannelShuffleGradientOp final : public Operator<Context> {
public:
ChannelShuffleGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
group_(OP_SINGLE_ARG(int64_t, "group", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t group_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
#endif // DRAGON_OPERATORS_ARRAY_SHUFFLE_OP_H_
#include "dragon/operators/array/channel_affine_op.h"
#include "dragon/operators/math/affine_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void ChannelAffineOp<Context>::DoRunWithType() {
void AffineOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0, {0});
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
GET_OP_AXIS_ARG(end_axis, X.ndim(), axis);
vec64_t affine_dims(
{X.dims().begin() + axis, X.dims().begin() + end_axis + 1});
// Compute affine dimensions.
vec64_t affine_dims;
for (auto axis : axes_) {
axis = axis < 0 ? axis + X.ndim() : axis;
CHECK(axis >= 0 && axis < X.ndim())
<< "\nExcepted the axis in [-" << X.ndim() << ", " << X.ndim()
<< "), got " << axis << ".";
affine_dims.push_back(X.dim(axis));
}
CHECK(W.dims() == affine_dims)
<< "\nExcepted the weight shape is " << Tensor::DimString(affine_dims)
<< ", got " << W.DimString() << ".";
......@@ -23,10 +27,11 @@ void ChannelAffineOp<Context>::DoRunWithType() {
<< ", got " << Input(2).DimString() << ".";
}
kernels::ChannelAffine(
X.count(0, axis),
X.count(end_axis + 1),
X.count(axis, end_axis + 1),
math::Affine(
X.ndim(),
X.dims().data(),
axes_.size(),
axes_.data(),
X.template data<T, Context>(),
W.template data<T, Context>(),
InputSize() <= 2 ? nullptr : Input(2).template data<T, Context>(),
......@@ -35,28 +40,30 @@ void ChannelAffineOp<Context>::DoRunWithType() {
}
template <class Context>
void ChannelAffineOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Numerical>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void ChannelAffineGradientOp<Context>::DoRunWithType() {
void AffineGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
GET_OP_AXIS_ARG(end_axis, X.ndim(), axis);
vec64_t affine_dims = {X.count(0, axis),
X.count(axis, end_axis + 1),
X.count(end_axis + 1)},
affine_axes = {0, 2};
// Compute reduce axes.
vec64_t reduce_axes;
for (int i = 0; i < X.ndim(); ++i) {
bool keep = true;
for (auto axis : axes_) {
axis = axis < 0 ? axis + X.ndim() : axis;
if (axis == i) keep = false;
}
if (keep) reduce_axes.push_back(i);
}
// Scratch to save the intermediates.
T* data = nullptr;
if (dW->has_name() && X.count() != W.count()) {
data = ctx()->workspace()->template data<T, Context>(X.count());
}
// dW = dY * X
if (dW->has_name()) {
Output(1)->ReshapeLike(Input(1));
auto* x = Input(0).template data<T, Context>();
auto* dw = Output(1)->template mutable_data<T, Context>();
if (X.count() == W.count()) {
math::Mul(
X.count(),
......@@ -65,20 +72,19 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx());
} else {
T* scratch = ctx()->workspace()->template data<T, Context>(X.count());
math::Mul(
X.count(),
dY.template data<T, Context>(),
X.template data<T, Context>(),
scratch,
data,
ctx());
math::ReduceSum(
3,
affine_dims.data(),
2,
affine_axes.data(),
X.ndim(),
X.dims().data(),
reduce_axes.size(),
reduce_axes.data(),
1.f,
scratch,
data,
dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx());
}
......@@ -90,10 +96,10 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
dB->ReshapeLike(W)->CopyFrom(dY, ctx());
} else {
math::ReduceSum(
3,
affine_dims.data(),
2,
affine_axes.data(),
X.ndim(),
X.dims().data(),
reduce_axes.size(),
reduce_axes.data(),
1.f,
dY.template data<T, Context>(),
dB->ReshapeLike(W)->template mutable_data<T, Context>(),
......@@ -103,11 +109,11 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
// dX = dY * W
if (dX->has_name()) {
Output(0)->ReshapeLike(Input(-1));
kernels::ChannelAffine(
X.count(0, axis),
X.count(end_axis + 1),
X.count(axis, end_axis + 1),
math::Affine(
X.ndim(),
X.dims().data(),
axes_.size(),
axes_.data(),
dY.template data<T, Context>(),
W.template data<T, Context>(),
(const T*)nullptr,
......@@ -116,22 +122,17 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
}
}
template <class Context>
void ChannelAffineGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(ChannelAffine);
DEPLOY_CPU_OPERATOR(Affine);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelAffine);
DEPLOY_CUDA_OPERATOR(Affine);
#endif
DEPLOY_CPU_OPERATOR(ChannelAffineGradient);
DEPLOY_CPU_OPERATOR(AffineGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelAffineGradient);
DEPLOY_CUDA_OPERATOR(AffineGradient);
#endif
OPERATOR_SCHEMA(ChannelAffine)
OPERATOR_SCHEMA(Affine)
/* X, W, B */
.NumInputs(2, 3)
/* Y */
......@@ -139,7 +140,7 @@ OPERATOR_SCHEMA(ChannelAffine)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ChannelAffineGradient)
OPERATOR_SCHEMA(AffineGradient)
/* X, W, dY */
.NumInputs(3)
/* dX, dW, dB */
......@@ -163,6 +164,6 @@ class GradientMaker final : public GradientMakerBase {
} // namespace
REGISTER_GRADIENT(ChannelAffine, GradientMaker);
REGISTER_GRADIENT(Affine, GradientMaker);
} // namespace dragon
......@@ -10,37 +10,49 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_
#define DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_
#ifndef DRAGON_OPERATORS_MATH_AFFINE_OP_H_
#define DRAGON_OPERATORS_MATH_AFFINE_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class ChannelAffineOp final : public Operator<Context> {
class AffineOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(ChannelAffineOp);
AffineOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), axes_(OP_REPEATED_ARG(int64_t, "axes")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
vec64_t axes_;
};
template <class Context>
class ChannelAffineGradientOp final : public Operator<Context> {
class AffineGradientOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(ChannelAffineGradientOp);
AffineGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), axes_(OP_REPEATED_ARG(int64_t, "axes")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
vec64_t axes_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CHANNEL_AFFINE_OP_H_
#endif // DRAGON_OPERATORS_MATH_AFFINE_OP_H_
......@@ -26,6 +26,7 @@ DISPATCH_WITH_TENSOR_TYPES(IsInf, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(IsNaN, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(IsFinite, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(Pow, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(Atan2, dtypes::Floating, Input(0));
DISPATCH_WITH_TENSOR_TYPES(Minimum, dtypes::Numerical, Input(0));
DISPATCH_WITH_TENSOR_TYPES(Maximum, dtypes::Numerical, Input(0));
DISPATCH_WITH_TENSOR_TYPES(BitwiseNot, dtypes::Bitwise, Input(0));
......@@ -120,6 +121,7 @@ DEFINE_INPLACE_UNARY_OP_IMPL(BitwiseNot, T);
}
DEFINE_SIMPLE_BINARY_OP_IMPL(Pow, T);
DEFINE_SIMPLE_BINARY_OP_IMPL(Atan2, T);
DEFINE_SIMPLE_BINARY_OP_IMPL(Minimum, T);
DEFINE_SIMPLE_BINARY_OP_IMPL(Maximum, T);
DEFINE_SIMPLE_BINARY_OP_IMPL(BitwiseAnd, T);
......@@ -152,6 +154,7 @@ DEPLOY_CPU_OPERATOR(IsInf);
DEPLOY_CPU_OPERATOR(IsNaN);
DEPLOY_CPU_OPERATOR(IsFinite);
DEPLOY_CPU_OPERATOR(Pow);
DEPLOY_CPU_OPERATOR(Atan2);
DEPLOY_CPU_OPERATOR(Minimum);
DEPLOY_CPU_OPERATOR(Maximum);
DEPLOY_CPU_OPERATOR(BitwiseNot);
......@@ -186,6 +189,7 @@ DEPLOY_CUDA_OPERATOR(IsInf);
DEPLOY_CUDA_OPERATOR(IsNaN);
DEPLOY_CUDA_OPERATOR(IsFinite);
DEPLOY_CUDA_OPERATOR(Pow);
DEPLOY_CUDA_OPERATOR(Atan2);
DEPLOY_CUDA_OPERATOR(Minimum);
DEPLOY_CUDA_OPERATOR(Maximum);
DEPLOY_CUDA_OPERATOR(BitwiseNot);
......@@ -222,6 +226,7 @@ OPERATOR_SCHEMA(IsNaN).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(IsFinite).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Not).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Pow).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(Atan2).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(Minimum).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(Maximum).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(BitwiseAnd)
......@@ -250,6 +255,7 @@ NO_GRADIENT(Round);
NO_GRADIENT(IsInf);
NO_GRADIENT(IsNaN);
NO_GRADIENT(IsFinite);
NO_GRADIENT(Atan2);
NO_GRADIENT(BitwiseNot);
NO_GRADIENT(BitwiseAnd);
NO_GRADIENT(BitwiseOr);
......
......@@ -70,7 +70,7 @@ inline vec32_t CheckOutputAliases(
return available_aliases;
}
// Unary ElementwiseOp
// Unary ElementwiseOp.
DECLARE_ELEMENTWISE_OP(Abs);
DECLARE_ELEMENTWISE_OP(Ceil);
DECLARE_ELEMENTWISE_OP(Cos);
......@@ -101,12 +101,13 @@ DECLARE_ELEMENTWISE_OP(SignGradient);
DECLARE_ELEMENTWISE_OP(SinGradient);
DECLARE_ELEMENTWISE_OP(SqrtGradient);
DECLARE_ELEMENTWISE_OP(SquareGradient);
// Binary ElementwiseOp
// Binary ElementwiseOp.
DECLARE_ELEMENTWISE_OP(Add);
DECLARE_ELEMENTWISE_OP(Sub);
DECLARE_ELEMENTWISE_OP(Mul);
DECLARE_ELEMENTWISE_OP(Div);
DECLARE_ELEMENTWISE_OP(Pow);
DECLARE_ELEMENTWISE_OP(Atan2);
DECLARE_ELEMENTWISE_OP(Minimum);
DECLARE_ELEMENTWISE_OP(Maximum);
DECLARE_ELEMENTWISE_OP(BitwiseAnd);
......@@ -128,7 +129,7 @@ DECLARE_ELEMENTWISE_OP(DivGradient);
DECLARE_ELEMENTWISE_OP(PowGradient);
DECLARE_ELEMENTWISE_OP(MinimumGradient);
DECLARE_ELEMENTWISE_OP(MaximumGradient);
// Trinary ElementwiseOp
// Trinary ElementwiseOp.
DECLARE_ELEMENTWISE_OP(Where);
DECLARE_ELEMENTWISE_OP(WhereGradient);
#undef DECLARE_ELEMENTWISE_OP
......
......@@ -199,11 +199,6 @@ void MatMulOp<Context>::DoRunWithType() {
}
template <class Context>
void MatMulOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void MatMulGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(2);
......@@ -590,11 +585,6 @@ void MatMulGradientOp<Context>::DoRunWithType() {
}
}
template <class Context>
void MatMulGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(MatMul);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(MatMul);
......
......@@ -23,7 +23,9 @@ class MatMulOp final : public Operator<Context> {
SIMPLE_CTOR_DTOR(MatMulOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -35,7 +37,9 @@ class MatMulGradientOp final : public Operator<Context> {
SIMPLE_CTOR_DTOR(MatMulGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......
#include "dragon/operators/array/channel_normalize_op.h"
#include "dragon/operators/normalization/channel_norm_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h"
......@@ -6,7 +6,7 @@ namespace dragon {
template <class Context>
template <typename InputT, typename OutputT>
void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() {
void ChannelNormOp<Context>::DoRunWithTypeAndCast() {
auto &X = Input(0), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
......@@ -30,7 +30,7 @@ void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() {
<< "\nProviding " << X_mean_.count() << " values to normalize Dimension("
<< Y_dims[axis] << ").";
kernels::ChannelNormalize(
kernels::ChannelNorm(
axis,
num_dims,
X_strides.data(),
......@@ -44,7 +44,7 @@ void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() {
template <class Context>
template <typename T>
void ChannelNormalizeOp<Context>::DoRunWithType() {
void ChannelNormOp<Context>::DoRunWithType() {
if (data_type() == "float16") {
DoRunWithTypeAndCast<T, float16>();
} else if (data_type() == "float32") {
......@@ -58,21 +58,21 @@ void ChannelNormalizeOp<Context>::DoRunWithType() {
}
template <class Context>
void ChannelNormalizeOp<Context>::RunOnDevice() {
void ChannelNormOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Numerical>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(ChannelNormalize);
DEPLOY_CPU_OPERATOR(ChannelNorm);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ChannelNormalize);
DEPLOY_CUDA_OPERATOR(ChannelNorm);
#endif
OPERATOR_SCHEMA(ChannelNormalize)
OPERATOR_SCHEMA(ChannelNorm)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
NO_GRADIENT(ChannelNormalize);
NO_GRADIENT(ChannelNorm);
} // namespace dragon
......@@ -10,17 +10,17 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_NORMALIZE_OP_H_
#define DRAGON_OPERATORS_ARRAY_CHANNEL_NORMALIZE_OP_H_
#ifndef DRAGON_OPERATORS_NORMALIZATION_CHANNEL_NORM_OP_H_
#define DRAGON_OPERATORS_NORMALIZATION_CHANNEL_NORM_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class ChannelNormalizeOp final : public Operator<Context> {
class ChannelNormOp final : public Operator<Context> {
public:
ChannelNormalizeOp(const OperatorDef& def, Workspace* ws)
ChannelNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
INITIALIZE_OP_REPEATED_ARG(int64_t, perm);
auto mean = OP_REPEATED_ARG(float, "mean");
......@@ -50,8 +50,8 @@ class ChannelNormalizeOp final : public Operator<Context> {
DECLARE_OP_REPEATED_ARG(int64_t, perm);
};
DEFINE_OP_REPEATED_ARG(int64_t, ChannelNormalizeOp, perm);
DEFINE_OP_REPEATED_ARG(int64_t, ChannelNormOp, perm);
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CHANNEL_NORMALIZE_OP_H_
#endif // DRAGON_OPERATORS_NORMALIZATION_CHANNEL_NORM_OP_H_
#include "dragon/operators/normalization/lp_normalize_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
#include "dragon/operators/normalization/lp_norm_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void LpNormalizeOp<Context>::DoRunWithType() {
void LpNormOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
GET_OP_AXIS_ARG(end_axis, X.ndim(), axis);
auto reduce_dim = X.count(axis, end_axis + 1);
// Normalize input with a scaled Lp-norm
if (p_ == 1) {
kernels::L1Normalize(
kernels::L1Norm(
X.count(0, axis),
X.count(end_axis + 1),
reduce_dim,
......@@ -25,7 +22,7 @@ void LpNormalizeOp<Context>::DoRunWithType() {
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
} else if (p_ == 2) {
kernels::L2Normalize(
kernels::L2Norm(
X.count(0, axis),
X.count(end_axis + 1),
reduce_dim,
......@@ -40,20 +37,15 @@ void LpNormalizeOp<Context>::DoRunWithType() {
}
template <class Context>
void LpNormalizeOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void LpNormalizeGradientOp<Context>::DoRunWithType() {
void LpNormGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(1), *dX = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
GET_OP_AXIS_ARG(end_axis, X.ndim(), axis);
auto reduce_dim = X.count(axis, end_axis + 1);
if (p_ == 1) {
kernels::L1NormalizeGrad(
kernels::L1NormGrad(
X.count(0, axis),
X.count(end_axis + 1),
reduce_dim,
......@@ -64,7 +56,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
} else if (p_ == 2) {
kernels::L2NormalizeGrad(
kernels::L2NormGrad(
X.count(0, axis),
X.count(end_axis + 1),
reduce_dim,
......@@ -79,33 +71,28 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
}
}
template <class Context>
void LpNormalizeGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(LpNormalize);
DEPLOY_CPU_OPERATOR(LpNorm);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(LpNormalize);
DEPLOY_CUDA_OPERATOR(LpNorm);
#endif
DEPLOY_CPU_OPERATOR(LpNormalizeGradient);
DEPLOY_CPU_OPERATOR(LpNormGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(LpNormalizeGradient);
DEPLOY_CUDA_OPERATOR(LpNormGradient);
#endif
OPERATOR_SCHEMA(LpNormalize)
OPERATOR_SCHEMA(LpNorm)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(LpNormalizeGradient)
OPERATOR_SCHEMA(LpNormGradient)
/* X, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(LpNormalize, GenericGradientMaker);
REGISTER_GRADIENT(LpNorm, GenericGradientMaker);
} // namespace dragon
......@@ -10,24 +10,26 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_NORMALIZATION_LP_NORMALIZE_OP_H_
#define DRAGON_OPERATORS_NORMALIZATION_LP_NORMALIZE_OP_H_
#ifndef DRAGON_OPERATORS_NORMALIZATION_LP_NORM_OP_H_
#define DRAGON_OPERATORS_NORMALIZATION_LP_NORM_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class LpNormalizeOp final : public Operator<Context> {
class LpNormOp final : public Operator<Context> {
public:
LpNormalizeOp(const OperatorDef& def, Workspace* ws)
LpNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
p_(OP_SINGLE_ARG(int64_t, "p", 2)),
epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-12)),
reduction_(OP_SINGLE_ARG(string, "reduction", "SUM")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -39,16 +41,18 @@ class LpNormalizeOp final : public Operator<Context> {
};
template <class Context>
class LpNormalizeGradientOp final : public Operator<Context> {
class LpNormGradientOp final : public Operator<Context> {
public:
LpNormalizeGradientOp(const OperatorDef& def, Workspace* ws)
LpNormGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
p_(OP_SINGLE_ARG(int64_t, "p", 2)),
epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-12)),
reduction_(OP_SINGLE_ARG(string, "reduction", "SUM")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -61,4 +65,4 @@ class LpNormalizeGradientOp final : public Operator<Context> {
} // namespace dragon
#endif // DRAGON_OPERATORS_NORMALIZATION_LP_NORMALIZE_OP_H_
#endif // DRAGON_OPERATORS_NORMALIZATION_LP_NORM_OP_H_
......@@ -5,46 +5,41 @@
namespace dragon {
template <class Context>
void AdamOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) {
template <typename T, typename CopyT>
void AdamOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::Adam(
dX->count(),
lr_ * correction_,
beta1_,
beta2_,
eps_,
dX->template mutable_data<float, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
this->weight_decay_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
GetState("v")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx());
}
template <class Context>
void AdamWOp<Context>::ComputeUpdate(Tensor* dX, Tensor* X) {
if (lambda_ > 0.f) {
kernels::AdamW(
dX->count(),
lr_ * correction_,
beta1_,
beta2_,
eps_,
this->lr_ * lambda_,
X->template data<float, Context>(),
dX->template mutable_data<float, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx());
} else {
kernels::Adam(
dX->count(),
lr_ * correction_,
beta1_,
beta2_,
eps_,
dX->template mutable_data<float, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
ctx());
}
template <typename T, typename CopyT>
void AdamWOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::AdamW(
dX->count(),
lr_ * correction_,
beta1_,
beta2_,
eps_,
lr_ * this->weight_decay_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
GetState("v")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx());
}
DEPLOY_CPU_OPERATOR(Adam);
......
......@@ -5,16 +5,21 @@
namespace dragon {
template <class Context>
void RMSpropOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) {
template <typename T, typename CopyT>
void RMSpropOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::RMSprop(
dX->count(),
lr_,
momentum_,
decay_,
alpha_,
eps_,
dX->template mutable_data<float, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
Slot("v")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
this->weight_decay_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
GetState("v")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx());
}
......
......@@ -6,33 +6,44 @@
namespace dragon {
template <class Context>
void MomentumSGDOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) {
template <typename T, typename CopyT>
void MomentumSGDOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::MomentumSGD(
dX->count(),
lr_,
momentum_,
dX->template mutable_data<float, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
this->weight_decay_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx());
}
template <class Context>
void NesterovSGDOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) {
template <typename T, typename CopyT>
void NesterovSGDOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
kernels::NesterovSGD(
dX->count(),
lr_,
momentum_,
dX->template mutable_data<float, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
this->weight_decay_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx());
}
template <class Context>
void LARSOp<Context>::ComputeUpdate(Tensor* dX, Tensor* X) {
template <typename T, typename CopyT>
void LARSOp<Context>::DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y) {
float trust_ratio = 0.f;
if (trust_coef_ > 0.f) {
auto* x = X->template data<float, Context>();
auto* dx = dX->template mutable_data<float, Context>();
auto* x = X->template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>();
float x_norm = std::sqrt(math::Dot(X->count(), x, x, ctx()));
float dx_norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx()));
if (x_norm > 0.f && dx_norm > 0.f) {
......@@ -43,16 +54,20 @@ void LARSOp<Context>::ComputeUpdate(Tensor* dX, Tensor* X) {
math::Scale(
dX->count(),
trust_ratio,
dX->template data<float, Context>(),
dX->template mutable_data<float, Context>(),
dX->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
ctx());
}
kernels::MomentumSGD(
dX->count(),
lr_,
momentum_,
dX->template mutable_data<float, Context>(),
Slot("m")->ReshapeLike(*dX)->template mutable_data<float, Context>(),
this->weight_decay_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
GetState("m")->ReshapeLike(*dX)->template mutable_data<T, Context>(),
X->template mutable_data<T, Context>(),
Y ? Y->template mutable_data<CopyT, Context>() : (CopyT*)nullptr,
ctx());
}
......
......@@ -37,17 +37,14 @@ class UpdateOpBase : public Operator<Context> {
void RunOnDevice() override;
template <typename T>
void TransformGrad(Tensor* dX, Tensor* X);
void TransformGrad(Tensor* dX);
virtual void ComputeUpdate(Tensor* dX, Tensor* X) = 0;
template <typename T>
void ApplyUpdate(Tensor* dX, Tensor* X);
virtual void ApplyUpdate(Tensor* dX, Tensor* X, Tensor* Y) = 0;
template <typename T>
T GetHyper(const string& key);
Tensor* Slot(const string& key);
Tensor* GetState(const string& key);
protected:
int weight_index_;
......@@ -55,9 +52,26 @@ class UpdateOpBase : public Operator<Context> {
float clip_norm_, clip_value_;
};
#define USE_UPDATE_FUNCTIONS \
using UpdateOpBase<Context>::GetHyper; \
using UpdateOpBase<Context>::Slot
#define USE_UPDATE_FUNCTIONS \
using UpdateOpBase<Context>::GetHyper; \
using UpdateOpBase<Context>::GetState; \
void ApplyUpdate(Tensor* dX, Tensor* X, Tensor* Y) override { \
if (dX->template IsType<float>()) { \
if (Y == nullptr) { \
DoRunWithType<float, float>(dX, X, Y); \
} else if (Y->template IsType<float16>()) { \
DoRunWithType<float, float16>(dX, X, Y); \
} else { \
LOG(FATAL) << MessageForUnsupported( \
dtypes::to_string(Y->meta()), {"float16", "float32"}); \
} \
} else if (dX->template IsType<double>()) { \
DoRunWithType<double, double>(dX, X, Y); \
} else { \
LOG(FATAL) << MessageForUnsupported( \
dtypes::to_string(dX->meta()), {"float32", "float64"}); \
} \
}
template <class Context>
class MomentumSGDOp final : public UpdateOpBase<Context> {
......@@ -73,7 +87,8 @@ class MomentumSGDOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override;
template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected:
float lr_, momentum_;
......@@ -93,7 +108,8 @@ class NesterovSGDOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override;
template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected:
float lr_, momentum_;
......@@ -110,15 +126,16 @@ class RMSpropOp final : public UpdateOpBase<Context> {
void GetArguments() override {
lr_ = this->template GetHyper<float>("lr");
momentum_ = this->template GetHyper<float>("momentum");
decay_ = this->template GetHyper<float>("decay");
alpha_ = this->template GetHyper<float>("alpha");
eps_ = this->template GetHyper<float>("eps");
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override;
template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected:
float lr_, momentum_, decay_, eps_;
float lr_, momentum_, alpha_, eps_;
};
template <class Context>
......@@ -139,7 +156,8 @@ class AdamOp : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override;
template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected:
int64_t t_;
......@@ -163,16 +181,15 @@ class AdamWOp final : public UpdateOpBase<Context> {
t_++;
correction_ = sqrt(1.f - pow(beta2_, t_)) / (1.f - pow(beta1_, t_));
UpdateOpBase<Context>::GetArguments();
lambda_ = this->weight_decay_;
this->weight_decay_ = 0.f;
}
void ComputeUpdate(Tensor* dX, Tensor* X) override;
template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected:
int64_t t_;
float lr_, beta1_, beta2_;
float eps_, correction_, lambda_;
float eps_, correction_;
};
template <class Context>
......@@ -190,14 +207,13 @@ class LARSOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX, Tensor* X) override;
template <typename T, typename CopyT>
void DoRunWithType(Tensor* dX, Tensor* X, Tensor* Y);
protected:
float lr_, momentum_, trust_coef_;
};
#undef USE_UPDATE_FUNCTIONS
} // namespace dragon
#endif // DRAGON_OPERATORS_TRAINING_UPDATE_OP_H_
......@@ -13,67 +13,40 @@ T UpdateOpBase<Context>::GetHyper(const string& key) {
}
template <class Context>
Tensor* UpdateOpBase<Context>::Slot(const string& key) {
Tensor* UpdateOpBase<Context>::GetState(const string& key) {
const string& weight_name = Output(weight_index_)->name();
return workspace()->CreateTensor(name() + "/" + weight_name + "/" + key);
}
template <class Context>
template <typename T>
void UpdateOpBase<Context>::TransformGrad(Tensor* dX, Tensor* X) {
// Scale.
void UpdateOpBase<Context>::TransformGrad(Tensor* dX) {
if (grad_scale_ != 1.f) {
auto* dx = dX->template mutable_data<T, Context>();
math::Scale(dX->count(), grad_scale_, dx, dx, ctx());
}
// Clip.
if (clip_norm_ > 0.f) {
auto* dx = dX->template mutable_data<T, Context>();
float grad_norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx()));
if (grad_norm > clip_norm_) {
math::Scale(dX->count(), clip_norm_ / grad_norm, dx, dx, ctx());
float norm = std::sqrt(math::Dot(dX->count(), dx, dx, ctx()));
if (norm > clip_norm_) {
math::Scale(dX->count(), clip_norm_ / norm, dx, dx, ctx());
}
} else if (clip_value_ > 0.f) {
auto* dx = dX->template mutable_data<T, Context>();
kernels::Clip(dX->count(), -clip_value_, clip_value_, dx, dx, ctx());
}
// Penalty.
if (weight_decay_ > 0.f) {
math::Axpy(
X->count(),
weight_decay_,
X->template data<T, Context>(),
dX->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
template <typename T>
void UpdateOpBase<Context>::ApplyUpdate(Tensor* dX, Tensor* X) {
math::Sub(
X->count(),
X->template data<T, Context>(),
dX->template data<T, Context>(),
X->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void UpdateOpBase<Context>::RunOnDevice() {
GetArguments();
for (int i = 0; i < InputSize(); ++i) {
weight_index_ = i;
auto &dX = Input(i), *X = Output(i);
if (dX.count() == 0 || X->count() == 0) return;
for (weight_index_ = 0; weight_index_ < InputSize(); ++weight_index_) {
auto &dX = Input(weight_index_), *X = Output(weight_index_);
if (dX.count() == 0 || X->count() == 0) continue;
CHECK(dX.dims() == X->dims())
<< "\nWeight and grad should have the same dimensions."
<< "\nGot" << X->DimString() << " and " << dX.DimString();
if (dX.template IsType<float>()) {
TransformGrad<float>(&dX, X);
ComputeUpdate(&dX, X);
ApplyUpdate<float>(&dX, X);
} else if (dX.template IsType<float16>()) {
if (dX.template IsType<float16>()) {
auto* X_master = workspace()->CreateTensor(X->name() + "_master");
auto* X_grad = ctx()->workspace()->CreateTensor("BufferShared");
if (X_master->count() != X->count()) {
......@@ -88,17 +61,17 @@ void UpdateOpBase<Context>::RunOnDevice() {
dX.template data<float16, Context>(),
X_grad->ReshapeLike(dX)->template mutable_data<float, Context>(),
ctx());
TransformGrad<float>(X_grad, X_master);
ComputeUpdate(X_grad, X_master);
ApplyUpdate<float>(X_grad, X_master);
math::Cast(
X->count(),
X_master->template data<float, Context>(),
X->template mutable_data<float16, Context>(),
ctx());
TransformGrad<float>(X_grad);
ApplyUpdate(X_grad, X_master, X);
} else if (dX.template IsType<float>()) {
TransformGrad<float>(&dX);
ApplyUpdate(&dX, X, nullptr);
} else if (dX.template IsType<double>()) {
TransformGrad<double>(&dX);
ApplyUpdate(&dX, X, nullptr);
} else {
LOG(FATAL) << MessageForUnsupported(
dtypes::to_string(dX.meta()), {"float16", "float32"});
dtypes::to_string(dX.meta()), {"float16", "float32", "float64"});
}
}
}
......
......@@ -58,9 +58,6 @@ from dragon.core.ops import tensor_ops as _
from dragon.core.ops.array_ops import assign
from dragon.core.ops.array_ops import boolean_mask
from dragon.core.ops.array_ops import broadcast_to
from dragon.core.ops.array_ops import channel_affine
from dragon.core.ops.array_ops import channel_normalize
from dragon.core.ops.array_ops import channel_shuffle
from dragon.core.ops.array_ops import concat
from dragon.core.ops.array_ops import expand_dims
from dragon.core.ops.array_ops import flatten
......
......@@ -21,6 +21,7 @@ from dragon.core.device.cuda import current_device
from dragon.core.device.cuda import get_device_capability
from dragon.core.device.cuda import is_available
from dragon.core.device.cuda import memory_allocated
from dragon.core.device.cuda import set_cublas_flags
from dragon.core.device.cuda import set_cudnn_flags
from dragon.core.device.cuda import set_default_device
from dragon.core.device.cuda import set_device
......
......@@ -17,8 +17,10 @@ from dragon.core.ops.activation_ops import sigmoid
from dragon.core.ops.activation_ops import tanh
from dragon.core.ops.math_ops import abs
from dragon.core.ops.math_ops import add
from dragon.core.ops.math_ops import affine
from dragon.core.ops.math_ops import argmax
from dragon.core.ops.math_ops import argmin
from dragon.core.ops.math_ops import atan2
from dragon.core.ops.math_ops import ceil
from dragon.core.ops.math_ops import clip
from dragon.core.ops.math_ops import cos
......@@ -60,7 +62,6 @@ from dragon.core.ops.math_ops import sqrt
from dragon.core.ops.math_ops import square
from dragon.core.ops.math_ops import sub
from dragon.core.ops.math_ops import sum
from dragon.core.ops.normalization_ops import lp_normalize
from dragon.core.ops.sort_ops import top_k
__all__ = [_s for _s in dir() if not _s.startswith('_')]
......@@ -34,12 +34,15 @@ from dragon.core.ops.activation_ops import relu6
from dragon.core.ops.activation_ops import selu
from dragon.core.ops.activation_ops import silu
from dragon.core.ops.activation_ops import softmax
from dragon.core.ops.array_ops import channel_shuffle
from dragon.core.ops.math_ops import moments
from dragon.core.ops.normalization_ops import batch_norm
from dragon.core.ops.normalization_ops import channel_norm
from dragon.core.ops.normalization_ops import group_norm
from dragon.core.ops.normalization_ops import instance_norm
from dragon.core.ops.normalization_ops import layer_norm
from dragon.core.ops.normalization_ops import local_response_norm
from dragon.core.ops.normalization_ops import lp_norm
from dragon.core.ops.normalization_ops import sync_batch_norm
from dragon.core.ops.vision_ops import bias_add
from dragon.core.ops.vision_ops import conv
......
......@@ -78,16 +78,13 @@ def cast_args(**kwargs):
return {'dtype': kwargs.get('dtype', 'float32')}
@register('ChannelAffine')
def channel_affine_args(**kwargs):
return {
'axis': kwargs.get('axis', -1),
'end_axis': kwargs.get('end_axis', kwargs.get('axis', -1)),
}
@register('Affine')
def affine_args(**kwargs):
return {'axes': kwargs.get('axes', None)}
@register('ChannelNormalize')
def channel_normalize_args(**kwargs):
@register('ChannelNorm')
def channel_norm_args(**kwargs):
return {
'axis': kwargs.get('axis', -1),
'mean': kwargs.get('mean', None),
......@@ -323,8 +320,8 @@ def loss_args(**kwargs):
return {'reduction': kwargs.get('reduction', 'MEAN')}
@register('LpNormalize')
def lp_normalize_args(**kwargs):
@register('LpNorm')
def lp_norm_args(**kwargs):
return {
'p': kwargs.get('p', 2),
'axis': kwargs.get('axis', -1),
......
......@@ -81,6 +81,7 @@ def binary_shape_spec(inputs, outputs):
@register([
'Add',
'Atan2',
'BitwiseAnd',
'BitwiseOr',
'BitwiseXor',
......@@ -403,7 +404,7 @@ def gemm_spec(args, inputs, outputs):
return outputs
@register('ChannelNormalize')
@register('ChannelNorm')
def channel_normalize_spec(args, inputs, outputs):
outputs[0]._dtype = args['dtype']
try:
......
......@@ -62,11 +62,23 @@ def current_device():
return backend.cudaGetDevice()
def set_cublas_flags(allow_tf32=None):
"""Set the flags of cuBLAS library.
Parameters
----------
allow_tf32 : bool, optional, default=False
Allow TF32 tensor core operation or not.
"""
backend.cublasSetFlags(-1 if allow_tf32 is None else allow_tf32)
def set_cudnn_flags(
enabled=True,
benchmark=False,
deterministic=False,
allow_tf32=False,
enabled=None,
benchmark=None,
deterministic=None,
allow_tf32=None,
):
"""Set the flags of cuDNN library.
......@@ -82,7 +94,11 @@ def set_cudnn_flags(
Allow TF32 tensor core operation or not.
"""
backend.cudnnSetFlags(enabled, benchmark, deterministic, allow_tf32)
backend.cudnnSetFlags(
-1 if enabled is None else enabled,
-1 if benchmark is None else benchmark,
-1 if deterministic is None else deterministic,
-1 if allow_tf32 is None else allow_tf32)
def get_device_capability(device_index=None):
......
......@@ -122,107 +122,16 @@ def broadcast_to(inputs, shape, **kwargs):
return OpLib.add('Expand', **args)
@OpSchema.num_inputs(2, 3)
def channel_affine(inputs, axis=-1, end_axis=None, **kwargs):
r"""Apply affine transformation to each channel of input.
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input, weight and optional bias tensor.
axis : int, optional, default=-1
The first channel axis.
end_axis : int, optional
The last channel axis.
Returns
-------
dragon.Tensor
The output tensor.
"""
outputs = kwargs.pop('outputs', [None])
if context.executing_eagerly():
return OpLib.execute(
'ChannelAffine', inputs, outputs=outputs,
axis=axis, end_axis=end_axis)
return OpLib.add('ChannelAffine', inputs,
axis=axis, end_axis=end_axis, **kwargs)
@OpSchema.num_inputs(1)
@OpSchema.convert_arg('perm')
def channel_normalize(
inputs,
mean,
std,
axis=-1,
dtype='float32',
perm=None,
**kwargs
):
"""Apply normalization to each channel of input.
:attr:`axis` can be negative:
```python
m = s = (1., 1., 1.)
x = dragon.constant([1, 2, 3])
print(dragon.channel_normalize(x, m, s, axis=0)) # [0., 1., 2.]
print(dragon.channel_normalize(x, m, s, axis=-1)) # Equivalent
```
If :attr:`perm` provided, :attr:`axis` is selected from the output layout:
```python
m, s = (1., 2., 3.), (1., 1., 1.)
x = dragon.constant([[1, 2, 3]])
# Provided 3 values to normalize the last axis
# with length 1, only the first value will be taken
print(dragon.channel_normalize(x, m, s, perm=(1, 0))) # [[0.], [1.], [2.]]
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
mean : Sequence[float], required
The mean to subtract.
std : Sequence[float], required
The standard deviation to divide.
axis : int, optional, default=-1
The channel axis.
dtype : str, optional, default='float32'
The output data type.
perm : Sequence[Union[int, dragon.Tensor]], optional
The output permutation.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = OpSchema.parse_args(locals())
if context.executing_eagerly():
return OpLib.execute(
'ChannelNormalize', inputs,
axis=axis, mean=mean, std=std, dtype=dtype,
ndim=len(args['perm']) if perm is not None else 0,
perm=args['perm'])
return OpLib.add('ChannelNormalize', **args)
@OpSchema.num_inputs(1)
def channel_shuffle(inputs, axis=-1, group=1, **kwargs):
"""Apply group shuffle to each channel of input.
"""Apply the group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
Examples:
```python
x = dragon.constant([1, 2, 3, 4])
print(dragon.channel_shuffle(x, group=2)) # [1, 3, 2, 4]
print(dragon.nn.channel_shuffle(x, group=2)) # [1, 3, 2, 4]
```
Parameters
......
......@@ -82,6 +82,30 @@ def add(inputs, **kwargs):
return OpLib.add('Add', inputs, **kwargs)
@OpSchema.num_inputs(2, 3)
def affine(inputs, axis=-1, **kwargs):
"""Apply affine transformation to input.
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input, scale and bias tensor.
axis : Union[int, Sequence[int]], optional, default=-1
The axis to apply.
Returns
-------
dragon.Tensor
The output tensor.
"""
axes = nest.flatten(axis)
outputs = kwargs.pop('outputs', [None])
if context.executing_eagerly():
return OpLib.execute('Affine', inputs, outputs=outputs, axes=axes)
return OpLib.add('Affine', inputs, axes=axes, **kwargs)
@OpSchema.num_inputs(1)
def argmax(inputs, axis=0, keepdims=False, **kwargs):
"""Compute the index of maximum elements along the given axis.
......@@ -149,6 +173,37 @@ def argmin(inputs, axis=0, keepdims=False, **kwargs):
@OpSchema.num_inputs(2)
def atan2(inputs, **kwargs):
r"""Compute the element-wise arc-tangent of two arguments.
.. math:: \text{out} = \text{arctan}(\frac{\text{input1}}{\text{input2}})
Examples:
```python
y = dragon.constant(1)
x = dragon.constant(2)
print(dragon.math.atan2([y, x])) # 0.46364761
```
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input1 and input2 tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
inputs = constant_ops.remove_scalars(inputs)
if context.executing_eagerly():
return OpLib.execute('Atan2', inputs)
return OpLib.add('Atan2', inputs, **kwargs)
@OpSchema.num_inputs(2)
def bitwise_and(inputs, **kwargs):
r"""Compute the element-wise AND bitwise operation.
......
......@@ -72,6 +72,69 @@ def batch_norm(
return OpLib.add('BatchNorm', **args)
@OpSchema.num_inputs(1)
@OpSchema.convert_arg('perm')
def channel_norm(
inputs,
mean,
std,
axis=-1,
dtype='float32',
perm=None,
**kwargs
):
"""Apply the normalization to each channel of input.
:attr:`axis` can be negative:
```python
m = s = (1., 1., 1.)
x = dragon.constant([1, 2, 3])
print(dragon.nn.channel_norm(x, m, s, axis=0)) # [0., 1., 2.]
print(dragon.nn.channel_norm(x, m, s, axis=-1)) # Equivalent
```
If :attr:`perm` provided, :attr:`axis` is selected from the output layout:
```python
m, s = (1., 2., 3.), (1., 1., 1.)
x = dragon.constant([[1, 2, 3]])
# Provided 3 values to normalize the last axis
# with length 1, only the first value will be taken
print(dragon.nn.channel_norm(x, m, s, perm=(1, 0))) # [[0.], [1.], [2.]]
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
mean : Sequence[float], required
The mean to subtract.
std : Sequence[float], required
The standard deviation to divide.
axis : int, optional, default=-1
The channel axis.
dtype : str, optional, default='float32'
The output data type.
perm : Sequence[Union[int, dragon.Tensor]], optional
The output permutation.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = OpSchema.parse_args(locals())
if context.executing_eagerly():
return OpLib.execute(
'ChannelNorm', inputs,
axis=axis, mean=mean, std=std, dtype=dtype,
ndim=len(args['perm']) if perm is not None else 0,
perm=args['perm'])
return OpLib.add('ChannelNorm', **args)
@OpSchema.num_inputs(3)
def group_norm(inputs, axis=-1, group=0, epsilon=1e-5, **kwargs):
r"""Apply the group normalization.
......@@ -180,7 +243,7 @@ def layer_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
@OpSchema.num_inputs(1)
def lp_normalize(
def lp_norm(
inputs,
axis=-1,
end_axis=None,
......@@ -200,15 +263,15 @@ def lp_normalize(
```python
x = dragon.constant([[1, 2, 3], [4, 5, 6]], 'float32')
# A negative axis is the last-k axis
print(dragon.math.lp_normalize(x, 1))
print(dragon.math.lp_normalize(x, -1)) # Equivalent
print(dragon.nn.lp_norm(x, 1))
print(dragon.nn.lp_norm(x, -1)) # Equivalent
```
More than one axis could be specified to reduce:
```python
# Along the continuous axes: [axis, end_axis]
print(dragon.math.lp_normalize(x, axis=0, end_axis=1))
print(dragon.nn.lp_norm(x, axis=0, end_axis=1))
```
Parameters
......@@ -236,9 +299,9 @@ def lp_normalize(
reduction = reduction.upper()
if context.executing_eagerly():
return OpLib.execute(
'LpNormalize', inputs, p=p, axis=axis, end_axis=end_axis,
'LpNorm', inputs, p=p, axis=axis, end_axis=end_axis,
epsilon=epsilon, reduction=reduction)
return OpLib.add('LpNormalize', inputs, p=p, axis=axis, end_axis=end_axis,
return OpLib.add('LpNorm', inputs, p=p, axis=axis, end_axis=end_axis,
epsilon=epsilon, reduction=reduction, **kwargs)
......
......@@ -24,9 +24,11 @@ class Adam(optimizer.Optimizer):
The **Adam** update is defined as:
.. math::
\text{Adam}(g) = \frac{\text{lr} * m_{t}}{\sqrt{v_{t}} + \epsilon} \\
\text{Adam}(g) = \text{lr} * (\frac{\text{correction}* m_{t}}
{\sqrt{v_{t}} + \epsilon}) \\
\quad \\ \text{where}\quad
\begin{cases}
\text{correction} = \sqrt{1 - \beta_{2}^{t}} / (1 - \beta_{1}^{t}) \\
m_{t} = \beta_{1} * m_{t-1} + (1 - \beta_{1}) * g \\
v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2}
\end{cases}
......@@ -62,12 +64,13 @@ class AdamW(Adam):
The **AdamW** update is defined as:
.. math::
\text{AdamW}(g, p) = \text{lr} * (\frac{m_{t}}{\sqrt{v_{t}} + \epsilon}
+ \lambda p) \\
\text{AdamW}(g, p) = \text{lr} * (\frac{\text{correction} * m_{t}}
{\sqrt{v_{t}} + \epsilon} + \lambda p) \\
\quad \\ \text{where}\quad
\begin{cases}
\text{correction} = \sqrt{1 - \beta_{2}^{t}} / (1 - \beta_{1}^{t}) \\
m_{t} = \beta_{1} * m_{t-1} + (1 - \beta_{1}) * g \\
v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2}
v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2} \\
\end{cases}
"""
......
......@@ -27,13 +27,13 @@ class RMSprop(optimizer.Optimizer):
\text{RMSprop}(g) = \text{lr} * m_{t} \\
\quad \\ \text{where} \quad
\begin{cases}
v_{t} = \text{decay} * v_{t-1} + (1 - \text{decay}) * g^{2} \\
v_{t} = \alpha * v_{t-1} + (1 - \alpha) * g^{2} \\
m_{t} = \text{momentum} * m_{t-1} + \frac{g}{\sqrt{v_{t}} + \epsilon}
\end{cases}
"""
def __init__(self, lr=0.01, momentum=0, decay=0.9, eps=1e-8, **kwargs):
def __init__(self, lr=0.01, momentum=0, alpha=0.9, eps=1e-8, **kwargs):
r"""Create a ``RMSProp`` optimizer.
Parameters
......@@ -42,8 +42,8 @@ class RMSprop(optimizer.Optimizer):
The initial value to :math:`\text{lr}`.
momentum : float, optional, default=0
The initial value to :math:`\text{momentum}`.
decay : float, optional, default=0.9
The initial value to :math:`\text{decay}`.
alpha : float, optional, default=0.9
The initial value to :math:`\alpha`.
eps : float, optional, default=1e-8
The initial value to :math:`\epsilon`.
......@@ -51,5 +51,5 @@ class RMSprop(optimizer.Optimizer):
super(RMSprop, self).__init__(**kwargs)
self._set_hyper('lr', lr)
self._set_hyper('momentum', momentum)
self._set_hyper('decay', decay)
self._set_hyper('alpha', alpha)
self._set_hyper('eps', eps)
......@@ -51,41 +51,6 @@ def cast_exporter(op_def, context):
return node, const_tensors
@export_util.register('ChannelAffine')
def channel_affine_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelAffine')
for arg in op_def.arg:
if arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
elif arg.name == 'end_axis':
helper.add_attribute(node, 'end_axis', arg.i)
return node, const_tensors
@export_util.register('ChannelNormalize')
def channel_normalize_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelNormalize')
for arg in op_def.arg:
if arg.name == 'mean':
helper.add_attribute(node, 'mean', arg.floats)
elif arg.name == 'std':
helper.add_attribute(node, 'std', arg.floats)
elif arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
elif arg.name == 'dtype':
helper.add_attribute(node, 'dtype', arg.s)
elif arg.name == 'perm':
helper.add_attribute(node, 'perm', arg.ints)
elif arg.name == 'perm_desc':
values = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'perm', values)
return node, const_tensors
@export_util.register('Concat')
def concat_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
......
......@@ -31,6 +31,17 @@ def add_exporter(op_def, context):
return node, const_tensors
@export_util.register('Affine')
def affine_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'Affine')
for arg in op_def.arg:
if arg.name == 'axes':
helper.add_attribute(node, 'axes', arg.ints)
return node, const_tensors
@export_util.register('Div')
def div_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
......
......@@ -32,6 +32,28 @@ def batch_norm_exporter(op_def, context):
return node, const_tensors
@export_util.register('ChannelNorm')
def channel_norm_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelNorm')
for arg in op_def.arg:
if arg.name == 'mean':
helper.add_attribute(node, 'mean', arg.floats)
elif arg.name == 'std':
helper.add_attribute(node, 'std', arg.floats)
elif arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
elif arg.name == 'dtype':
helper.add_attribute(node, 'dtype', arg.s)
elif arg.name == 'perm':
helper.add_attribute(node, 'perm', arg.ints)
elif arg.name == 'perm_desc':
values = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'perm', values)
return node, const_tensors
@export_util.register('GroupNorm')
def group_norm_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
......@@ -49,8 +71,8 @@ def group_norm_exporter(op_def, context):
return node, const_tensors
@export_util.register('LpNormalize')
def lp_normalize_exporter(op_def, context):
@export_util.register('LpNorm')
def lp_norm_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'LpNormalization'
axis, end_axis = None, None
......
......@@ -33,9 +33,6 @@ constexpr int CUDA_WARP_SIZE = 32;
/*! \brief The number of cuda threads in a block */
constexpr int CUDA_THREADS = 256;
/*! \brief The maximum number of blocks to use in a default kernel call */
constexpr int CUDA_MAX_BLOCKS = 4096;
/*! \brief The maximum number of devices in a single machine */
constexpr int CUDA_MAX_DEVICES = 16;
......@@ -82,12 +79,15 @@ constexpr int CUDA_TENSOR_MAX_DIMS = 8;
for (size_t j = threadIdx.x; j < m; j += blockDim.x)
inline int CUDA_BLOCKS(const int N) {
return std::max(
std::min((N + CUDA_THREADS - 1) / CUDA_THREADS, CUDA_MAX_BLOCKS), 1);
}
inline int CUDA_2D_BLOCKS(const int N) {
return std::max(std::min(N, CUDA_MAX_BLOCKS), 1);
int device, sm_count, threads_per_sm;
CUDA_CHECK(cudaGetDevice(&device));
CUDA_CHECK(cudaDeviceGetAttribute(
&sm_count, cudaDevAttrMultiProcessorCount, device));
CUDA_CHECK(cudaDeviceGetAttribute(
&threads_per_sm, cudaDevAttrMaxThreadsPerMultiProcessor, device));
const auto num_blocks = (N + CUDA_THREADS - 1) / CUDA_THREADS;
const auto max_blocks = sm_count * threads_per_sm / CUDA_THREADS * 32;
return std::max(1, std::min(num_blocks, max_blocks));
}
#if CUDA_VERSION_MAX(9, 0)
......
......@@ -84,6 +84,7 @@ DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Sub, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Mul, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Div, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Pow, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Atan2, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Minimum, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Maximum, T);
DECLARE_ROWWISE_COLWISE_BINARY_FUNC(Equal, bool);
......@@ -434,6 +435,7 @@ DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(And, bool, std::logical_and);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Or, bool, std::logical_or);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Xor, bool, math::XorFunctor);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Pow, T, math::PowFunctor);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Atan2, T, math::Atan2Functor);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Minimum, T, math::MinFunctor);
DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Maximum, T, math::MaxFunctor);
#undef DEFINE_ROWWISE_COLWISE_BIANRY_FUNC
......@@ -469,6 +471,7 @@ DEFINE_BROADCAST_BINARY_FUNC(And, bool, std::logical_and);
DEFINE_BROADCAST_BINARY_FUNC(Or, bool, std::logical_or);
DEFINE_BROADCAST_BINARY_FUNC(Xor, bool, math::XorFunctor);
DEFINE_BROADCAST_BINARY_FUNC(Pow, T, math::PowFunctor);
DEFINE_BROADCAST_BINARY_FUNC(Atan2, T, math::Atan2Functor);
DEFINE_BROADCAST_BINARY_FUNC(Minimum, T, math::MinFunctor);
DEFINE_BROADCAST_BINARY_FUNC(Maximum, T, math::MaxFunctor);
#undef DEFINE_BROADCAST_BINARY_FUNC
......@@ -612,6 +615,9 @@ DEFINE_BINARY_FUNC(Div, float, float);
DEFINE_BINARY_FUNC(Div, double, double);
DEFINE_BINARY_FUNC(Pow, float, float);
DEFINE_BINARY_FUNC(Pow, double, double);
DEFINE_BINARY_FUNC(Atan2, float16, float16);
DEFINE_BINARY_FUNC(Atan2, float, float);
DEFINE_BINARY_FUNC(Atan2, double, double);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t);
DEFINE_BINARY_FUNC(Minimum, int, int);
......
......@@ -388,6 +388,9 @@ DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Atan2, float16, float16, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, float, float, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, double, double, math::Atan2Functor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
......
......@@ -122,6 +122,17 @@ DRAGON_API void Pow(
Context* ctx);
template <typename T, class Context>
DRAGON_API void Atan2(
const int A_ndim,
const int64_t* A_dims,
const int B_ndim,
const int64_t* B_dims,
const T* a,
const T* b,
T* y,
Context* ctx);
template <typename T, class Context>
DRAGON_API void Minimum(
const int A_ndim,
const int64_t* A_dims,
......
......@@ -550,6 +550,9 @@ DEFINE_BINARY_FUNC(Maximum, double, double, max);
_SimpleBinaryFunc(N, Functor<InputT>(), a, b, y); \
}
DEFINE_BINARY_FUNC(Atan2, float16, float16, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, float, float, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, double, double, math::Atan2Functor);
DEFINE_BINARY_FUNC(BitwiseAnd, bool, bool, std::bit_and);
DEFINE_BINARY_FUNC(BitwiseAnd, uint8_t, uint8_t, std::bit_and);
DEFINE_BINARY_FUNC(BitwiseAnd, int8_t, int8_t, std::bit_and);
......
......@@ -342,7 +342,10 @@ _Where(const int N, const T* a, const T* b, const bool* c, T* y) {
DRAGON_API void name<InputT, CUDAContext>( \
const int N, const InputT* x, OutputT* y, CUDAContext* ctx) { \
_SimpleUnaryFunc<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, Functor<InputT>(), x, y); \
N, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(x), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
}
DEFINE_UNARY_FUNC(BitwiseNot, bool, bool, math::BitNotFunctor);
......@@ -706,6 +709,87 @@ DEFINE_APPLY_MASK_FUNC(float, float);
DEFINE_APPLY_MASK_FUNC(double, double);
#undef DEFINE_APPLY_MASK_FUNC
#define DEFINE_BINARY_FUNC(name, T, Functor) \
template <> \
DRAGON_API void name<T, CUDAContext>( \
const int N, const T* a, const T* b, T* y, CUDAContext* ctx) { \
using ScalarT = typename math::ScalarType<T>::type; \
using ScalarT2 = typename math::ScalarType<T>::type2; \
if ((N & 1) == 0 && sizeof(ScalarT) != sizeof(ScalarT2)) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(N >> 1), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N >> 1, \
Functor<ScalarT2>(), \
reinterpret_cast<const ScalarT2*>(a), \
reinterpret_cast<const ScalarT2*>(b), \
reinterpret_cast<ScalarT2*>(y)); \
} else { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(N), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N, \
Functor<ScalarT>(), \
reinterpret_cast<const ScalarT*>(a), \
reinterpret_cast<const ScalarT*>(b), \
reinterpret_cast<ScalarT*>(y)); \
} \
}
DEFINE_BINARY_FUNC(Add, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float16, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float16, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float16, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Atan2, float16, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, float, math::Atan2Functor);
DEFINE_BINARY_FUNC(Atan2, double, math::Atan2Functor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, math::MaxFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, InputT, OutputT, Functor) \
template <> \
DRAGON_API void name<InputT, CUDAContext>( \
......@@ -726,51 +810,6 @@ DEFINE_APPLY_MASK_FUNC(double, double);
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
}
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float16, float16, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, float16, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float16, float16, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float16, float16, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, float16, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, float16, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor);
DEFINE_BINARY_FUNC(BitwiseAnd, bool, bool, math::BitAndFunctor);
DEFINE_BINARY_FUNC(BitwiseAnd, uint8_t, uint8_t, math::BitAndFunctor);
DEFINE_BINARY_FUNC(BitwiseAnd, int8_t, int8_t, math::BitAndFunctor);
......
......@@ -126,6 +126,9 @@ template <typename T, class Context>
DRAGON_API void Pow(const int N, const T* a, const T* b, T* y, Context* ctx);
template <typename T, class Context>
DRAGON_API void Atan2(const int N, const T* a, const T* b, T* y, Context* ctx);
template <typename T, class Context>
DRAGON_API void
Minimum(const int N, const T* a, const T* b, T* y, Context* ctx);
......
......@@ -108,13 +108,11 @@ void _GenericReduce(
} \
if (math::utils::IsRowwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_RowwiseReduce##name(rows, cols, scale, x, y); \
return; \
return _RowwiseReduce##name(rows, cols, scale, x, y); \
} \
if (math::utils::IsColwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_ColwiseReduce##name(rows, cols, scale, x, y); \
return; \
return _ColwiseReduce##name(rows, cols, scale, x, y); \
} \
vec64_t transpose_axes(num_dims); \
vec64_t transpose_strides(num_dims); \
......
#include "dragon/utils/math/transform.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/math/reduce.h"
namespace dragon {
namespace math {
namespace {
template <typename T>
void _AffineChannel(
const int N,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y) {
EigenArrayMap<T> Y(y, C, N);
ConstEigenArrayMap<T> X(x, C, N);
Y = X.colwise() * ConstEigenVectorArrayMap<T>(scale, C);
if (bias != nullptr) {
Y.colwise() += ConstEigenVectorArrayMap<T>(bias, C);
}
}
template <typename T>
void _AffineChannel(
const int N,
const int C,
const int S,
const T* x,
const T* scale,
const T* bias,
T* y) {
const auto CxS = C * S;
for (int i = 0; i < N; ++i) {
EigenArrayMap<T> Y(y + i * CxS, S, C);
ConstEigenArrayMap<T> X(x + i * CxS, S, C);
Y = X.rowwise() * ConstEigenVectorArrayMap<T>(scale, C).transpose();
if (bias != nullptr) {
Y.rowwise() += ConstEigenVectorArrayMap<T>(bias, C).transpose();
}
}
}
template <typename T>
void _AffineImpl(
const int num_dims,
const int64_t* dims,
const int num_axes,
const int64_t* axes,
const T* x,
const T* scale,
const T* bias,
T* y) {
if (num_dims == 2 && num_axes == 1 && axes[0] == 1) {
_AffineChannel(dims[0], dims[1], x, scale, bias, y);
} else if (num_dims == 3 && num_axes == 1 && axes[0] == 1) {
_AffineChannel(dims[0], dims[1], dims[2], x, scale, bias, y);
} else {
LOG(FATAL) << "Unsupported affine dimensions.";
}
}
} // namespace
template <>
void Affine<float16, CPUContext>(
const int num_dims,
const int64_t* dims,
const int num_axes,
const int64_t* axes,
const float16* x,
const float16* scale,
const float16* bias,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_AFFINE_FUNC(T) \
template <> \
void Affine<T, CPUContext>( \
const int num_dims, \
const int64_t* dims, \
const int num_axes, \
const int64_t* axes, \
const T* x, \
const T* scale, \
const T* bias, \
T* y, \
CPUContext* ctx) { \
vec64_t new_dims, new_axes; \
math::utils::CollapseReduceAxes( \
num_dims, dims, num_axes, axes, new_dims, new_axes); \
_AffineImpl( \
new_dims.size(), \
new_dims.data(), \
new_axes.size(), \
new_axes.data(), \
x, \
scale, \
bias, \
y); \
}
DEFINE_AFFINE_FUNC(float);
DEFINE_AFFINE_FUNC(double);
#undef DEFINE_AFFINE_FUNC
} // namespace math
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/transform.h"
#include "dragon/utils/math/types.h"
#include "dragon/utils/math/utils.h"
namespace dragon {
namespace math {
namespace {
template <typename T>
__global__ void _AffineChannel(
const int NxC,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y) {
auto op3 = math::FMAFunctor<T>();
auto op2 = math::MultipliesFunctor<T>();
CUDA_1D_KERNEL_LOOP(i, NxC) {
if (bias != nullptr) {
y[i] = op3(x[i], __ldg(scale + i % C), __ldg(bias + i % C));
} else {
y[i] = op2(x[i], __ldg(scale + i % C));
}
}
}
template <typename T>
__global__ void _AffineChannel(
const int NxCxS,
const int C,
const int S,
const T* x,
const T* scale,
const T* bias,
T* y) {
auto op3 = math::FMAFunctor<T>();
auto op2 = math::MultipliesFunctor<T>();
CUDA_1D_KERNEL_LOOP(i, NxCxS) {
const int j = (i / S) % C;
if (bias != nullptr) {
y[i] = op3(x[i], __ldg(scale + j), __ldg(bias + j));
} else {
y[i] = op2(x[i], __ldg(scale + j));
}
}
}
template <typename T>
void _AffineImpl(
const int num_dims,
const int64_t* dims,
const int num_axes,
const int64_t* axes,
const T* x,
const T* scale,
const T* bias,
T* y,
CUDAContext* ctx) {
const auto N = math::utils::Prod(num_dims, dims);
if (num_dims == 2 && num_axes == 1 && axes[0] == 1) {
_AffineChannel<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, dims[1], x, scale, bias, y);
} else if (num_dims == 3 && num_axes == 1 && axes[0] == 1) {
_AffineChannel<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, dims[1], dims[2], x, scale, bias, y);
} else {
LOG(FATAL) << "Unsupported affine dimensions.";
}
}
} // namespace
#define DEFINE_AFFINE_FUNC(T) \
template <> \
void Affine<T, CUDAContext>( \
const int num_dims, \
const int64_t* dims, \
const int num_axes, \
const int64_t* axes, \
const T* x, \
const T* scale, \
const T* bias, \
T* y, \
CUDAContext* ctx) { \
vec64_t new_dims, new_axes; \
math::utils::CollapseReduceAxes( \
num_dims, dims, num_axes, axes, new_dims, new_axes); \
_AffineImpl( \
new_dims.size(), \
new_dims.data(), \
new_axes.size(), \
new_axes.data(), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<const math::ScalarType<T>::type*>(scale), \
reinterpret_cast<const math::ScalarType<T>::type*>(bias), \
reinterpret_cast<math::ScalarType<T>::type*>(y), \
ctx); \
}
DEFINE_AFFINE_FUNC(float);
DEFINE_AFFINE_FUNC(float16);
DEFINE_AFFINE_FUNC(double);
#undef DEFINE_AFFINE_FUNC
} // namespace math
} // namespace dragon
#endif // USE_CUDA
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_MATH_TRANSFORM_H_
#define DRAGON_UTILS_MATH_TRANSFORM_H_
#include "dragon/core/context.h"
namespace dragon {
namespace math {
template <typename T, class Context>
DRAGON_API void Affine(
const int num_dims,
const int64_t* dims,
const int num_axes,
const int64_t* axes,
const T* x,
const T* scale,
const T* bias,
T* y,
Context* ctx);
} // namespace math
} // namespace dragon
#endif // DRAGON_UTILS_MATH_TRANSFORM_H_
......@@ -141,8 +141,7 @@ void _TransposeImpl(
CUDAContext* ctx) {
auto aligned_size = sizeof(T);
if (axes.back() == D - 1) {
const auto N = math::utils::Prod(D, dims.data());
aligned_size = utils::GetAlignedSize<T, 16>(N, x, y);
aligned_size = utils::GetAlignedSize<T, 16>(dims[D - 1], x, y);
}
SimpleArray<int, D> X_dims, X_strides, Y_dims;
for (int i = 0; i < D; ++i) {
......
......@@ -27,6 +27,7 @@ template <typename T>
class ScalarType {
public:
typedef T type;
typedef T type2;
};
#if defined(__CUDACC__)
......@@ -34,6 +35,7 @@ template <>
class ScalarType<float16> {
public:
typedef half type;
typedef half2 type2;
};
#endif
......
......@@ -16,9 +16,9 @@
#include "dragon/utils/conversions.h"
#if defined(__CUDACC__)
#define MATH_UTILS_DECL inline __host__ __device__
#define HOSTDEVICE_DECL inline __host__ __device__
#else
#define MATH_UTILS_DECL inline
#define HOSTDEVICE_DECL inline
#endif
#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) \
......@@ -41,28 +41,28 @@ namespace utils {
template <
typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
MATH_UTILS_DECL T IsInf(const T x) {
HOSTDEVICE_DECL T IsInf(const T x) {
return false;
}
template <
typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
MATH_UTILS_DECL T IsNaN(const T x) {
HOSTDEVICE_DECL T IsNaN(const T x) {
return false;
}
template <
typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
MATH_UTILS_DECL T IsFinite(const T x) {
HOSTDEVICE_DECL T IsFinite(const T x) {
return true;
}
template <
typename T,
typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
MATH_UTILS_DECL bool IsInf(T x) {
HOSTDEVICE_DECL bool IsInf(T x) {
#if defined(__CUDACC__)
return isinf(x);
#else
......@@ -73,7 +73,7 @@ MATH_UTILS_DECL bool IsInf(T x) {
template <
typename T,
typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
MATH_UTILS_DECL bool IsNaN(T x) {
HOSTDEVICE_DECL bool IsNaN(T x) {
#if defined(__CUDACC__)
return isnan(x);
#else
......@@ -84,7 +84,7 @@ MATH_UTILS_DECL bool IsNaN(T x) {
template <
typename T,
typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
MATH_UTILS_DECL bool IsFinite(T x) {
HOSTDEVICE_DECL bool IsFinite(T x) {
#if defined(__CUDACC__)
return isfinite(x);
#else
......@@ -106,27 +106,27 @@ inline bool IsFinite(float16 x) {
}
template <typename T>
MATH_UTILS_DECL bool IsAGeZeroAndALtB(const T a, const T b) {
HOSTDEVICE_DECL bool IsAGeZeroAndALtB(const T a, const T b) {
return static_cast<unsigned int>(a) < static_cast<unsigned int>(b);
}
template <typename T>
MATH_UTILS_DECL T Sign(const T x) {
HOSTDEVICE_DECL T Sign(const T x) {
return x > T(0) ? T(1) : (x < T(0) ? T(-1) : T(0));
}
template <typename T>
MATH_UTILS_DECL T Identity(const T x) {
HOSTDEVICE_DECL T Identity(const T x) {
return x;
}
template <typename T>
MATH_UTILS_DECL T Square(const T x) {
HOSTDEVICE_DECL T Square(const T x) {
return x * x;
}
template <typename T>
MATH_UTILS_DECL T Cube(const T x) {
HOSTDEVICE_DECL T Cube(const T x) {
return x * x * x;
}
......@@ -247,4 +247,6 @@ void IncreaseIndexInDims(const int num_dims, const DimT* dims, IndexT* index) {
} // namespace dragon
#undef HOSTDEVICE_DECL
#endif // DRAGON_UTILS_MATH_UTILS_H_
......@@ -21,6 +21,7 @@
#include "dragon/utils/math/random.h"
#include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/sort.h"
#include "dragon/utils/math/transform.h"
#include "dragon/utils/math/transpose.h"
#include "dragon/utils/math/types.h"
#include "dragon/utils/math/utils.h"
......
......@@ -284,39 +284,6 @@ void BooleanMaskGrad(
Context* ctx);
template <typename T, class Context>
void ChannelAffine(
const int N,
const int S,
const int C,
const T* x,
const T* scale,
const T* bias,
T* y,
Context* ctx);
template <typename InputT, typename OutputT, class Context>
void ChannelNormalize(
const int axis,
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const InputT* x,
const float* mean,
const float* std,
OutputT* y,
Context* ctx);
template <typename T, class Context>
void ChannelShuffle(
const int N,
const int S,
const int C,
const int G,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void ConstPad(
const int num_dims,
const int64_t* x_dims,
......@@ -813,6 +780,18 @@ void TopK(
* NormalizationOp Kernels
*/
template <typename InputT, typename OutputT, class Context>
void ChannelNorm(
const int axis,
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const InputT* x,
const float* mean,
const float* std,
OutputT* y,
Context* ctx);
template <typename T, typename AccT, class Context>
void BatchNormExpectation(
const int N,
......@@ -923,7 +902,7 @@ void GroupNormGrad(
Context* ctx);
template <typename T, class Context>
void L1Normalize(
void L1Norm(
const int N,
const int S,
const int C,
......@@ -934,7 +913,7 @@ void L1Normalize(
Context* ctx);
template <typename T, class Context>
void L1NormalizeGrad(
void L1NormGrad(
const int N,
const int S,
const int C,
......@@ -946,7 +925,7 @@ void L1NormalizeGrad(
Context* ctx);
template <typename T, class Context>
void L2Normalize(
void L2Norm(
const int N,
const int S,
const int C,
......@@ -957,7 +936,7 @@ void L2Normalize(
Context* ctx);
template <typename T, class Context>
void L2NormalizeGrad(
void L2NormGrad(
const int N,
const int S,
const int C,
......@@ -1012,19 +991,23 @@ void LSTMCellGrad(
* TrainingOp Kernels
*/
template <typename T, class Context>
template <typename T, typename CopyT, class Context>
void Adam(
const int N,
const float lr,
const float beta1,
const float beta2,
const float eps,
T* g,
const float wd,
const T* x,
const T* g,
T* m,
T* v,
T* y,
CopyT* y_copy,
Context* ctx);
template <typename T, class Context>
template <typename T, typename CopyT, class Context>
void AdamW(
const int N,
const float lr,
......@@ -1033,39 +1016,53 @@ void AdamW(
const float eps,
const float wd,
const T* x,
T* g,
const T* g,
T* m,
T* v,
T* y,
CopyT* y_copy,
Context* ctx);
template <typename T, class Context>
template <typename T, typename CopyT, class Context>
void MomentumSGD(
const int N,
const float lr,
const float momentum,
T* g,
const float wd,
const T* x,
const T* g,
T* m,
T* y,
CopyT* y_copy,
Context* ctx);
template <typename T, class Context>
template <typename T, typename CopyT, class Context>
void NesterovSGD(
const int N,
const float lr,
const float momentum,
T* g,
const float wd,
const T* x,
const T* g,
T* m,
T* y,
CopyT* y_copy,
Context* ctx);
template <typename T, class Context>
template <typename T, typename CopyT, class Context>
void RMSprop(
const int N,
const float lr,
const float momentum,
const float decay,
const float alpha,
const float eps,
T* g,
const float wd,
const T* x,
const T* g,
T* m,
T* v,
T* y,
CopyT* y_copy,
Context* ctx);
/*
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!