Commit bdf4e10f by Ting PAN

Add GELU operator

Summary:
This commit adds GELU activation to compute output
via approximate or naive mode.
1 parent 43a82e77
Showing with 1860 additions and 329 deletions
......@@ -136,6 +136,9 @@ dragon
`reshape(...) <dragon/reshape.html>`_
: Change the dimensions of input.
`roll(...) <dragon/roll.html>`_
: Roll elements along the given axis.
`scatter_add(...) <dragon/scatter_add.html>`_
: Add elements along the given axis of index.
......@@ -234,6 +237,7 @@ dragon
dragon/repeat
dragon/reset_workspace
dragon/reshape
dragon/roll
dragon/scatter_add
dragon/scatter_elements
dragon/set_num_threads
......
......@@ -74,6 +74,10 @@ dragon.nn
: Apply the exponential linear unit.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`gelu(...) <nn/gelu.html>`_
: Apply the gaussian error linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`group_norm(...) <nn/group_norm.html>`_
: Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
......@@ -131,15 +135,15 @@ dragon.nn
: Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`silu(...) <nn/silu.html>`_
: Apply the sigmoid linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`softmax(...) <nn/softmax.html>`_
: Compute the softmax result.
`space_to_depth(...) <nn/space_to_depth.html>`_
: Rearrange blocks of spatial data into depth.
`swish(...) <nn/swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`sync_batch_norm(...) <nn/sync_batch_norm.html>`_
: Apply the batch normalization with synced statistics.
......@@ -167,6 +171,7 @@ dragon.nn
nn/drop_block
nn/drop_path
nn/elu
nn/gelu
nn/group_norm
nn/hardsigmoid
nn/hardswish
......@@ -184,9 +189,9 @@ dragon.nn
nn/relu
nn/relu6
nn/selu
nn/silu
nn/softmax
nn/space_to_depth
nn/swish
nn/sync_batch_norm
.. raw:: html
......
swish
=====
gelu
====
.. autofunction:: dragon.nn.swish
.. autofunction:: dragon.nn.gelu
.. raw:: html
......
silu
====
.. autofunction:: dragon.nn.silu
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
......@@ -10,6 +10,10 @@ dragon.optimizers
: The optimizer to apply Adam algorithm.
`[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_.
`class AdamW <optimizers/AdamW.html>`_
: The optimizer to apply AdamW algorithm.
`[Loshchilov & Hutter, 2017] <https://arxiv.org/abs/1711.05101>`_.
`class Nesterov <optimizers/Nesterov.html>`_
: The optimizer to apply NesterovSGD algorithm.
`[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_.
......@@ -23,19 +27,20 @@ dragon.optimizers
`[Polyak, 1964] <https://doi.org/10.1016/0041-5553(64)90137-5>`_.
.. toctree::
:hidden:
:hidden:
optimizers/Adam
optimizers/Nesterov
optimizers/Optimizer
optimizers/RMSprop
optimizers/SGD
optimizers/Adam
optimizers/AdamW
optimizers/Nesterov
optimizers/Optimizer
optimizers/RMSprop
optimizers/SGD
.. raw:: html
<style>
h1:before {
h1:before {
content: "Module: ";
color: #103d3e;
}
}
</style>
AdamW
=====
.. autoclass:: dragon.optimizers.AdamW
__init__
--------
.. automethod:: dragon.optimizers.AdamW.__init__
Methods
-------
apply_gradients
################
.. automethod:: dragon.optimizers.Optimizer.apply_gradients
:noindex:
.. raw:: html
<style>
h1:before {
content: "dragon.optimizers.";
color: #103d3e;
}
</style>
roll
====
.. autofunction:: dragon.roll
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
......@@ -87,6 +87,9 @@ vm.tensorflow
`reshape(...) <tensorflow/reshape.html>`_
: Change the dimensions of input.
`roll(...) <tensorflow/roll.html>`_
: Roll elements along the given axis.
`shape(...) <tensorflow/shape.html>`_
: Return the shape of input.
......@@ -149,6 +152,7 @@ vm.tensorflow
tensorflow/pad
tensorflow/range
tensorflow/reshape
tensorflow/roll
tensorflow/shape
tensorflow/slice
tensorflow/sort
......
......@@ -64,6 +64,10 @@ vm.tensorflow.nn
: Apply the exponential exponential linear unit to input.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`gelu(...) <nn/gelu.html>`_
: Apply the gaussian error linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`leaky_relu(...) <nn/leaky_relu.html>`_
: Apply the leaky rectified linear unit.
......@@ -101,6 +105,10 @@ vm.tensorflow.nn
: Apply the scaled exponential linear unit.
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`silu(...) <nn/silu.html>`_
: Apply the sigmoid linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`softmax(...) <nn/softmax.html>`_
: Apply the softmax function.
......@@ -138,6 +146,7 @@ vm.tensorflow.nn
nn/depth_to_space
nn/dropout
nn/elu
nn/gelu
nn/leaky_relu
nn/local_response_normalization
nn/log_softmax
......@@ -149,6 +158,7 @@ vm.tensorflow.nn
nn/relu
nn/relu6
nn/selu
nn/silu
nn/softmax
nn/softmax_cross_entropy_with_logits
nn/space_to_depth
......
gelu
====
.. autofunction:: dragon.vm.tensorflow.nn.gelu
.. raw:: html
<style>
h1:before {
content: "tf.nn.";
color: #103d3e;
}
</style>
silu
====
.. autofunction:: dragon.vm.tensorflow.nn.silu
.. raw:: html
<style>
h1:before {
content: "tf.nn.";
color: #103d3e;
}
</style>
roll
====
.. autofunction:: dragon.vm.tensorflow.roll
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
......@@ -81,10 +81,6 @@ vm.torch
`channel_normalize(...) <torch/channel_normalize.html>`_
: Apply normalization to each channel of input.
`channel_shuffle(...) <torch/channel_shuffle.html>`_
: Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`chunk(...) <torch/chunk.html>`_
: Split input into a specific number of chunks.
......@@ -244,6 +240,9 @@ vm.torch
`reshape(...) <torch/reshape.html>`_
: Change the shape of input.
`roll(...) <torch/roll.html>`_
: Roll elements along the given dimension.
`round(...) <torch/round.html>`_
: Compute the nearest integer of input.
......@@ -338,7 +337,6 @@ vm.torch
torch/ceil
torch/channel_affine
torch/channel_normalize
torch/channel_shuffle
torch/chunk
torch/clamp
torch/cos
......@@ -396,6 +394,7 @@ vm.torch
torch/randperm
torch/reciprocal
torch/reshape
torch/roll
torch/round
torch/rsqrt
torch/scatter
......
......@@ -473,6 +473,10 @@ retain_grad
###########
.. automethod:: dragon.vm.torch.Tensor.retain_grad
roll
####
.. automethod:: dragon.vm.torch.Tensor.roll
round
#####
.. automethod:: dragon.vm.torch.Tensor.round
......@@ -675,6 +679,7 @@ zero\_
.. _torch.pow(...): pow.html
.. _torch.reciprocal(...): reciprocal.html
.. _torch.reshape(...): reshape.html
.. _torch.roll(...): roll.html
.. _torch.round(...): round.html
.. _torch.rsqrt(...): rsqrt.html
.. _torch.scatter(...): scatter.html
......
......@@ -51,6 +51,10 @@ vm.torch.nn
`class BCEWithLogitsLoss <nn/BCEWithLogitsLoss.html>`_
: Compute the sigmoid cross entropy with contiguous targets.
`class ChannelShuffle <nn/ChannelShuffle.html>`_
: Apply group shuffle to each channel.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`class ConstantPad1d <nn/ConstantPad1d.html>`_
: Pad input according to the last dimension with a constant.
......@@ -108,6 +112,10 @@ vm.torch.nn
`class Flatten <nn/Flatten.html>`_
: Flatten the dimensions of input.
`class GELU <nn/GELU.html>`_
: Apply the gaussian error linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`class GroupNorm <nn/GroupNorm.html>`_
: Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
......@@ -237,6 +245,10 @@ vm.torch.nn
: Compute the sigmoid focal loss with sparse labels.
`[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__.
`class SiLU <nn/SiLU.html>`_
: Apply the sigmoid linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`class SmoothL1Loss <nn/SmoothL1Loss.html>`_
: Compute the element-wise error transited from L1 and L2.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
......@@ -244,10 +256,6 @@ vm.torch.nn
`class Softmax <nn/Softmax.html>`_
: Apply the softmax function.
`class Swish <nn/Swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`class Tanh <nn/Tanh.html>`_
: Apply the tanh function.
......@@ -300,6 +308,7 @@ vm.torch.nn
nn/BatchNorm2d
nn/BatchNorm3d
nn/BCEWithLogitsLoss
nn/ChannelShuffle
nn/ConstantPad1d
nn/ConstantPad2d
nn/ConstantPad3d
......@@ -317,6 +326,7 @@ vm.torch.nn
nn/DropPath
nn/ELU
nn/Flatten
nn/GELU
nn/GroupNorm
nn/GRU
nn/GumbelSoftmax
......@@ -355,9 +365,9 @@ vm.torch.nn
nn/Sequential
nn/Sigmoid
nn/SigmoidFocalLoss
nn/SiLU
nn/SmoothL1Loss
nn/Softmax
nn/Swish
nn/Tanh
nn/TransformerDecoder
nn/TransformerDecoderLayer
......
Swish
=====
ChannelShuffle
==============
.. autoclass:: dragon.vm.torch.nn.Swish
.. autoclass:: dragon.vm.torch.nn.ChannelShuffle
__init__
--------
.. automethod:: dragon.vm.torch.nn.Swish.__init__
.. automethod:: dragon.vm.torch.nn.ChannelShuffle.__init__
.. _torch.nn.functional.swish(...): functional/swish.html
.. _torch.nn.functional.channel_shuffle(...): functional/channel_shuffle.html
.. raw:: html
......
GELU
====
.. autoclass:: dragon.vm.torch.nn.GELU
__init__
--------
.. automethod:: dragon.vm.torch.nn.GELU.__init__
.. _torch.nn.functional.gelu(...): functional/gelu.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
SiLU
====
.. autoclass:: dragon.vm.torch.nn.SiLU
__init__
--------
.. automethod:: dragon.vm.torch.nn.SiLU.__init__
.. _torch.nn.functional.silu(...): functional/silu.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
......@@ -40,6 +40,10 @@ vm.torch.nn.functional
`binary_cross_entropy_with_logits(...) <functional/binary_cross_entropy_with_logits.html>`_
: Compute the sigmoid cross entropy with contiguous target.
`channel_shuffle(...) <functional/channel_shuffle.html>`_
: Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`conv1d(...) <functional/conv1d.html>`_
: Apply the 1d convolution to input.
......@@ -85,6 +89,10 @@ vm.torch.nn.functional
: Apply the exponential linear unit to input.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`gelu(...) <functional/gelu.html>`_
: Apply the gaussian error linear unit to input.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`group_norm(...) <functional/group_norm.html>`_
: Apply the group normalization to input.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
......@@ -163,6 +171,17 @@ vm.torch.nn.functional
: Compute the sigmoid focal loss with sparse labels.
`[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__.
`sigmoid(...) <functional/sigmoid.html>`_
: Apply the sigmoid function to input.
`sigmoid_focal_loss(...) <functional/sigmoid_focal_loss.html>`_
: Compute the sigmoid focal loss with sparse labels.
`[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`__.
`silu(...) <functional/silu.html>`_
: Apply the sigmoid linear unit to input.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
`smooth_l1_loss(...) <functional/smooth_l1_loss.html>`_
: Compute the element-wise error transited from L1 and L2.
`[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
......@@ -170,10 +189,6 @@ vm.torch.nn.functional
`softmax(...) <functional/softmax.html>`_
: Apply the softmax function to input.
`swish(...) <functional/swish.html>`_
: Apply the swish function to input.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`sync_batch_norm(...) <functional/sync_batch_norm.html>`_
: Apply the sync batch normalization to input.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
......@@ -204,6 +219,7 @@ vm.torch.nn.functional
functional/avg_pool3d
functional/batch_norm
functional/binary_cross_entropy_with_logits
functional/channel_shuffle
functional/conv1d
functional/conv2d
functional/conv3d
......@@ -217,6 +233,7 @@ vm.torch.nn.functional
functional/drop_path
functional/dropout
functional/elu
functional/gelu
functional/group_norm
functional/hardsigmoid
functional/hardswish
......@@ -242,9 +259,9 @@ vm.torch.nn.functional
functional/selu
functional/sigmoid
functional/sigmoid_focal_loss
functional/silu
functional/smooth_l1_loss
functional/softmax
functional/swish
functional/sync_batch_norm
functional/tanh
functional/upsample
......
channel_shuffle
===============
.. autofunction:: dragon.vm.torch.channel_shuffle
.. autofunction:: dragon.vm.torch.nn.functional.channel_shuffle
.. _torch.nn.ChannelShuffle(...): ../ChannelShuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.";
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
swish
=====
gelu
====
.. autofunction:: dragon.vm.torch.nn.functional.swish
.. autofunction:: dragon.vm.torch.nn.functional.gelu
.. _torch.nn.Swish(...): ../Swish.html
.. _torch.nn.GELU(...): ../GELU.html
.. raw:: html
......
silu
====
.. autofunction:: dragon.vm.torch.nn.functional.silu
.. _torch.nn.SiLU(...): ../SiLU.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
......@@ -8,6 +8,11 @@ vm.torch.optim
`class Adam <optim/Adam.html>`_
: The optimizer to apply Adam algorithm.
`[Kingma & Ba, 2014] <https://arxiv.org/abs/1412.6980>`_.
`class AdamW <optim/AdamW.html>`_
: The optimizer to apply AdamW algorithm.
`[Loshchilov & Hutter, 2017] <https://arxiv.org/abs/1711.05101>`_.
`class Optimizer <optim/Optimizer.html>`_
: The base class of optimizers.
......@@ -23,6 +28,7 @@ vm.torch.optim
:hidden:
optim/Adam
optim/AdamW
optim/Optimizer
optim/RMSprop
optim/SGD
......
AdamW
=====
.. autoclass:: dragon.vm.torch.optim.AdamW
__init__
--------
.. automethod:: dragon.vm.torch.optim.AdamW.__init__
Methods
-------
add_param_group
###############
.. automethod:: dragon.vm.torch.optim.Optimizer.add_param_group
:noindex:
step
####
.. automethod:: dragon.vm.torch.optim.Optimizer.step
:noindex:
sum_grad
########
.. automethod:: dragon.vm.torch.optim.Optimizer.sum_grad
:noindex:
zero_grad
#########
.. automethod:: dragon.vm.torch.optim.Optimizer.zero_grad
:noindex:
.. raw:: html
<style>
h1:before {
content: "torch.optim.";
color: #103d3e;
}
</style>
roll
====
.. autofunction:: dragon.vm.torch.roll
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _Gelu(const int N, const T* x, T* y) {
const T kRsqrt2 = 0.7071067811865475;
for (int i = 0; i < N; ++i) {
const T val = x[i];
y[i] = val * (T(1) + erf(val * kRsqrt2)) * T(0.5);
}
}
template <>
void _Gelu<float16>(const int N, const float16* x, float16* y) {
const float kRsqrt2 = 0.7071067811865475;
for (int i = 0; i < N; ++i) {
const float val = convert::To<float>(x[i]);
y[i] = convert::To<float16>(val * (1.f + erf(val * kRsqrt2)) * 0.5f);
}
}
template <typename T>
void _GeluGrad(const int N, const T* dy, const T* x, T* dx) {
const T kAlpha = 0.3989422804014327; // 0.5 * Sqrt(2/Pi)
const T kRsqrt2 = 0.7071067811865475;
ConstEigenVectorArrayMap<T> dY(dy, N);
ConstEigenVectorArrayMap<T> X(x, N);
EigenVectorArrayMap<T> dX(dx, N);
for (int i = 0; i < N; ++i) {
dx[i] = (T(1) + erf(x[i] * kRsqrt2)) * T(0.5);
}
dX = dY * (dX + X * ((T(-0.5) * X.square()).exp() * kAlpha));
}
template <>
void _GeluGrad<float16>(
const int N,
const float16* dy,
const float16* x,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename T>
void _ApproxGelu(const int N, const T* x, T* y) {
const T kAlpha = 0.7978845608028654; // Sqrt(2/Pi)
const T kBeta = 0.035677408136300125; // Sqrt(2/Pi) * 0.044715
ConstEigenVectorArrayMap<T> X(x, N);
EigenVectorArrayMap<T> Y(y, N);
Y = X * ((X * kAlpha + X.cube() * kBeta).tanh() + T(1)) * T(0.5);
}
template <>
void _ApproxGelu<float16>(const int N, const float16* x, float16* y) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename T>
void _ApproxGeluGrad(const int N, const T* dy, const T* x, T* dx) {
const T kAlpha = 0.7978845608028654; // Sqrt(2/Pi)
const T kBeta = 0.035677408136300125; // Sqrt(2/Pi) * 0.044715
const T kGamma = 0.10703222440890037; // Sqrt(2/Pi) * 0.044715 * 3
ConstEigenVectorArrayMap<T> dY(dy, N);
ConstEigenVectorArrayMap<T> X(x, N);
EigenVectorArrayMap<T> Y(dx, N);
EigenVectorArrayMap<T> dX(dx, N);
Y = (X * kAlpha + X.cube() * kBeta).tanh();
dX = T(0.5) * dY *
(T(1) + Y + (X - X * Y.square()) * (kGamma * X.square() + kAlpha));
}
template <>
void _ApproxGeluGrad<float16>(
const int N,
const float16* dy,
const float16* x,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>(const int N, const T* x, T* y, CPUContext* ctx) { \
_##name(N, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int N, const T* dy, const T* x, T* dx, CPUContext* ctx) { \
_##name(N, dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(Gelu, float16);
DEFINE_KERNEL_LAUNCHER(Gelu, float);
DEFINE_KERNEL_LAUNCHER(Gelu, double);
DEFINE_KERNEL_LAUNCHER(ApproxGelu, float16);
DEFINE_KERNEL_LAUNCHER(ApproxGelu, float);
DEFINE_KERNEL_LAUNCHER(ApproxGelu, double);
DEFINE_GRAD_KERNEL_LAUNCHER(GeluGrad, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(GeluGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(GeluGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(ApproxGeluGrad, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(ApproxGeluGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(ApproxGeluGrad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, typename AccT>
__global__ void _Gelu(const int N, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
const AccT val = convert::To<AccT>(x[i]);
y[i] = convert::To<T>(val * normcdf(val));
}
}
template <typename T, typename AccT>
__global__ void _GeluGrad(const int N, const T* dy, const T* x, T* dx) {
CUDA_1D_KERNEL_LOOP(i, N) {
const AccT val = convert::To<AccT>(x[i]);
dx[i] = convert::To<T>(
convert::To<AccT>(dy[i]) *
fma(AccT(0.3989422804014327) * val,
exp(val * val * AccT(-0.5)),
normcdf(val)));
}
}
template <typename T, typename AccT>
__global__ void _ApproxGelu(const int N, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
const AccT val = convert::To<AccT>(x[i]);
y[i] = fma(val,
tanh(
AccT(0.7978845608028654) *
fma(AccT(0.044715), val * val * val, val)),
val) *
AccT(0.5);
}
}
template <typename T, typename AccT>
__global__ void _ApproxGeluGrad(const int N, const T* dy, const T* x, T* dx) {
CUDA_1D_KERNEL_LOOP(i, N) {
const AccT val = convert::To<AccT>(x[i]);
const AccT val2 = tanh(
AccT(0.7978845608028654) * fma(AccT(0.044715), val * val * val, val));
dx[i] = convert::To<T>(
convert::To<AccT>(dy[i]) * AccT(0.5) *
fma(fma(-val, val2 * val2, val),
fma(AccT(0.10703222440890037), val * val, AccT(0.7978845608028654)),
val2 + AccT(1)));
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>(const int N, const T* x, T* y, CUDAContext* ctx) { \
_##name<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>( \
const int N, const T* dy, const T* x, T* dx, CUDAContext* ctx) { \
_##name<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
reinterpret_cast<const math::ScalarType<T>::type*>(dy), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<math::ScalarType<T>::type*>(dx)); \
}
DEFINE_KERNEL_LAUNCHER(Gelu, float16);
DEFINE_KERNEL_LAUNCHER(Gelu, float);
DEFINE_KERNEL_LAUNCHER(Gelu, double);
DEFINE_KERNEL_LAUNCHER(ApproxGelu, float16);
DEFINE_KERNEL_LAUNCHER(ApproxGelu, float);
DEFINE_KERNEL_LAUNCHER(ApproxGelu, double);
DEFINE_GRAD_KERNEL_LAUNCHER(GeluGrad, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(GeluGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(GeluGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(ApproxGeluGrad, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(ApproxGeluGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(ApproxGeluGrad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _Roll(
const int num_dims,
const int64_t* x_shifts,
const int64_t* x_strides,
const int64_t* y_dims,
const T* x,
T* y) {
const auto N =
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0);
for (int yi = 0; yi < N; ++yi) {
int64_t xi = 0, r;
for (int d = num_dims - 1; d >= 0; --d) {
r = index[d] - x_shifts[d];
r = (r < 0 ? r + y_dims[d] : r) % y_dims[d];
xi += r * x_strides[d];
}
y[yi] = x[xi];
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Roll<T, CPUContext>( \
const int num_dims, \
const int64_t* x_shifts, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_Roll(num_dims, x_shifts, x_strides, y_dims, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, int D>
__global__ void _Roll(
const int N,
const int num_dims,
const SimpleArray<int, D> X_shifts,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
r -= X_shifts.data[d];
r = (r < 0 ? r + Y_dims.data[d] : r) % Y_dims.data[d];
xi += r * X_strides.data[d];
}
y[yi] = x[xi];
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Roll<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_shifts, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_shifts; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> Y_dims; \
const auto N = std::accumulate( \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_shifts.data[i] = x_shifts[i]; \
X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \
} \
_Roll<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, num_dims, X_shifts, X_strides, Y_dims, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
......@@ -16,8 +16,8 @@ namespace {
template <typename T>
__global__ void _GroupNormFusedParams(
const int N,
const int G,
const int NxC,
const int C,
const int D,
const T* mu,
const T* rsig,
......@@ -25,58 +25,29 @@ __global__ void _GroupNormFusedParams(
const T* beta,
T* scale,
T* bias) {
const int NxG = N * G;
CUDA_2D_KERNEL_LOOP1(i, NxG) {
const int g = i % G;
const T mu_val = LDG(mu, i);
const T rsig_val = LDG(rsig, i);
CUDA_2D_KERNEL_LOOP2(j, D) {
const int c = g * D + j;
const int nc = i * D + j;
const T scale_val = LDG(gamma, c) * rsig_val;
scale[nc] = scale_val;
bias[nc] = fma(-scale_val, mu_val, LDG(beta, c));
}
CUDA_1D_KERNEL_LOOP(i, NxC) {
const int c = i % C;
const int ng = i / D;
const T scale_val = LDG(gamma, c) * LDG(rsig, ng);
scale[i] = scale_val;
bias[i] = fma(-scale_val, LDG(mu, ng), LDG(beta, c));
}
}
template <typename T, typename AccT>
__global__ void _GroupNormAffineNCHW(
const int N,
const int C,
const int S,
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
const int NxC = N * C;
CUDA_2D_KERNEL_LOOP1(i, NxC) {
const AccT w = LDG(scale, i);
const AccT b = LDG(bias, i);
CUDA_2D_KERNEL_LOOP2(j, S) {
const int idx = i * S + j;
y[idx] = convert::To<AccT>(fma(LDG2(x, idx), w, b));
}
}
}
template <typename T, typename AccT>
__global__ void _GroupNormAffineNHWC(
const int N,
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _GroupNormAffine(
const int NxCxS,
const int C,
const int S,
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
const int NxS = N * S;
CUDA_2D_KERNEL_LOOP1(i, NxS) {
const int n = i / S;
CUDA_2D_KERNEL_LOOP2(j, C) {
const int nc = n * C + j;
const int idx = i * C + j;
y[idx] = convert::To<T>(fma(LDG2(x, idx), LDG(scale, nc), LDG(bias, nc)));
}
CUDA_1D_KERNEL_LOOP(i, NxCxS) {
const int nc =
kOrder == StorageOrder::NCHW ? i / S : i / (C * S) * C + i % C;
y[i] = convert::To<T>(
fma(convert::To<AccT>(x[i]), LDG(scale, nc), LDG(bias, nc)));
}
}
......@@ -195,56 +166,44 @@ __global__ void _GroupNormGrad(
LOG(FATAL) << "Unknown DataFormat: " << data_format; \
}
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \
void GroupNorm<T, AccT, CUDAContext>( \
const int N, \
const int G, \
const int D, \
const int S, \
const string& data_format, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \
CUDAContext* ctx) { \
const auto C = G * D; \
_GroupNormFusedParams<<< \
CUDA_2D_BLOCKS(N* G), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(N, G, D, mu, rsig, gamma, beta, scale, bias); \
if (data_format == "NCHW") { \
_GroupNormAffineNCHW<<< \
CUDA_2D_BLOCKS(N* C), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N, \
C, \
S, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
scale, \
bias, \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} else if (data_format == "NHWC") { \
_GroupNormAffineNHWC<<< \
CUDA_2D_BLOCKS(N* S), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N, \
C, \
S, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
scale, \
bias, \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} \
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \
void GroupNorm<T, AccT, CUDAContext>( \
const int N, \
const int G, \
const int D, \
const int S, \
const string& data_format, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \
CUDAContext* ctx) { \
const auto C = G * D; \
const auto NxC = N * C; \
const auto NxCxS = NxC * S; \
_GroupNormFusedParams<<< \
CUDA_BLOCKS(NxC), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(NxC, C, D, mu, rsig, gamma, beta, scale, bias); \
DISPATCH_GROUPNORM_KERNEL( \
_GroupNormAffine, \
math::ScalarType<T>::type, \
AccT, \
CUDA_BLOCKS(NxCxS), \
CUDA_THREADS, \
NxCxS, \
C, \
S, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
scale, \
bias, \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T, AccT) \
......@@ -266,7 +225,7 @@ __global__ void _GroupNormGrad(
AccT* dbeta, \
T* dx, \
CUDAContext* ctx) { \
auto NxCxS = N * G * D * S; \
const auto NxCxS = N * G * D * S; \
DISPATCH_GROUPNORM_KERNEL( \
_GroupNormWGrad, \
math::ScalarType<T>::type, \
......
#include "dragon/operators/activation/gelu_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void GeluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
if (approximate_) {
kernels::ApproxGelu(
X.count(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
} else {
kernels::Gelu(
X.count(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
template <typename T>
void GeluGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(1), *dX = Output(0);
if (approximate_) {
kernels::ApproxGeluGrad(
X.count(),
dY.template data<T, Context>(),
X.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
} else {
kernels::GeluGrad(
X.count(),
dY.template data<T, Context>(),
X.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
}
DEPLOY_CPU_OPERATOR(Gelu);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Gelu);
#endif
DEPLOY_CPU_OPERATOR(GeluGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(GeluGradient);
#endif
OPERATOR_SCHEMA(Gelu)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(GeluGradient)
/* X, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(Gelu, GenericGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_GELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_GELU_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class GeluOp : public Operator<Context> {
public:
GeluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
approximate_(OP_SINGLE_ARG(int64_t, "approximate", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
int64_t approximate_;
};
template <class Context>
class GeluGradientOp : public Operator<Context> {
public:
GeluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
approximate_(OP_SINGLE_ARG(int64_t, "approximate", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
int64_t approximate_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_GELU_OP_H_
#include "dragon/operators/array/roll_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void RollOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
auto* X_ref = Buffer("X_ref")->ReshapeLike(X);
if (axes_.empty()) X_ref->Reshape({X.count()});
int num_shifts, num_dims = X_ref->ndim();
vec64_t X_shifts(num_dims, 0);
shifts(0, &num_shifts);
if (axes_.empty()) {
X_shifts[0] = shifts(0);
} else {
CHECK_EQ(num_shifts, int(axes_.size()))
<< "\nProviding " << axes_.size() << " dimensions and " << num_shifts
<< " shifts to roll.";
for (int i = 0; i < axes_.size(); ++i) {
int axis = axes_[i];
axis = axis < 0 ? axis + num_dims : axis;
CHECK(axis >= 0 && axis < num_dims)
<< "\nExcepted the <axis> in [-" << num_dims << ", " << num_dims
<< "), got " << axes_[i] << ".";
X_shifts[axis] += shifts(i);
}
}
Buffer("X_shifts")->template CopyFrom<int64_t>(X_shifts);
kernels::Roll(
num_dims,
X_shifts.data(),
X_ref->strides().data(),
X_ref->dims().data(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
template <typename T>
void RollGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0);
auto* X_ref = Buffer("X_ref");
vec64_t Y_shifts;
Buffer("X_shifts")->template CopyTo<int64_t>(Y_shifts);
for (int i = 0; i < Y_shifts.size(); ++i) {
Y_shifts[i] *= -1; // Reverse the shifts.
}
kernels::Roll(
X_ref->ndim(),
Y_shifts.data(),
X_ref->strides().data(),
X_ref->dims().data(),
dY.template data<T, Context>(),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
ctx());
}
DEPLOY_CPU_OPERATOR(Roll);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Roll);
#endif
DEPLOY_CPU_OPERATOR(RollGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(RollGradient);
#endif
OPERATOR_SCHEMA(Roll)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(RollGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(Roll, SimpleGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_ROLL_OP_H_
#define DRAGON_OPERATORS_ARRAY_ROLL_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class RollOp final : public Operator<Context> {
public:
RollOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), axes_(OP_REPEATED_ARG(int64_t, "axes")) {
INITIALIZE_OP_REPEATED_ARG(int64_t, shifts);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
vec64_t axes_;
DECLARE_OP_REPEATED_ARG(int64_t, shifts);
};
template <class Context>
class RollGradientOp : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(RollGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
};
DEFINE_OP_REPEATED_ARG(int64_t, RollOp, shifts);
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_ROLL_OP_H_
#include "dragon/core/workspace.h"
#include "dragon/operators/training/update_ops.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
void AdamUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
void AdamUpdateOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) {
kernels::AdamUpdate(
dX->count(),
lr_,
lr_ * correction_,
beta1_,
beta2_,
eps_,
......@@ -18,13 +19,30 @@ void AdamUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
ctx());
}
template <class Context>
void AdamWUpdateOp<Context>::ComputeUpdate(Tensor* dX, Tensor* X) {
AdamUpdateOp<Context>::ComputeUpdate(dX, X);
if (lambda_ > 0.f) {
math::Axpy(
X->count(),
this->lr_ * lambda_,
X->template data<float, Context>(),
dX->template mutable_data<float, Context>(),
ctx());
}
}
DEPLOY_CPU_OPERATOR(AdamUpdate);
DEPLOY_CPU_OPERATOR(AdamWUpdate);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(AdamUpdate);
DEPLOY_CUDA_OPERATOR(AdamWUpdate);
#endif
OPERATOR_SCHEMA(AdamUpdate).NumInputs(1, INT_MAX).NumOutputs(1, INT_MAX);
OPERATOR_SCHEMA(AdamWUpdate).NumInputs(1, INT_MAX).NumOutputs(1, INT_MAX);
NO_GRADIENT(AdamUpdate);
NO_GRADIENT(AdamWUpdate);
} // namespace dragon
......@@ -5,7 +5,7 @@
namespace dragon {
template <class Context>
void NesterovUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
void NesterovUpdateOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) {
kernels::NesterovUpdate(
dX->count(),
lr_,
......
......@@ -5,7 +5,7 @@
namespace dragon {
template <class Context>
void RMSpropUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
void RMSpropUpdateOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) {
kernels::RMSPropUpdate(
dX->count(),
lr_,
......
......@@ -5,7 +5,7 @@
namespace dragon {
template <class Context>
void SGDUpdateOp<Context>::ComputeUpdate(Tensor* dX) {
void SGDUpdateOp<Context>::ComputeUpdate(Tensor* dX, Tensor* /* X */) {
kernels::SGDUpdate(
dX->count(),
lr_,
......
......@@ -67,10 +67,10 @@ void UpdateOpBase<Context>::RunOnDevice() {
input_index_ = i;
if (dX.template IsType<float>()) {
AdjustGradient<float>(&dX, X);
ComputeUpdate(&dX);
ComputeUpdate(&dX, X);
ApplyUpdate<float>(&dX, X);
} else if (dX.template IsType<float16>()) {
auto* X_master = workspace()->CreateTensor(X->name() + "/float32");
auto* X_master = workspace()->CreateTensor(X->name() + "_master");
auto* dX_copy = ctx()->workspace()->CreateTensor("shared/buffer/data:0");
if (X_master->count() != X->count()) {
math::Cast(
......@@ -85,7 +85,7 @@ void UpdateOpBase<Context>::RunOnDevice() {
dX_copy->ReshapeLike(dX)->template mutable_data<float, Context>(),
ctx());
AdjustGradient<float>(dX_copy, X_master);
ComputeUpdate(dX_copy);
ComputeUpdate(dX_copy, X_master);
ApplyUpdate<float>(dX_copy, X_master);
math::Cast(
X->count(),
......
......@@ -35,7 +35,7 @@ class UpdateOpBase : public Operator<Context> {
void RunOnDevice() override;
virtual void ComputeUpdate(Tensor* dX) = 0;
virtual void ComputeUpdate(Tensor* dX, Tensor* X) = 0;
template <typename T>
void AdjustGradient(Tensor* dX, Tensor* X);
......@@ -75,7 +75,7 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX) override;
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override;
protected:
float lr_, last_lr_;
......@@ -96,7 +96,7 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX) override;
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override;
protected:
float lr_, momentum_;
......@@ -118,14 +118,14 @@ class RMSpropUpdateOp final : public UpdateOpBase<Context> {
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX) override;
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override;
protected:
float lr_, momentum_, decay_, eps_;
};
template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> {
class AdamUpdateOp : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), t_(0) {}
......@@ -133,19 +133,40 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
USE_UPDATE_FUNCTIONS;
void GetArguments() override {
t_++;
lr_ = Hyper("lr");
beta1_ = Hyper("beta1");
beta2_ = Hyper("beta2");
auto correction = sqrt(1.f - pow(beta2_, t_)) / (1.f - pow(beta1_, t_));
lr_ = Hyper("lr") * correction;
eps_ = Hyper("eps");
t_++;
correction_ = sqrt(1.f - pow(beta2_, t_)) / (1.f - pow(beta1_, t_));
UpdateOpBase<Context>::GetArguments();
}
void ComputeUpdate(Tensor* dX) override;
void ComputeUpdate(Tensor* dX, Tensor* /* X */) override;
protected:
int64_t t_;
float lr_, beta1_, beta2_, eps_, correction_;
};
template <class Context>
class AdamWUpdateOp final : public AdamUpdateOp<Context> {
public:
AdamWUpdateOp(const OperatorDef& def, Workspace* ws)
: AdamUpdateOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_UPDATE_FUNCTIONS;
void GetArguments() override {
AdamUpdateOp<Context>::GetArguments();
lambda_ = this->weight_decay_;
this->weight_decay_ = 0.f;
}
void ComputeUpdate(Tensor* dX, Tensor* X) override;
protected:
float lr_, beta1_, beta2_, eps_, t_;
float lambda_;
};
#undef USE_UPDATE_FUNCTIONS
......
......@@ -76,6 +76,7 @@ from dragon.core.ops.array_ops import pad
from dragon.core.ops.array_ops import range
from dragon.core.ops.array_ops import repeat
from dragon.core.ops.array_ops import reshape
from dragon.core.ops.array_ops import roll
from dragon.core.ops.array_ops import scatter_add
from dragon.core.ops.array_ops import scatter_elements
from dragon.core.ops.array_ops import shape
......
......@@ -23,6 +23,7 @@ from dragon.core.ops.activation_ops import dropout
from dragon.core.ops.activation_ops import drop_block
from dragon.core.ops.activation_ops import drop_path
from dragon.core.ops.activation_ops import elu
from dragon.core.ops.activation_ops import gelu
from dragon.core.ops.activation_ops import hardsigmoid
from dragon.core.ops.activation_ops import hardswish
from dragon.core.ops.activation_ops import leaky_relu
......@@ -31,8 +32,8 @@ from dragon.core.ops.activation_ops import prelu
from dragon.core.ops.activation_ops import relu
from dragon.core.ops.activation_ops import relu6
from dragon.core.ops.activation_ops import selu
from dragon.core.ops.activation_ops import silu
from dragon.core.ops.activation_ops import softmax
from dragon.core.ops.activation_ops import swish
from dragon.core.ops.array_ops import moments
from dragon.core.ops.normalization_ops import batch_norm
from dragon.core.ops.normalization_ops import group_norm
......
......@@ -14,6 +14,7 @@ from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.training.adam import Adam
from dragon.core.training.adam import AdamW
from dragon.core.training.optimizer import Optimizer
from dragon.core.training.rmsprop import RMSprop
from dragon.core.training.sgd import Nesterov
......
......@@ -234,6 +234,11 @@ def gather_args(**kwargs):
}
@register('Gelu')
def gelu_args(**kwargs):
return {'approximate': kwargs.get('approximate', False)}
@register('Gemm')
def gemm_args(**kwargs):
return {
......@@ -498,6 +503,14 @@ def roi_pool_args(**kwargs):
}
@register('Roll')
def roll_args(**kwargs):
return {
'axes': kwargs.get('axes', None),
'shifts_desc': 'int64',
}
@register(['ScatterElements', 'ScatterAdd', 'GatherElements'])
def scatter_gather_elements_args(**kwargs):
return {'axis': kwargs.get('axis', 0)}
......@@ -609,6 +622,10 @@ def unsqueeze_args(**kwargs):
return {'axes': kwargs.get('axes', [0])}
@register(['AdamUpdate', 'RMSpropUpdate', 'SGDUpdate', 'NesterovUpdate'])
@register(['AdamUpdate',
'AdamWUpdate',
'RMSpropUpdate',
'SGDUpdate',
'NesterovUpdate'])
def update_args(**kwargs):
return {'no_grad': True, 'weight_decay': kwargs.get('weight_decay', None)}
......@@ -178,8 +178,8 @@ def elu(inputs, alpha=1.0, inplace=False, **kwargs):
Examples:
```python
x = dragon.constant([-1, 0, 1], 'float32')
print(dragon.nn.elu(x, inplace=False))
x = dragon.constant([-1., 0., 1.])
print(dragon.nn.elu(x))
```
Parameters
......@@ -205,6 +205,40 @@ def elu(inputs, alpha=1.0, inplace=False, **kwargs):
@OpSchema.num_inputs(1)
def gelu(inputs, approximate=False, **kwargs):
r"""Apply the gaussian error linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
The **GELU** function is defined as:
.. math:: \text{GELU}(x) = x\cdot\frac{1}{2}[1 + \text{erf}(x / \sqrt{2})]
Examples:
```python
x = dragon.constant([-1., 0., 1.])
print(dragon.nn.gelu(x))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
approximate : bool, optional, default=False
Whether to approximate the computation.
Returns
-------
dragon.Tensor
The output tensor.
"""
if context.executing_eagerly():
return OpLib.execute('Gelu', inputs)
return OpLib.add('Gelu', inputs, approximate=approximate, **kwargs)
@OpSchema.num_inputs(1)
def hardsigmoid(inputs, alpha=0.2, beta=0.5, inplace=False, **kwargs):
r"""Apply the hard sigmoid function.
......@@ -216,7 +250,7 @@ def hardsigmoid(inputs, alpha=0.2, beta=0.5, inplace=False, **kwargs):
```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.hardsigmoid(x, inplace=False))
print(dragon.nn.hardsigmoid(x))
```
Parameters
......@@ -297,8 +331,8 @@ def leaky_relu(inputs, alpha=0.2, inplace=False, **kwargs):
Examples:
```python
x = dragon.constant([-1, 0, 1], 'float32')
print(dragon.nn.leaky_relu(x, inplace=False))
x = dragon.constant([-1., 0., 1.])
print(dragon.nn.leaky_relu(x))
```
Parameters
......@@ -376,8 +410,8 @@ def prelu(inputs, data_format='NCHW', **kwargs):
Examples:
```python
x = dragon.constant([[-1, 0, 1]], 'float32')
w = dragon.fill([3], value=0.25, dtype='float32')
x = dragon.constant([[-1., 0., 1.]])
w = dragon.fill((3,), value=0.25, dtype=x.dtype)
print(dragon.nn.prelu([x, w]))
```
......@@ -456,7 +490,7 @@ def relu6(inputs, inplace=False, **kwargs):
Examples:
```python
x = dragon.constant([-1, 0, 7], 'float32')
x = dragon.constant([-1., 0., 7.])
print(dragon.nn.relu6(x))
```
......@@ -561,29 +595,25 @@ def sigmoid(inputs, inplace=False, **kwargs):
@OpSchema.num_inputs(1)
def softmax(inputs, axis=-1, inplace=False, **kwargs):
r"""Compute the softmax result.
def silu(inputs, **kwargs):
r"""Apply the sigmoid linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
The **Softmax** function is defined as:
The **SiLU** function is defined as:
.. math:: \text{Softmax}(x_{i}) = \frac{\exp(x_{i})}{\sum_{j} \exp(x_{j})}
.. math:: \text{SiLU}(x) = x \cdot \frac{1}{1 + \exp(-x)}
The argument ``axis`` could be negative:
Examples:
```python
x = dragon.ones((1, 4), dtype='float32')
print(dragon.nn.softmax(x, 1)) # [[0.25 0.25 0.25 0.25]]
print(dragon.nn.softmax(x, -1)) # Equivalent
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.silu(x))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to reduce.
inplace : bool, optional, default=False
Call in-place or return a new tensor.
Returns
-------
......@@ -592,30 +622,32 @@ def softmax(inputs, axis=-1, inplace=False, **kwargs):
"""
if context.executing_eagerly():
return OpLib.execute(
'Softmax', inputs, outputs=inputs if inplace else [None], axis=axis)
return OpLib.add('Softmax', inputs, axis=axis, **kwargs)
return OpLib.execute('Swish', inputs)
return OpLib.add('Swish', inputs, **kwargs)
@OpSchema.num_inputs(1)
def tanh(inputs, inplace=False, **kwargs):
r"""Compute the tanh of input.
def softmax(inputs, axis=-1, inplace=False, **kwargs):
r"""Compute the softmax result.
The **Tanh** function is defined as:
The **Softmax** function is defined as:
.. math:: \text{Tanh}(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}
.. math:: \text{Softmax}(x_{i}) = \frac{\exp(x_{i})}{\sum_{j} \exp(x_{j})}
Examples:
The argument ``axis`` could be negative:
```python
x = dragon.constant([0.2, 0.4, 0.6, 0.8, 1.0], 'float32')
print(dragon.math.tanh(x))
x = dragon.ones((1, 4), dtype='float32')
print(dragon.nn.softmax(x, 1)) # [[0.25 0.25 0.25 0.25]]
print(dragon.nn.softmax(x, -1)) # Equivalent
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to reduce.
inplace : bool, optional, default=False
Call in-place or return a new tensor.
......@@ -627,30 +659,31 @@ def tanh(inputs, inplace=False, **kwargs):
"""
if context.executing_eagerly():
return OpLib.execute(
'Tanh', inputs, outputs=inputs if inplace else [None])
return OpLib.add('Tanh', inputs, **kwargs)
'Softmax', inputs, outputs=inputs if inplace else [None], axis=axis)
return OpLib.add('Softmax', inputs, axis=axis, **kwargs)
@OpSchema.num_inputs(1)
def swish(inputs, **kwargs):
r"""Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
def tanh(inputs, inplace=False, **kwargs):
r"""Compute the tanh of input.
The **Swish** function is defined as:
The **Tanh** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
.. math:: \text{Tanh}(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}
Examples:
```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.swish(x))
x = dragon.constant([0.2, 0.4, 0.6, 0.8, 1.0])
print(dragon.math.tanh(x))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
inplace : bool, optional, default=False
Call in-place or return a new tensor.
Returns
-------
......@@ -659,5 +692,6 @@ def swish(inputs, **kwargs):
"""
if context.executing_eagerly():
return OpLib.execute('Swish', inputs)
return OpLib.add('Swish', inputs, **kwargs)
return OpLib.execute(
'Tanh', inputs, outputs=inputs if inplace else [None])
return OpLib.add('Tanh', inputs, **kwargs)
......@@ -20,6 +20,7 @@ from dragon.core.autograph.op_impl import OpSchema
from dragon.core.framework import types
from dragon.core.ops import constant_ops
from dragon.core.util import nest
from dragon.core.util import six
@OpSchema.num_inputs(1)
......@@ -1227,6 +1228,55 @@ def reshape(inputs, shape, copy=True, **kwargs):
return OpLib.add('Reshape', **args)
@OpSchema.num_inputs(1)
@OpSchema.convert_arg('shift', name_v2='shifts')
def roll(inputs, shift, axis=None, **kwargs):
"""Roll elements along the given axis.
:attr:`axis` could be negative or ``None``:
```python
x = dragon.constant([[1, 2, 3], [4, 5, 6]])
# A negative axis is the last-k axis
print(dragon.roll(x, shift=1, axis=1)) # [[3, 1, 2], [6, 4, 5]]
print(dragon.roll(x, shift=1, axis=-1)) # Equivalent
# If axis is None, roll input as a vector
print(dragon.roll(x, shift=1)) # [[6, 1, 2], [3, 4, 5]]
# Also, axis could be a sequence of integers
print(dragon.roll(x, shift=(1, 1), axis=(0, 1))) # [[6, 4, 5], [3, 1, 2]]
print(dragon.roll(x, shift=(1, -1), axis=(0, 1))) # [[5, 6, 4], [2, 3, 1]]
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
shift : Union[int, Sequence[int], dragon.Tensor]
The rolling offset of each axis.
axis : Union[int, Sequence[int]], optional
The axis to roll.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = OpSchema.parse_args(locals())
axes = nest.flatten(axis) if axis is not None else axis
if isinstance(shift, six.integer_types):
args['shifts'] = nest.flatten(shift)
if context.executing_eagerly():
return OpLib.execute(
'Roll', inputs, num_shifts=len(args['shifts']),
shifts=args['shifts'], axes=axes)
args.pop('axis')
return OpLib.add('Roll', axes=axes, **args)
@OpSchema.num_inputs(3)
def scatter_add(inputs, axis=0, copy=True, **kwargs):
"""Add elements along the given axis of index.
......
......@@ -53,3 +53,49 @@ class Adam(optimizer.Optimizer):
self._set_hyper('beta1', beta1)
self._set_hyper('beta2', beta2)
self._set_hyper('eps', eps)
class AdamW(Adam):
r"""The optimizer to apply AdamW algorithm.
`[Loshchilov & Hutter, 2017] <https://arxiv.org/abs/1711.05101>`_.
The **AdamW** update is defined as:
.. math::
\text{AdamW}(g, p) = -\text{lr} * (\frac{m_{t}}{\sqrt{v_{t}} + \epsilon}
+ \lambda p) \\
\quad \\ \text{where}\quad
\begin{cases}
m_{t} = \beta_{1} * m_{t-1} + (1 - \beta_{1}) * g \\
v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2}
\end{cases}
"""
def __init__(
self,
lr=0.001,
beta1=0.9,
beta2=0.999,
eps=1e-8,
weight_decay=0.01,
**kwargs
):
r"""Create an ``AdamW`` updater.
Parameters
----------
lr : float, optional, default=0.001
The initial value to :math:`\text{lr}`.
beta1 : float, optional, default=0.9
The initial value to :math:`\beta_{1}`.
beta2 : float, optional, default=0.999
The initial value to :math:`\beta_{2}`.
eps : float, optional, default=1e-8
The initial value to :math:`\epsilon`
weight_decay : float, optional, default=0.01
The initial value to :math:`\lambda`.
"""
super(AdamW, self).__init__(
lr, beta1, beta2, eps, weight_decay=weight_decay, **kwargs)
......@@ -85,6 +85,18 @@ void EluGrad(
Context* ctx);
template <typename T, class Context>
void Gelu(const int N, const T* x, T* y, Context* ctx);
template <typename T, class Context>
void GeluGrad(const int N, const T* dy, const T* x, T* dx, Context* ctx);
template <typename T, class Context>
void ApproxGelu(const int N, const T* x, T* y, Context* ctx);
template <typename T, class Context>
void ApproxGeluGrad(const int N, const T* dy, const T* x, T* dx, Context* ctx);
template <typename T, class Context>
void HardSigmoid(
const int N,
const float alpha,
......@@ -490,6 +502,16 @@ void RepeatGrad(
Context* ctx);
template <typename T, class Context>
void Roll(
const int num_dims,
const int64_t* x_shifts,
const int64_t* x_strides,
const int64_t* y_dims,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void ScatterElements(
const int axis,
const int num_dims,
......
......@@ -80,6 +80,7 @@ from dragon.vm.tensorflow.core.ops.array_ops import one_hot
from dragon.vm.tensorflow.core.ops.array_ops import pad
from dragon.vm.tensorflow.core.ops.array_ops import placeholder
from dragon.vm.tensorflow.core.ops.array_ops import reshape
from dragon.vm.tensorflow.core.ops.array_ops import roll
from dragon.vm.tensorflow.core.ops.array_ops import shape
from dragon.vm.tensorflow.core.ops.array_ops import slice
from dragon.vm.tensorflow.core.ops.array_ops import split
......
......@@ -32,6 +32,7 @@ from dragon.vm.tensorflow.core.ops.nn import conv_transpose
from dragon.vm.tensorflow.core.ops.nn import depthwise_conv2d
from dragon.vm.tensorflow.core.ops.nn import dropout
from dragon.vm.tensorflow.core.ops.nn import elu
from dragon.vm.tensorflow.core.ops.nn import gelu
from dragon.vm.tensorflow.core.ops.nn import l2_loss
from dragon.vm.tensorflow.core.ops.nn import l2_normalize
from dragon.vm.tensorflow.core.ops.nn import leaky_relu
......@@ -45,6 +46,7 @@ from dragon.vm.tensorflow.core.ops.nn import moments
from dragon.vm.tensorflow.core.ops.nn import relu
from dragon.vm.tensorflow.core.ops.nn import relu6
from dragon.vm.tensorflow.core.ops.nn import selu
from dragon.vm.tensorflow.core.ops.nn import silu
from dragon.vm.tensorflow.core.ops.nn import sigmoid_cross_entropy_with_logits
from dragon.vm.tensorflow.core.ops.nn import softmax
from dragon.vm.tensorflow.core.ops.nn import softmax_cross_entropy_with_logits
......
......@@ -503,6 +503,46 @@ def reshape(tensor, shape, name=None):
return array_ops.reshape(tensor, shape=shape, name=name)
def roll(input, shift, axis, name=None):
"""Roll elements along the given axis.
:attr:`axis` could be negative or ``None``:
```python
x = tf.constant([[1, 2, 3], [4, 5, 6]])
# A negative axis is the last-k axis
print(tf.roll(x, shift=1, axis=1)) # [[3, 1, 2], [6, 4, 5]]
print(tf.roll(x, shift=1, axis=-1)) # Equivalent
# If axis is None, roll input as a vector
print(tf.roll(x, shift=1)) # [[6, 1, 2], [3, 4, 5]]
# Also, axis could be a sequence of integers
print(tf.roll(x, shift=(1, 1), axis=(0, 1))) # [[6, 4, 5], [3, 1, 2]]
print(tf.roll(x, shift=(1, -1), axis=(0, 1))) # [[5, 6, 4], [2, 3, 1]]
```
Parameters
----------
input : dragon.Tensor
The input tensor.
shift : Union[int, Sequence[int], dragon.Tensor]
The rolling offset of each axis.
axis : Union[int, Sequence[int]], optional
The axis to roll.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The output tensor.
"""
return array_ops.roll(input, shift=shift, axis=axis, name=name)
def shape(input, name=None):
"""Return the shape of input.
......
......@@ -958,6 +958,13 @@ def elu(features, alpha=1., name=None, **kwargs):
\alpha * (\exp(x) - 1), & \text{ otherwise }
\end{cases}
Examples:
```python
x = tf.constant([-1., 0., 1.])
print(tf.nn.elu(x))
```
Parameters
----------
features : dragon.Tensor
......@@ -976,6 +983,39 @@ def elu(features, alpha=1., name=None, **kwargs):
return activation_ops.elu(features, alpha=alpha, name=name, **kwargs)
def gelu(features, approximate=False, name=None):
r"""Apply the gaussian error linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
The **GELU** function is defined as:
.. math:: \text{GELU}(x) = 0.5x(1 + \tanh[\sqrt{2/\pi}(x + 0.044715x^{3})])
Examples:
```python
x = tf.constant([-1., 0., 1.])
print(tf.nn.gelu(x))
```
Parameters
----------
features : dragon.Tensor
The input tensor.
approximate : bool, optional, default=False
Whether to approximate the computation.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The output tensor.
"""
return activation_ops.gelu(features, approximate=approximate, name=name)
def l2_loss(t, name=None):
return loss_ops.l2_loss(t, normalization='NONE', name=name)
......@@ -1552,6 +1592,35 @@ def sparse_softmax_cross_entropy_with_logits(labels, logits, name=None):
)
def silu(features):
r"""Apply the sigmoid linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
The **SiLU** function is defined as:
.. math:: \text{SiLU}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Examples:
```python
x = tf.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(tf.nn.silu(x))
```
Parameters
----------
features : dragon.Tensor
The input tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
return activation_ops.silu(features)
def top_k(input, k=1, sorted=True, name=None):
"""Return the top-K largest elements along the last axis.
......
......@@ -190,6 +190,30 @@ class TestActivationOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_elu()
def test_gelu(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data = np.array([-1., 0., 1.], 'float32')
cdf = data.copy()
pdf = 0.3989422804014327 * np.exp(-0.5 * np.square(data))
for i in range(data.size):
cdf[i] = 0.5 * (1 + math.erf(data[i] * 0.7071067811865475))
for approximate in (False, True):
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.nn.gelu(x, approximate=approximate)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual(
[y, dx], [data * cdf, data * (cdf + data * pdf)],
prec=0.001 if approximate else None)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_gelu_cuda(self):
dragon.cuda.enable_cudnn(False)
with dragon.device('cuda'):
self.test_gelu()
def test_hardsigmoid(self):
alpha, beta = 0.2, 0.5
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......@@ -390,6 +414,24 @@ class TestActivationOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_sigmoid()
def test_silu(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], 'float32')
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.nn.silu(x)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
result = data * (1. / (1. + np.exp(-data)))
result2 = data * (result + (1. / (1. + np.exp(-data))) * (1. - result))
self.assertEqual([y, dx], [result, result2])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_silu_cuda(self):
with dragon.device('cuda'):
self.test_silu()
def test_softmax(self):
grad = np.array([[-0.11596, -0.0523, 0.16825],
[-0.15008, 0.3116, -0.16152]], dtype='float32')
......@@ -415,24 +457,6 @@ class TestActivationOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_softmax()
def test_swish(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], 'float32')
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.nn.swish(x)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
result = data * (1. / (1. + np.exp(-data)))
result2 = data * (result + (1. / (1. + np.exp(-data))) * (1. - result))
self.assertEqual([y, dx], [result, result2])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_swish_cuda(self):
with dragon.device('cuda'):
self.test_swish()
def test_tanh(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......@@ -2260,6 +2284,26 @@ class TestMathOps(OpTestCase):
with dragon.device('cuda'):
self.test_reciprocal()
def test_roll(self):
entries = [(0, 0), ((0, 0), (0, 1)), ((-1, 1), (0, 1)), (1, None)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for shift, axis in entries:
data = arange((2, 3))
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.roll(x, shift, axis)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual(
[y, dx], [np.roll(data, shift, axis),
np.roll(data, [-v for v in nest.flatten(shift)], axis)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_roll_cuda(self):
with dragon.device('cuda'):
self.test_roll()
def test_round(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......@@ -2762,6 +2806,7 @@ class TestTrainingOps(OpTestCase):
def __init__(self, method_name='runTest'):
super(TestTrainingOps, self).__init__(method_name)
self.adam = dragon.optimizers.Adam()
self.adam_w = dragon.optimizers.AdamW()
self.nesterov = dragon.optimizers.Nesterov()
self.rmsprop = dragon.optimizers.RMSprop()
self.sgd = dragon.optimizers.SGD()
......@@ -2790,6 +2835,30 @@ class TestTrainingOps(OpTestCase):
with dragon.device('cuda'):
self.test_adam_update()
def test_adam_w_update(self):
with execution_context().mode('EAGER_MODE'):
lr, eps = self.adam_w.lr, self.adam_w.eps
beta1, beta2 = self.adam_w.beta1, self.adam_w.beta2
wd = self.adam_w.weight_decay
data1 = uniform((2, 3))
data2, data3 = np.zeros((2, 3), 'float32'), np.zeros((2, 3), 'float32')
param = new_tensor(data1)
for i in range(2):
t = i + 1
coef = math.sqrt(1 - math.pow(beta2, t)) / (1 - math.pow(beta1, t))
data4 = uniform((2, 3))
grad = new_tensor(data4)
self.adam_w.apply_gradients([[grad, param]])
data2 = beta1 * data2 + (1 - beta1) * data4
data3 = beta2 * data3 + (1 - beta2) * np.square(data4)
data1 -= lr * (coef * data2 / (np.sqrt(data3) + eps) + wd * data1)
self.assertEqual(param, data1)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_adam_w_update_cuda(self):
with dragon.device('cuda'):
self.test_adam_w_update()
def test_nesterov_update(self):
with execution_context().mode('EAGER_MODE'):
momentum, lr = self.nesterov.momentum, self.nesterov.lr
......
......@@ -15,6 +15,7 @@ from __future__ import division
from __future__ import print_function
import collections
import math
import os
import unittest
......@@ -251,6 +252,18 @@ class TestModules(OpTestCase):
y, _ = m(x), repr(m)
self.assertEqual(y, result)
def test_channel_shuffle(self):
entries = [(1, 4)]
for axis, group in entries:
data = arange((2, 8))
g, k = group, data.shape[axis] // group
shape = data.shape[:axis] + (g, k) + data.shape[axis + 1:]
perm = list(range(0, axis)) + [axis + 1, axis] + list(range(axis + 2, len(shape)))
x = new_tensor(data)
m = torch.nn.ChannelShuffle(group)
y, _ = m(x), repr(m)
self.assertEqual(y, data.reshape(shape).transpose(perm).reshape(data.shape))
def test_conv1d(self):
entries = [((2, 2, 2), (3, 2, 1), (3,), 1, 1, 0, 1, 1),
((2, 2, 2), (3, 2, 3), (3,), 3, 1, 1, 1, 1)]
......@@ -467,6 +480,16 @@ class TestModules(OpTestCase):
new_shape += data.shape[end_dim + 1:]
self.assertEqual(y, data.reshape(new_shape))
def test_gelu(self):
data = np.array([-1., 0., 1.], 'float32')
cdf = data.copy()
for i in range(data.size):
cdf[i] = 0.5 * (1 + math.erf(data[i] * 0.7071067811865475))
x = new_tensor(data)
m = torch.nn.GELU()
y, _ = m(x), repr(m)
self.assertEqual(y, data * cdf)
def test_group_norm(self):
eps = 1e-5
entries = [((1, 4), (1, 4), 2, (2,)),
......@@ -783,6 +806,14 @@ class TestModules(OpTestCase):
result = reduce(pos_term + neg_term, reduction=reduction)
self.assertEqual(y, result)
def test_silu(self):
data = np.array([-3., -2., -1., 0., 1., 2., 3], 'float32')
x = new_tensor(data)
m = torch.nn.SiLU()
y, _ = m(x), repr(m)
result = data * (1. / (1. + np.exp(-data)))
self.assertEqual(y, result)
def test_softmax(self):
data = np.array([[0.2, 0.3, 0.5], [0.1, 0.7, 0.2]], 'float32')
x = new_tensor(np.log(data))
......@@ -790,14 +821,6 @@ class TestModules(OpTestCase):
y, _ = m(x), repr(m)
self.assertEqual(y, data)
def test_swish(self):
data = np.array([-3., -2., -1., 0., 1., 2., 3], 'float32')
x = new_tensor(data)
m = torch.nn.Swish()
y, _ = m(x), repr(m)
result = data * (1. / (1. + np.exp(-data)))
self.assertEqual(y, result)
def test_tanh(self):
data = np.array([0.2, 0.4, 0.6, 0.8, 1.], 'float32')
x = new_tensor(data)
......
......@@ -605,6 +605,14 @@ class TestTensorOps(OpTestCase):
self.assertEqual(x, data)
self.assertEqual(x.view_as(x), data)
def test_roll(self):
entries = [(0, 0), ((0, 0), (0, 1)), ((-1, 1), (0, 1)), (1, None)]
for shift, axis in entries:
data = arange((2, 3))
x = new_tensor(data)
y = x.roll(shift, axis)
self.assertEqual(y, np.roll(data, shift, axis))
def test_round(self):
data = np.array([0.9, 1.4, 1.9], 'float32')
x = new_tensor(data)
......@@ -889,17 +897,6 @@ class TestTorchOps(OpTestCase):
y = torch.channel_normalize(x, *args, **kwargs)
self.assertEqual(y, (data - mean) / std)
def test_channel_shuffle(self):
entries = [(0, 2), (1, 4)]
for axis, group in entries:
data = arange((2, 8))
g, k = group, data.shape[axis] // group
shape = data.shape[:axis] + (g, k) + data.shape[axis + 1:]
perm = list(range(0, axis)) + [axis + 1, axis] + list(range(axis + 2, len(shape)))
x = new_tensor(data)
y = torch.channel_shuffle(x, axis, group)
self.assertEqual(y, data.reshape(shape).transpose(perm).reshape(data.shape))
def test_linspace(self):
entries = [([[0., 5.], [10., 40.], 5], {'dim': 0, 'dtype': 'float32'}),
([[0., 5.], [10., 40.], 5], {'dim': 1, 'dtype': 'float32'}),
......
......@@ -64,6 +64,7 @@ class TestOptimizer(unittest.TestCase):
for lr, betas, eps, amsgrad in entries:
try:
_ = torch.optim.Adam([weight], lr=lr, betas=betas, eps=eps, amsgrad=amsgrad)
_ = torch.optim.AdamW([weight], lr=lr, betas=betas, eps=eps, amsgrad=amsgrad)
except (ValueError, NotImplementedError):
pass
......
......@@ -52,7 +52,6 @@ from dragon.vm.torch.core.ops.array_ops import broadcast_to
from dragon.vm.torch.core.ops.array_ops import cat
from dragon.vm.torch.core.ops.array_ops import channel_affine
from dragon.vm.torch.core.ops.array_ops import channel_normalize
from dragon.vm.torch.core.ops.array_ops import channel_shuffle
from dragon.vm.torch.core.ops.array_ops import chunk
from dragon.vm.torch.core.ops.array_ops import cumsum
from dragon.vm.torch.core.ops.array_ops import flatten
......@@ -69,6 +68,7 @@ from dragon.vm.torch.core.ops.array_ops import nonzero
from dragon.vm.torch.core.ops.array_ops import one_hot
from dragon.vm.torch.core.ops.array_ops import permute
from dragon.vm.torch.core.ops.array_ops import reshape
from dragon.vm.torch.core.ops.array_ops import roll
from dragon.vm.torch.core.ops.array_ops import scatter
from dragon.vm.torch.core.ops.array_ops import scatter_add
from dragon.vm.torch.core.ops.array_ops import sort
......
......@@ -20,6 +20,7 @@ from dragon.vm.torch._api.nn import init
# Classes
from dragon.vm.torch.core.nn.modules.activation import ELU
from dragon.vm.torch.core.nn.modules.activation import GELU
from dragon.vm.torch.core.nn.modules.activation import GumbelSoftmax
from dragon.vm.torch.core.nn.modules.activation import Hardsigmoid
from dragon.vm.torch.core.nn.modules.activation import Hardswish
......@@ -31,13 +32,14 @@ from dragon.vm.torch.core.nn.modules.activation import ReLU
from dragon.vm.torch.core.nn.modules.activation import ReLU6
from dragon.vm.torch.core.nn.modules.activation import SELU
from dragon.vm.torch.core.nn.modules.activation import Sigmoid
from dragon.vm.torch.core.nn.modules.activation import SiLU
from dragon.vm.torch.core.nn.modules.activation import Softmax
from dragon.vm.torch.core.nn.modules.activation import Swish
from dragon.vm.torch.core.nn.modules.activation import Tanh
from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm1d
from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm2d
from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm3d
from dragon.vm.torch.core.nn.modules.batchnorm import SyncBatchNorm
from dragon.vm.torch.core.nn.modules.channelshuffle import ChannelShuffle
from dragon.vm.torch.core.nn.modules.container import Container
from dragon.vm.torch.core.nn.modules.container import ModuleList
from dragon.vm.torch.core.nn.modules.container import Sequential
......
......@@ -25,6 +25,7 @@ from dragon.vm.torch.core.nn.functional import avg_pool2d
from dragon.vm.torch.core.nn.functional import avg_pool3d
from dragon.vm.torch.core.nn.functional import batch_norm
from dragon.vm.torch.core.nn.functional import binary_cross_entropy_with_logits
from dragon.vm.torch.core.nn.functional import channel_shuffle
from dragon.vm.torch.core.nn.functional import conv1d
from dragon.vm.torch.core.nn.functional import conv2d
from dragon.vm.torch.core.nn.functional import conv3d
......@@ -39,6 +40,7 @@ from dragon.vm.torch.core.nn.functional import drop_path
from dragon.vm.torch.core.nn.functional import dropout
from dragon.vm.torch.core.nn.functional import elu
from dragon.vm.torch.core.nn.functional import embedding
from dragon.vm.torch.core.nn.functional import gelu
from dragon.vm.torch.core.nn.functional import group_norm
from dragon.vm.torch.core.nn.functional import hardsigmoid
from dragon.vm.torch.core.nn.functional import hardswish
......@@ -64,9 +66,9 @@ from dragon.vm.torch.core.nn.functional import relu6
from dragon.vm.torch.core.nn.functional import selu
from dragon.vm.torch.core.nn.functional import sigmoid
from dragon.vm.torch.core.nn.functional import sigmoid_focal_loss
from dragon.vm.torch.core.nn.functional import silu
from dragon.vm.torch.core.nn.functional import smooth_l1_loss
from dragon.vm.torch.core.nn.functional import softmax
from dragon.vm.torch.core.nn.functional import swish
from dragon.vm.torch.core.nn.functional import sync_batch_norm
from dragon.vm.torch.core.nn.functional import tanh
from dragon.vm.torch.core.nn.functional import upsample
......
......@@ -15,6 +15,7 @@ from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.vm.torch.core.optim.adam import Adam
from dragon.vm.torch.core.optim.adam import AdamW
from dragon.vm.torch.core.optim.optimizer import Optimizer
from dragon.vm.torch.core.optim.rmsprop import RMSprop
from dragon.vm.torch.core.optim.sgd import SGD
......
......@@ -353,6 +353,31 @@ def binary_cross_entropy_with_logits(
[input, target], reduction=reduction.upper())
def channel_shuffle(input, groups):
"""Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
groups : int
The number of shuffle groups.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.ChannelShuffle(...)`_
"""
return FunctionLib.apply(
'ChannelShuffle', input.device, [input], axis=1, group=groups)
def conv1d(
input,
weight,
......@@ -879,8 +904,34 @@ def embedding(input, weight, padding_idx=None):
return weight.index_select(0, input)
def gelu(input):
r"""Apply the gaussian error linear unit to input.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
The **GELU** function is defined as:
.. math:: \text{GELU}(x) = x\cdot\frac{1}{2}[1 + \text{erf}(x / \sqrt{2})]
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.GELU(...)`_
"""
return FunctionLib.apply('Gelu', input.device, [input], approximate=False)
def group_norm(input, num_groups, weight, bias, eps=1e-5):
r"""Apply the group normalization to input.
"""Apply the group normalization to input.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
Parameters
......@@ -1920,6 +1971,32 @@ def sigmoid_focal_loss(
start_index=start_index, reduction=reduction.upper())
def silu(input):
r"""Apply the sigmoid linear unit to input.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
The **SiLU** function is defined as:
.. math:: \text{SiLU}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.SiLU(...)`_
"""
return FunctionLib.apply('Swish', input.device, [input])
def smooth_l1_loss(
input,
target,
......@@ -2005,32 +2082,6 @@ def softmax(input, dim, inplace=False):
outputs=[input if inplace else None], axis=dim)
def swish(input):
r"""Apply the swish function to input.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
The **Swish** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.Swish(...)`_
"""
return FunctionLib.apply('Swish', input.device, [input])
def sync_batch_norm(
input,
running_mean,
......
......@@ -73,6 +73,36 @@ class ELU(Module):
return F.elu(input, self.alpha, self.inplace)
class GELU(Module):
r"""Apply the gaussian error linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
The **GELU** function is defined as:
.. math:: \text{GELU}(x) = x\cdot\frac{1}{2}[1 + \text{erf}(x / \sqrt{2})]
Examples:
```python
m = torch.nn.GELU()
x = torch.randn(2, 3)
y = m(x)
```
See Also
--------
`torch.nn.functional.gelu(...)`_
"""
def __init__(self):
"""Create a ``GELU`` module."""
super(GELU, self).__init__()
def forward(self, input):
return F.gelu(input)
class GumbelSoftmax(Module):
r"""Apply the gumbel softmax function.
`[Jang et.al, 2016] <https://arxiv.org/abs/1611.01144>`_.
......@@ -637,6 +667,36 @@ class Sigmoid(Module):
return F.sigmoid(input, self.inplace)
class SiLU(Module):
r"""Apply the sigmoid linear unit.
`[Hendrycks & Gimpel, 2016] <https://arxiv.org/abs/1606.08415>`_.
The **SiLU** function is defined as:
.. math:: \text{SiLU}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Examples:
```python
m = torch.nn.So()
x = torch.randn(2, 3)
y = m(x)
```
See Also
--------
`torch.nn.functional.silu(...)`_
"""
def __init__(self):
"""Create a ``SiLU`` module."""
super(SiLU, self).__init__()
def forward(self, input):
return F.silu(input)
class Softmax(Module):
r"""Apply the softmax function.
......@@ -681,36 +741,6 @@ class Softmax(Module):
return F.softmax(input, self.dim, self.inplace)
class Swish(Module):
r"""Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
The **Swish** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Examples:
```python
m = torch.nn.Swish()
x = torch.randn(2, 3)
y = m(x)
```
See Also
--------
`torch.nn.functional.swish(...)`_
"""
def __init__(self):
"""Create a ``Swish`` module."""
super(Swish, self).__init__()
def forward(self, input):
return F.swish(input)
class Tanh(Module):
r"""Apply the tanh function.
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Shuffle modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules.module import Module
class ChannelShuffle(Module):
"""Apply group shuffle to each channel.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
Examples:
```python
m = torch.nn.ChannelShuffle(2)
x = torch.tensor([1, 2, 3, 4])
print(m(x)) # [1, 3, 2, 4]
```
See Also
--------
`torch.nn.functional.channel_shuffle(...)`_
"""
def __init__(self, groups):
"""Create a ``ChannelShuffle`` module.
Parameters
----------
groups : int
The number of shuffle groups.
"""
super(ChannelShuffle, self).__init__()
self.groups = groups
def extra_repr(self):
return 'groups={}'.format(self.groups)
def forward(self, input):
return F.channel_shuffle(input, self.groups)
......@@ -308,4 +308,6 @@ def _get_activation_fn(activation):
"""Return the activation function."""
if activation == 'relu':
return F.relu
elif activation == 'gelu':
return F.gelu
raise RuntimeError('Unknown activation: {}'.format(activation))
......@@ -250,39 +250,6 @@ def channel_normalize(input, mean, std, dim=-1, dtype='float32', dims=None):
ndim=len(dims) if dims is not None else 0, perm=dims)
def channel_shuffle(input, dim=0, groups=1, out=None):
"""Apply group shuffle to each channel of input.
`[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
Examples:
```python
x = torch.tensor([1, 2, 3, 4])
print(torch.channel_shuffle(x, groups=2)) # [1, 3, 2, 4]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional, default=0
The channel dimension.
groups : int, optional, default=1
The number of shuffle groups.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return FunctionLib.apply(
'ChannelShuffle', input.device, [input], outputs=[out],
axis=dim, group=groups)
def chunk(tensor, chunks, dim=0):
"""Split input into a specific number of chunks.
......@@ -898,6 +865,48 @@ def reshape(input, shape, out=None):
ndim=len(shape), dims=shape)
def roll(input, shifts, dims=None):
"""Roll elements along the given dimension.
:attr:`dims` could be negative or ``None``:
```python
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# A negative dimension is the last-k dimension
print(torch.roll(x, shifts=1, dims=1)) # [[3, 1, 2], [6, 4, 5]]
print(torch.roll(x, shifts=1, dims=-1)) # Equivalent
# If dimension is None, roll input as a vector
print(torch.roll(x, shifts=1)) # [[6, 1, 2], [3, 4, 5]]
# Also, dimension could be a sequence of integers
print(torch.roll(x, shifts=(1, 1), dims=(0, 1))) # [[6, 4, 5], [3, 1, 2]]
print(torch.roll(x, shifts=(1, -1), dims=(0, 1))) # [[5, 6, 4], [2, 3, 1]]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
shifts : Union[int, Sequence[int]]
The rolling offset of each dimension.
dims : Union[int, Sequence[int]], optional
The dimension to roll.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
shifts = nest.flatten(shifts)
dims = nest.flatten(dims) if dims is not None else dims
return FunctionLib.apply(
'Roll', input.device, [input],
num_shifts=len(shifts), shifts=shifts, axes=dims)
def scatter(input, dim, index, src, out=None):
"""Update elements along the given dimension of index.
......
......@@ -1997,6 +1997,29 @@ def reshape_(self, shape):
return array_ops.reshape(self, shape, self)
def roll(self, shifts, dims=None):
"""Return a tensor of rolled elements.
Parameters
----------
shifts : Union[int, Sequence[int]]
The rolling offset of each dimension.
dims : Union[int, Sequence[int]], optional
The dimension to roll.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.roll(...)`_
"""
return array_ops.roll(self, shifts, dims)
def round(self):
r"""Return a tensor taken the round of elements.
......@@ -2897,6 +2920,7 @@ Tensor.reciprocal_ = reciprocal_
Tensor.repeat = repeat
Tensor.reshape = reshape
Tensor.reshape_ = reshape_
Tensor.roll = roll
Tensor.round = round
Tensor.round_ = round_
Tensor.rsqrt = rsqrt
......
......@@ -95,3 +95,58 @@ class Adam(Optimizer):
'scale': ('scale', collections.defaultdict(str)),
'clip_norm': ('clip_norm', collections.defaultdict(str)),
}
class AdamW(Adam):
r"""The optimizer to apply AdamW algorithm.
`[Loshchilov & Hutter, 2017] <https://arxiv.org/abs/1711.05101>`_.
The **AdamW** update is defined as:
.. math::
\text{AdamW}(g, p) = -\text{lr} * (\frac{m_{t}}{\sqrt{v_{t}} + \epsilon}
+ \lambda p) \\
\quad \\ \text{where}\quad
\begin{cases}
m_{t} = \beta_{1} * m_{t-1} + (1 - \beta_{1}) * g \\
v_{t} = \beta_{2} * v_{t-1} + (1 - \beta_{2}) * g^{2}
\end{cases}
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01,
amsgrad=False,
scale=1,
clip_norm=0,
):
r"""Create an ``AdamW`` optimizer.
Parameters
----------
params : Sequence[dragon.vm.torch.nn.Parameter]
The parameters to optimize.
lr : float, required
The initial value to :math:`\text{lr}`.
betas : Tuple[float, float], optional, default=(0.9, 0.999)
The initial value to :math:`\beta_{1}` and :math:`\beta_{2}`.
eps : float, optional, default=1e-8
The initial value to :math:`\epsilon`.
weight_decay : float, optional, default=0.01
The initial value to :math:`\lambda`.
amsgrad : bool, optional, default=False
``True`` to switch to **AMSGrad** optimizer.
scale : float, optional, default=1
The scaling factor to gradient.
clip_norm : float, optional, default=0
The maximum L2 norm to clip gradient.
"""
super(AdamW, self).__init__(params, lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad,
scale=scale, clip_norm=clip_norm)
......@@ -2150,6 +2150,27 @@ class Tensor(object):
raise RuntimeError('Retain grad for a tensor that does not require.')
self._retains_grad = True
def roll(self, shifts, dims=None):
"""Return a tensor of rolled elements.
Parameters
----------
shifts : Union[int, Sequence[int]]
The rolling offset of each dimension.
dims : Union[int, Sequence[int]], optional
The dimension to roll.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.roll(...)`_
"""
def round(self):
r"""Return a tensor taken the round of elements.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!