Commit 9ca4b60f by Ting PAN

Add HardSigmoid && HardSwish && Swish Operator

Summary:
This commit adds the hardsigmoid, hardswish and swish op with specialized kernel,
there are widely used in MobileNetV3 and EfficientNet.
1 parent f76c693e
Showing with 3110 additions and 409 deletions
......@@ -18,7 +18,7 @@ from dragon.core.util import tls
def device(device_type, device_index=0):
"""Context-manager to nest the the device spec.
"""Context-manager to nest the device spec.
Examples:
......
......@@ -16,7 +16,7 @@ vm.dali
#########
`device(...) <dali/device.html>`_
: Context-manager to nest the the device spec.
: Context-manager to nest the device spec.
`get_device_type(...) <dali/get_device_type.html>`_
: Return the current nesting device type.
......
......@@ -55,7 +55,7 @@ dragon
: Create a callable graph from the specified outputs.
`device(...) <dragon/device.html>`_
: Context-manager to nest the the device spec.
: Context-manager to nest the device spec.
`eager_mode(...) <dragon/eager_mode.html>`_
: Context-manager set the eager execution mode.
......
......@@ -63,6 +63,13 @@ dragon.nn
: Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`hardsigmoid(...) <nn/hardsigmoid.html>`_
: Apply the hard sigmoid function.
`hardswish(...) <nn/hardswish.html>`_
: Apply the hard swish function.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
`instance_norm(...) <nn/instance_norm.html>`_
: Apply the instance normalization.
`[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
......@@ -79,7 +86,7 @@ dragon.nn
`[Krizhevsky et.al, 2012] <http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf>`_.
`log_softmax(...) <nn/log_softmax.html>`_
: Apply the composite of logarithm and softmax.
: Compute the composite of logarithm and softmax.
`prelu(...) <nn/prelu.html>`_
: Apply the parametric rectified linear unit.
......@@ -101,11 +108,15 @@ dragon.nn
`[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`softmax(...) <nn/softmax.html>`_
: Apply the softmax function.
: Compute the softmax result.
`space_to_depth(...) <nn/space_to_depth.html>`_
: Rearrange blocks of spatial data into depth.
`swish(...) <nn/swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`sync_batch_norm(...) <nn/sync_batch_norm.html>`_
: Apply the batch normalization with synced statistics.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
......@@ -128,6 +139,8 @@ dragon.nn
nn/elu
nn/fully_connected
nn/group_norm
nn/hardsigmoid
nn/hardswish
nn/instance_norm
nn/layer_norm
nn/leaky_relu
......@@ -140,6 +153,7 @@ dragon.nn
nn/selu
nn/softmax
nn/space_to_depth
nn/swish
nn/sync_batch_norm
.. raw:: html
......
hardsigmoid
===========
.. autofunction:: dragon.nn.hardsigmoid
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
hardswish
=========
.. autofunction:: dragon.nn.hardswish
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
swish
=====
.. autofunction:: dragon.nn.swish
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
color: #103d3e;
}
</style>
......@@ -63,7 +63,7 @@ Name Supported Reference
`GlobalLpPool`_
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d`
`Greater`_ |v| :func:`dragon.math.greater`
`HardSigmoid`_
`HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid`
`Hardmax`_
`Identity`_
`If`_
......
......@@ -37,7 +37,7 @@ vm.tensorflow
: Return a tensor initialized from the value.
`device(...) <tensorflow/device.html>`_
: Context-manager to nest the the device spec.
: Context-manager to nest the device spec.
`expand_dims(...) <tensorflow/expand_dims.html>`_
: Expand the dimensions of input with size 1.
......
......@@ -16,6 +16,9 @@ activations
`get(...) <activations/get.html>`_
: Return the activation callable by identifier.
`hard_sigmoid(...) <activations/hard_sigmoid.html>`_
: Apply the hard sigmoid function to input.
`linear(...) <activations/linear.html>`_
: Apply the linear activation to input.
......@@ -33,6 +36,10 @@ activations
`softmax(...) <activations/softmax.html>`_
: Apply the softmax function to input.
`swish(...) <activations/swish.html>`_
: Apply the swish function to input.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`tanh(...) <activations/tanh.html>`_
: Apply the tanh function to input.
......@@ -42,11 +49,13 @@ activations
activations/elu
activations/exponential
activations/get
activations/hard_sigmoid
activations/linear
activations/relu
activations/selu
activations/sigmoid
activations/softmax
activations/swish
activations/tanh
.. raw:: html
......
hard_sigmoid
============
.. autofunction:: dragon.vm.tensorflow.keras.activations.hard_sigmoid
.. raw:: html
<style>
h1:before {
content: "tf.keras.activations.";
color: #103d3e;
}
</style>
swish
=====
.. autofunction:: dragon.vm.tensorflow.keras.activations.swish
.. raw:: html
<style>
h1:before {
content: "tf.keras.activations.";
color: #103d3e;
}
</style>
......@@ -86,6 +86,10 @@ vm.tensorflow.nn
`sparse_softmax_cross_entropy_with_logits(...) <nn/sparse_softmax_cross_entropy_with_logits.html>`_
: Compute the softmax cross entropy with sparse labels.
`swish(...) <nn/swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
.. toctree::
:hidden:
......@@ -113,6 +117,7 @@ vm.tensorflow.nn
nn/softmax_cross_entropy_with_logits
nn/space_to_depth
nn/sparse_softmax_cross_entropy_with_logits
nn/swish
.. raw:: html
......
swish
=====
.. autofunction:: dragon.vm.tensorflow.nn.swish
.. raw:: html
<style>
h1:before {
content: "tf.nn.";
color: #103d3e;
}
</style>
......@@ -109,6 +109,12 @@ vm.torch
`from_numpy(...) <torch/from_numpy.html>`_
: Create a tensor from the given numpy array.
`full(...) <torch/full.html>`_
: Return a tensor filled with a scalar.
`full_like(...) <torch/full_like.html>`_
: Return a tensor filled with a scalar with size as input.
`ge(...) <torch/ge.html>`_
: Compute the element-wise greater-equal comparison.
......@@ -172,6 +178,9 @@ vm.torch
`ne(...) <torch/ne.html>`_
: Compute the element-wise not-equal comparison.
`neg(...) <torch/neg.html>`_
: Compute the element-wise negative.
`nonzero(...) <torch/nonzero.html>`_
: Return the index of non-zero elements.
......@@ -299,6 +308,8 @@ vm.torch
torch/flatten
torch/floor
torch/from_numpy
torch/full
torch/full_like
torch/ge
torch/gt
torch/index_select
......@@ -320,6 +331,7 @@ vm.torch
torch/multinomial
torch/narrow
torch/ne
torch/neg
torch/no_grad
torch/nonzero
torch/ones
......
......@@ -314,9 +314,33 @@ ndimension
.. automethod:: dragon.vm.torch.Tensor.ndimension
ne
###
##
.. automethod:: dragon.vm.torch.Tensor.ne
neg
###
.. automethod:: dragon.vm.torch.Tensor.neg
neg\_
#####
.. automethod:: dragon.vm.torch.Tensor.neg_
new_ones
########
.. automethod:: dragon.vm.torch.Tensor.new_ones
new_empty
#########
.. automethod:: dragon.vm.torch.Tensor.new_empty
new_full
########
.. automethod:: dragon.vm.torch.Tensor.new_full
new_zeros
#########
.. automethod:: dragon.vm.torch.Tensor.new_zeros
nonzero
#######
.. automethod:: dragon.vm.torch.Tensor.nonzero
......@@ -497,10 +521,12 @@ zero\_
.. _torch.cos(...): cos.html
.. _torch.cumsum(...): cumsum.html
.. _torch.div(...): div.html
.. _torch.empty(...): empty.html
.. _torch.eq(...): eq.html
.. _torch.exp(...): exp.html
.. _torch.flatten(...): flatten.html
.. _torch.floor(...): floor.html
.. _torch.full(...): full.html
.. _torch.ge(...): ge.html
.. _torch.gt(...): gt.html
.. _torch.le(...): le.html
......@@ -509,6 +535,7 @@ zero\_
.. _torch.ne(...): ne.html
.. _torch.neg(...): neg.html
.. _torch.nonzero(...): nonzero.html
.. _torch.ones(...): ones.html
.. _torch.pow(...): pow.html
.. _torch.reciprocal(...): reciprocal.html
.. _torch.reshape(...): reshape.html
......@@ -526,6 +553,7 @@ zero\_
.. _torch.unique(...): unique.html
.. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html
.. _torch.zeros(...): zeros.html
.. raw:: html
......
full
====
.. autofunction:: dragon.vm.torch.full
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
full_like
=========
.. autofunction:: dragon.vm.torch.full_like
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
neg
===
.. autofunction:: dragon.vm.torch.neg
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
......@@ -84,6 +84,13 @@ vm.torch.nn
: Apply the gumbel softmax with a temperature.
`[Jang et.al, 2016] <https://arxiv.org/abs/1611.01144>`_.
`class Hardsigmoid <nn/Hardsigmoid.html>`_
: Apply the hard sigmoid function.
`class Hardswish <nn/Hardswish.html>`_
: Apply the hard swish function.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
`class KLDivLoss <nn/KLDivLoss.html>`_
: Compute the Kullback-Leibler divergence.
......@@ -178,6 +185,10 @@ vm.torch.nn
`class Softmax <nn/Softmax.html>`_
: Apply the softmax function.
`class Swish <nn/Swish.html>`_
: Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`class Tanh <nn/Tanh.html>`_
: Apply the tanh function.
......@@ -222,6 +233,8 @@ vm.torch.nn
nn/GroupNorm
nn/GRU
nn/GumbelSoftmax
nn/Hardsigmoid
nn/Hardswish
nn/KLDivLoss
nn/L1Loss
nn/LeakyReLU
......@@ -250,6 +263,7 @@ vm.torch.nn
nn/SigmoidFocalLoss
nn/SmoothL1Loss
nn/Softmax
nn/Swish
nn/Tanh
nn/SyncBatchNorm
nn/Upsample
......
Hardsigmoid
===========
.. autoclass:: dragon.vm.torch.nn.Hardsigmoid
__init__
--------
.. automethod:: dragon.vm.torch.nn.Hardsigmoid.__init__
.. _torch.nn.functional.hardsigmoid(...): functional/hardsigmoid.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
Hardswish
=========
.. autoclass:: dragon.vm.torch.nn.Hardswish
__init__
--------
.. automethod:: dragon.vm.torch.nn.Hardswish.__init__
.. _torch.nn.functional.hardswish(...): functional/hardswish.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
Swish
=====
.. autoclass:: dragon.vm.torch.nn.Swish
__init__
--------
.. automethod:: dragon.vm.torch.nn.Swish.__init__
.. _torch.nn.functional.swish(...): functional/swish.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
......@@ -53,6 +53,13 @@ vm.torch.nn.functional
: Apply the group normalization to input.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`hardsigmoid(...) <functional/hardsigmoid.html>`_
: Apply the hard sigmoid function to input.
`hardswish(...) <functional/hardswish.html>`_
: Apply the hard swish function to input.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
`kl_div(...) <functional/kl_div.html>`_
: Compute the Kullback-Leibler divergence.
......@@ -113,6 +120,10 @@ vm.torch.nn.functional
`softmax(...) <functional/softmax.html>`_
: Apply the softmax function to input.
`swish(...) <functional/swish.html>`_
: Apply the swish function to input.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
`sync_batch_norm(...) <functional/sync_batch_norm.html>`_
: Apply the sync batch normalization to input.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
......@@ -145,6 +156,8 @@ vm.torch.nn.functional
functional/dropout
functional/elu
functional/group_norm
functional/hardsigmoid
functional/hardswish
functional/kl_div
functional/l1_loss
functional/leaky_relu
......@@ -165,6 +178,7 @@ vm.torch.nn.functional
functional/sigmoid_focal_loss
functional/smooth_l1_loss
functional/softmax
functional/swish
functional/sync_batch_norm
functional/tanh
functional/upsample
......
hardsigmoid
===========
.. autofunction:: dragon.vm.torch.nn.functional.hardsigmoid
.. _torch.nn.Hardsigmoid(...): ../Hardsigmoid.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
hardswish
=========
.. autofunction:: dragon.vm.torch.nn.functional.hardswish
.. _torch.nn.Hardswish(...): ../Hardswish.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
swish
=====
.. autofunction:: dragon.vm.torch.nn.functional.swish
.. _torch.nn.Swish(...): ../Swish.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
......@@ -25,9 +25,12 @@ template <>
__global__ void
_Elu<half>(const int nthreads, const float alpha, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 350
const float val = __half2float(__ldg(x + i));
y[i] = val > 0.f ? __ldg(x + i) : __float2half(alpha * (exp(val) - 1.f));
#else
const float val = __half2float(x[i]);
y[i] = val > 0.f ? x[i] : __float2half(alpha * (exp(val) - 1.f));
#endif
}
}
......@@ -36,12 +39,10 @@ template <>
__global__ void
_Elu<half2>(const int nthreads, const float alpha, const half2* x, half2* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float2 val = __half22float2(x[i]);
y[i] = __floats2half2_rn(
val.x > 0.f ? val.x : alpha * (exp(val.x) - 1.f),
val.y > 0.f ? val.y : alpha * (exp(val.y) - 1.f));
#endif
}
}
......@@ -69,10 +70,9 @@ __global__ void _EluGrad<half>(
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float val = __half2float(y[i]);
dx[i] = __hmul(dy[i], __float2half(val > 0.f ? 1.f : (alpha + val)));
#endif
dx[i] =
__float2half(__half2float(dy[i]) * (val > 0.f ? 1.f : (alpha + val)));
}
} // EluGrad
......@@ -84,14 +84,11 @@ __global__ void _EluGrad<half2>(
const half2* y,
half2* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float2 val = __half22float2(y[i]);
dx[i] = __hmul2(
dy[i],
__floats2half2_rn(
val.x > 0.f ? 1.f : (alpha + val.x),
val.y > 0.f ? 1.f : (alpha + val.y)));
#endif
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(
grad.x * (val.x > 0.f ? 1.f : (alpha + val.x)),
grad.y * (val.y > 0.f ? 1.f : (alpha + val.y)));
}
} // EluGrad
......
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _HardSigmoid(
const int count,
const T alpha,
const T beta,
const T* x,
T* y) {
EigenVectorArrayMap<T>(y, count) =
(ConstEigenVectorArrayMap<T>(x, count) * alpha + beta)
.cwiseMin(T(1))
.cwiseMax(T(0));
}
template <>
void _HardSigmoid<float16>(
const int count,
const float16 alpha,
const float16 beta,
const float16* x,
float16* y) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename T>
void _HardSigmoidGrad(
const int count,
const T alpha,
const T* dy,
const T* y,
T* dx) {
ConstEigenVectorArrayMap<T> Y(y, count);
EigenVectorArrayMap<T>(dx, count) =
(Y > T(0) && Y < T(1))
.select(ConstEigenVectorArrayMap<T>(dy, count) * alpha, T(0));
}
template <>
void _HardSigmoidGrad<float16>(
const int count,
const float16 alpha,
const float16* dy,
const float16* y,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void HardSigmoid<T, CPUContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_HardSigmoid(count, cast::to<T>(alpha), cast::to<T>(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void HardSigmoidGrad<T, CPUContext>( \
const int count, \
const float alpha, \
const T* dy, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
_HardSigmoidGrad(count, cast::to<T>(alpha), dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void _HardSigmoid(
const int nthreads,
const T alpha,
const T beta,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = max(T(0), min(T(1), fma(x[i], alpha, beta)));
}
}
__global__ void _HardSigmoid(
const int nthreads,
const float alpha,
const float beta,
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] =
__float2half(max(0.f, min(1.f, fma(__half2float(x[i]), alpha, beta))));
}
}
template <typename T>
__global__ void _HardSigmoidGrad(
const int nthreads,
const float alpha,
const T* dy,
const T* y,
T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = (y[i] > T(0) && y[i] < T(1)) ? dy[i] * alpha : T(0);
}
}
template <>
__global__ void _HardSigmoidGrad<half>(
const int nthreads,
const float alpha,
const half* dy,
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float val = __half2float(y[i]);
dx[i] = __half2float(
(val > 0.f && val < 1.f) ? __half2float(dy[i]) * alpha : 0.f);
}
} // HardSigmoidGrad
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void HardSigmoid<float16, CUDAContext>(
const int count,
const float alpha,
const float beta,
const float16* x,
float16* y,
CUDAContext* ctx) {
_HardSigmoid<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
alpha,
beta,
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
template <>
void HardSigmoidGrad<float16, CUDAContext>(
const int count,
const float alpha,
const float16* dy,
const float16* y,
float16* dx,
CUDAContext* ctx) {
_HardSigmoidGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
alpha,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y),
reinterpret_cast<half*>(dx));
} // HardSigmoidGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void HardSigmoid<T, CUDAContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_HardSigmoid<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, T(alpha), T(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void HardSigmoidGrad<T, CUDAContext>( \
const int count, \
const float alpha, \
const T* dy, \
const T* y, \
T* dx, \
CUDAContext* ctx) { \
_HardSigmoidGrad<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(count, alpha, dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _HardSwish(
const int count,
const T alpha,
const T beta,
const T* x,
T* y) {
ConstEigenVectorArrayMap<T> X(x, count);
EigenVectorArrayMap<T>(y, count) =
X * ((X * alpha + beta).cwiseMin(T(1)).cwiseMax(T(0)));
}
template <>
void _HardSwish<float16>(
const int count,
const float16 alpha,
const float16 beta,
const float16* x,
float16* y) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename T>
void _HardSwishGrad(
const int count,
const T alpha,
const T beta,
const T* dy,
const T* x,
T* dx) {
const auto bound = beta / alpha;
const auto alpha2x = alpha * T(2);
EigenVectorArrayMap<T>(dx, count) = ConstEigenVectorArrayMap<T>(dy, count) *
ConstEigenVectorArrayMap<T>(x, count).unaryExpr([&](T a) {
return (a < -bound) ? T(0) : (a < bound ? a * alpha2x + beta : T(1));
});
}
template <>
void _HardSwishGrad<float16>(
const int count,
const float16 alpha,
const float16 beta,
const float16* dy,
const float16* x,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void HardSwish<T, CPUContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_HardSwish(count, cast::to<T>(alpha), cast::to<T>(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void HardSwishGrad<T, CPUContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* dy, \
const T* x, \
T* dx, \
CPUContext* ctx) { \
_HardSwishGrad(count, cast::to<T>(alpha), cast::to<T>(beta), dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void
_HardSwish(const int nthreads, const T alpha, const T beta, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __ldg(x + i) * max(T(0), min(T(1), fma(__ldg(x + i), alpha, beta)));
#else
y[i] = x[i] * max(T(0), min(T(1), fma(x[i], alpha, beta)));
#endif
}
}
__global__ void _HardSwish(
const int nthreads,
const float alpha,
const float beta,
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __float2half(
__half2float(__ldg(x + i)) *
max(0.f, min(1.f, fma(__half2float(__ldg(x + i)), alpha, beta))));
#else
y[i] = __float2half(
__half2float(x[i]) *
max(0.f, min(1.f, fma(__half2float(x[i]), alpha, beta))));
#endif
}
}
template <typename T>
__global__ void _HardSwishGrad(
const int nthreads,
const T alpha,
const T beta,
const T* dy,
const T* x,
T* dx) {
const T bound = beta / alpha;
const T alpha2x = alpha * T(2);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
dx[i] = (__ldg(x + i) < -bound)
? T(0)
: (__ldg(x + i) < bound) ? dy[i] * fma(__ldg(x + i), alpha2x, beta)
: dy[i];
#else
dx[i] = (x[i] < -bound)
? T(0)
: (x[i] < bound) ? dy[i] * fma(x[i], alpha2x, beta) : dy[i];
#endif
}
}
__global__ void _HardSwishGrad(
const int nthreads,
const float alpha,
const float beta,
const half* dy,
const half* x,
half* dx) {
const float bound = beta / alpha;
const float alpha2x = alpha * 2.f;
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float val = __half2float(x[i]);
dx[i] = (val < -bound) ? kZero
: (val < bound)
? __float2half(__half2float(dy[i]) * fma(val, alpha2x, beta))
: dy[i];
}
} // HardSwishGrad
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void HardSwish<float16, CUDAContext>(
const int count,
const float alpha,
const float beta,
const float16* x,
float16* y,
CUDAContext* ctx) {
_HardSwish<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
alpha,
beta,
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
template <>
void HardSwishGrad<float16, CUDAContext>(
const int count,
const float alpha,
const float beta,
const float16* dy,
const float16* x,
float16* dx,
CUDAContext* ctx) {
_HardSwishGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
alpha,
beta,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(dx));
} // HardSwishGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void HardSwish<T, CUDAContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_HardSwish<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, T(alpha), T(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void HardSwishGrad<T, CUDAContext>( \
const int count, \
const float alpha, \
const float beta, \
const T* dy, \
const T* x, \
T* dx, \
CUDAContext* ctx) { \
_HardSwishGrad<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(count, T(alpha), T(beta), dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
......@@ -10,9 +10,9 @@ namespace {
template <typename T>
void _Relu(const int count, const T alpha, const T* x, T* y) {
ConstEigenVectorArrayMap<T> X(x, count);
EigenVectorArrayMap<T>(y, count) =
ConstEigenVectorArrayMap<T>(x, count).unaryExpr(
[&](T a) { return a > T(0) ? a : alpha * a; });
X.cwiseMax(T(0)) + X.cwiseMin(T(0)) * alpha;
}
template <>
......
......@@ -10,8 +10,8 @@ namespace {
template <typename T>
void _Softmax(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
T* y) {
int row_offset, col_offset, yi;
......@@ -45,8 +45,8 @@ void _Softmax(
template <>
void _Softmax<float16>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* x,
float16* y) {
CPU_FP16_NOT_SUPPORTED;
......@@ -55,8 +55,8 @@ void _Softmax<float16>(
template <typename T>
void _SoftmaxGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* dy,
const T* y,
T* dx) {
......@@ -82,8 +82,8 @@ void _SoftmaxGrad(
template <>
void _SoftmaxGrad<float16>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* dy,
const float16* y,
float16* dx) {
......@@ -98,25 +98,25 @@ void _SoftmaxGrad<float16>(
template <> \
void Softmax<T, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_Softmax(outer_dim, axis_dim, inner_dim, x, y); \
_Softmax(outer_dim, inner_dim, axis_dim, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void SoftmaxGrad<T, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* dy, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
_SoftmaxGrad(outer_dim, axis_dim, inner_dim, dy, y, dx); \
_SoftmaxGrad(outer_dim, inner_dim, axis_dim, dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
......
......@@ -185,8 +185,8 @@ __global__ void _SoftmaxGrad<half>(
template <>
void Softmax<float16, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* x,
float16* y,
CUDAContext* ctx) {
......@@ -203,8 +203,8 @@ void Softmax<float16, CUDAContext>(
template <>
void SoftmaxGrad<float16, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* dy,
const float16* y,
float16* dx,
......@@ -223,8 +223,8 @@ void SoftmaxGrad<float16, CUDAContext>(
template <> \
void Softmax<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
......@@ -237,8 +237,8 @@ void SoftmaxGrad<float16, CUDAContext>(
template <> \
void SoftmaxGrad<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* dy, \
const T* y, \
T* dx, \
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _Swish(const int count, const T* x, T* y) {
ConstEigenVectorArrayMap<T> X(x, count);
EigenVectorArrayMap<T>(y, count) = X / (T(1) + (-X).exp());
}
template <>
void _Swish<float16>(const int count, const float16* x, float16* y) {
CPU_FP16_NOT_SUPPORTED;
}
template <typename T>
void _SwishGrad(const int count, const T* dy, const T* x, const T* y, T* dx) {
ConstEigenVectorArrayMap<T> X(x, count);
ConstEigenVectorArrayMap<T> Y(y, count);
EigenVectorArrayMap<T>(dx, count) = ConstEigenVectorArrayMap<T>(dy, count) *
(Y + (T(1) / (T(1) + (-X).exp())) * (T(1) - Y));
}
template <>
void _SwishGrad<float16>(
const int count,
const float16* dy,
const float16* x,
const float16* y,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Swish<T, CPUContext>( \
const int count, const T* x, T* y, CPUContext* ctx) { \
_Swish(count, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void SwishGrad<T, CPUContext>( \
const int count, \
const T* dy, \
const T* x, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
_SwishGrad(count, dy, x, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void _Swish(const int nthreads, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __ldg(x + i) / (T(1) + exp(-__ldg(x + i)));
#else
y[i] = x[i] / (T(1) + exp(-x[i]));
#endif
}
}
template <>
__global__ void _Swish<half>(const int nthreads, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = __float2half(
__half2float(__ldg(x + i)) / (1.f + exp(-__half2float(__ldg(x + i)))));
#else
y[i] = __float2half(__half2float(x[i]) / (1.f + exp(-__half2float(x[i]))));
#endif
}
}
template <typename T>
__global__ void
_SwishGrad(const int nthreads, const T* dy, const T* x, const T* y, T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
dx[i] =
dy[i] * (__ldg(y + i) + (T(1) - __ldg(y + i)) / (T(1) + exp(-x[i])));
#else
dx[i] = dy[i] * (y[i] + (T(1) - y[i]) / (T(1) + exp(-x[i])));
#endif
}
}
template <>
__global__ void _SwishGrad<half>(
const int nthreads,
const half* dy,
const half* x,
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(__ldg(y + i)) +
(1.f - __half2float(__ldg(y + i))) /
(1.f + exp(-__half2float(x[i])))));
#else
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(y[i]) +
(1.f - __half2float(y[i])) / (1.f + exp(-__half2float(x[i])))));
#endif
}
} // SwishGrad
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void Swish<float16, CUDAContext>(
const int count,
const float16* x,
float16* y,
CUDAContext* ctx) {
_Swish<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, reinterpret_cast<const half*>(x), reinterpret_cast<half*>(y));
}
template <>
void SwishGrad<float16, CUDAContext>(
const int count,
const float16* dy,
const float16* x,
const float16* y,
float16* dx,
CUDAContext* ctx) {
_SwishGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(y),
reinterpret_cast<half*>(dx));
} // SwishGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Swish<T, CUDAContext>( \
const int count, const T* x, T* y, CUDAContext* ctx) { \
_Swish<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void SwishGrad<T, CUDAContext>( \
const int count, \
const T* dy, \
const T* x, \
const T* y, \
T* dx, \
CUDAContext* ctx) { \
_SwishGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, dy, x, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
......@@ -31,8 +31,8 @@ void _ChannelAffine(
template <typename T>
void _ChannelAffine(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
const T* w,
const T* b,
......@@ -59,8 +59,8 @@ void _ChannelAffine(
template <>
void ChannelAffine<float16, CPUContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* x,
const float16* w,
const float16* b,
......@@ -73,8 +73,8 @@ void ChannelAffine<float16, CPUContext>(
template <> \
void ChannelAffine<T, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* x, \
const T* w, \
const T* b, \
......@@ -83,7 +83,7 @@ void ChannelAffine<float16, CPUContext>(
if (inner_dim == 1) { \
_ChannelAffine(outer_dim, axis_dim, x, w, b, y); \
} else { \
_ChannelAffine(outer_dim, axis_dim, inner_dim, x, w, b, y); \
_ChannelAffine(outer_dim, inner_dim, axis_dim, x, w, b, y); \
} \
}
......@@ -93,6 +93,7 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -12,8 +12,8 @@ namespace {
template <typename T>
__global__ void _ChannelAffine(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
const T* w,
T* y) {
......@@ -29,8 +29,8 @@ __global__ void _ChannelAffine(
template <>
__global__ void _ChannelAffine<half>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const half* x,
const half* w,
half* y) {
......@@ -51,8 +51,8 @@ __global__ void _ChannelAffine<half>(
template <typename T>
__global__ void _ChannelAffine(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
const T* w,
const T* b,
......@@ -70,8 +70,8 @@ __global__ void _ChannelAffine(
template <>
__global__ void _ChannelAffine<half>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const half* x,
const half* w,
const half* b,
......@@ -95,8 +95,8 @@ __global__ void _ChannelAffine<half>(
template <>
__global__ void _ChannelAffine<float>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float* x,
const float* w,
const float* b,
......@@ -114,8 +114,8 @@ __global__ void _ChannelAffine<float>(
template <>
__global__ void _ChannelAffine<double>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const double* x,
const double* w,
const double* b,
......@@ -137,14 +137,14 @@ __global__ void _ChannelAffine<double>(
template <>
void ChannelAffine<float16, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* x,
const float16* w,
const float16* b,
float16* y,
CUDAContext* ctx) {
const int nthreads = outer_dim * axis_dim * inner_dim;
const auto nthreads = outer_dim * axis_dim * inner_dim;
if (b != nullptr) {
_ChannelAffine<<<
CUDA_BLOCKS(nthreads),
......@@ -152,8 +152,8 @@ void ChannelAffine<float16, CUDAContext>(
0,
ctx->cuda_stream()>>>(
nthreads,
axis_dim,
inner_dim,
axis_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(w),
reinterpret_cast<const half*>(b),
......@@ -165,8 +165,8 @@ void ChannelAffine<float16, CUDAContext>(
0,
ctx->cuda_stream()>>>(
nthreads,
axis_dim,
inner_dim,
axis_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(w),
reinterpret_cast<half*>(y));
......@@ -177,26 +177,26 @@ void ChannelAffine<float16, CUDAContext>(
template <> \
void ChannelAffine<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* x, \
const T* w, \
const T* b, \
T* y, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \
const auto nthreads = outer_dim * axis_dim * inner_dim; \
if (b != nullptr) { \
_ChannelAffine<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(nthreads, axis_dim, inner_dim, x, w, b, y); \
ctx->cuda_stream()>>>(nthreads, inner_dim, axis_dim, x, w, b, y); \
} else { \
_ChannelAffine<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(nthreads, axis_dim, inner_dim, x, w, y); \
ctx->cuda_stream()>>>(nthreads, inner_dim, axis_dim, x, w, y); \
} \
}
......@@ -206,6 +206,7 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -10,8 +10,8 @@ namespace {
template <typename T>
void _CumSum(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const bool exclusive,
const T* x,
T* y,
......@@ -33,8 +33,8 @@ void _CumSum(
template <>
void _CumSum<float16>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const bool exclusive,
const float16* x,
float16* y,
......@@ -45,8 +45,8 @@ void _CumSum<float16>(
template <typename T>
void _CumSumReverse(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const bool exclusive,
const T* x,
T* y,
......@@ -72,8 +72,8 @@ void _CumSumReverse(
template <>
void _CumSumReverse<float16>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const bool exclusive,
const float16* x,
float16* y,
......@@ -89,17 +89,17 @@ void _CumSumReverse<float16>(
template <> \
void CumSum<T, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const bool exclusive, \
const bool reverse, \
const T* x, \
T* y, \
CPUContext* ctx) { \
if (reverse) { \
_CumSumReverse(outer_dim, axis_dim, inner_dim, exclusive, x, y, ctx); \
_CumSumReverse(outer_dim, inner_dim, axis_dim, exclusive, x, y, ctx); \
} else { \
_CumSum(outer_dim, axis_dim, inner_dim, exclusive, x, y, ctx); \
_CumSum(outer_dim, inner_dim, axis_dim, exclusive, x, y, ctx); \
} \
}
......@@ -110,7 +110,6 @@ DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -98,8 +98,8 @@ __global__ void _CumSumReverse<half>(
template <>
void CumSum<float16, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const bool exclusive,
const bool reverse,
const float16* x,
......@@ -129,8 +129,8 @@ void CumSum<float16, CUDAContext>(
template <> \
void CumSum<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const bool exclusive, \
const bool reverse, \
const T* x, \
......@@ -155,7 +155,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -13,16 +13,16 @@ void _IndexSelect(
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* indices,
const int64_t* index,
const T* x,
T* y,
CPUContext* ctx) {
int index;
int pos;
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < select_dim; ++j) {
index = indices[j];
index = index >= 0 ? index : index + axis_dim;
const T* offset_x = x + (i * axis_dim + index) * inner_dim;
pos = index[j];
pos = pos >= 0 ? pos : pos + axis_dim;
const T* offset_x = x + (i * axis_dim + pos) * inner_dim;
math::Copy(inner_dim, offset_x, y, ctx);
y += inner_dim;
}
......@@ -35,16 +35,16 @@ void _IndexSelectGrad(
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* indices,
const int64_t* index,
const T* dy,
T* dx,
CPUContext* ctx) {
int index;
int pos;
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < select_dim; ++j) {
index = indices[j];
index = index >= 0 ? index : index + axis_dim;
T* offset_dx = dx + (i * axis_dim + index) * inner_dim;
pos = index[j];
pos = pos >= 0 ? pos : pos + axis_dim;
T* offset_dx = dx + (i * axis_dim + pos) * inner_dim;
math::Add(inner_dim, dy, offset_dx, offset_dx, ctx);
dy += inner_dim;
}
......@@ -62,11 +62,11 @@ void _IndexSelectGrad(
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int64_t* indices, \
const int64_t* index, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_##name(outer_dim, inner_dim, axis_dim, select_dim, indices, x, y, ctx); \
_##name(outer_dim, inner_dim, axis_dim, select_dim, index, x, y, ctx); \
}
DEFINE_KERNEL_LAUNCHER(IndexSelect, bool);
......
......@@ -15,19 +15,19 @@ __global__ void _IndexSelect(
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* indices,
const int64_t* index,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int j = yi % inner_dim;
const int i = yi / inner_dim / select_dim;
#if __CUDA_ARCH__ >= 350
int index = __ldg(indices + ((yi / inner_dim) % select_dim));
int pos = __ldg(index + ((yi / inner_dim) % select_dim));
#else
int index = indices[(yi / inner_dim) % select_dim];
int pos = index[(yi / inner_dim) % select_dim];
#endif
index = index >= 0 ? index : index + axis_dim;
y[yi] = x[(i * axis_dim + index) * inner_dim + j];
pos = pos >= 0 ? pos : pos + axis_dim;
y[yi] = x[(i * axis_dim + pos) * inner_dim + j];
}
}
......@@ -37,22 +37,22 @@ __global__ void _IndexSelectGrad(
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* indices,
const int64_t* index,
const T* dy,
T* dx) {
CUDA_1D_KERNEL_LOOP(ti, nthreads) {
const int i = ti / inner_dim;
const int j = ti % inner_dim;
const int c = i * axis_dim * inner_dim + j;
const int x_offset = i * axis_dim * inner_dim + j;
const T* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < select_dim; ++k) {
#if __CUDA_ARCH__ >= 350
int index = __ldg(indices + k);
int pos = __ldg(index + k);
#else
int index = indices[k];
int pos = index[k];
#endif
index = index >= 0 ? index : index + axis_dim;
dx[c + index * inner_dim] += (*offset_dy);
pos = pos >= 0 ? pos : pos + axis_dim;
dx[x_offset + pos * inner_dim] += (*offset_dy);
offset_dy += inner_dim;
}
}
......@@ -64,23 +64,30 @@ __global__ void _IndexSelectGrad<half>(
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* indices,
const int64_t* index,
const half* dy,
half* dx) {
CUDA_1D_KERNEL_LOOP(ti, nthreads) {
#if __CUDA_ARCH__ >= 530
const int i = ti / inner_dim;
const int j = ti % inner_dim;
const int c = i * axis_dim * inner_dim + j;
const int x_offset = i * axis_dim * inner_dim + j;
const half* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < select_dim; ++k) {
int index = __ldg(indices + j);
index = index >= 0 ? index : index + axis_dim;
index = c + index * inner_dim;
dx[index] = __hadd(dx[index], *(offset_dy));
#if __CUDA_ARCH__ >= 350
int pos = __ldg(index + k);
#else
int pos = index[k];
#endif
pos = pos >= 0 ? pos : pos + axis_dim;
pos = x_offset + pos * inner_dim;
#if __CUDA_ARCH__ >= 530
dx[pos] = __hadd(dx[pos], *(offset_dy));
#else
dx[pos] =
__float2half(__half2float(dx[pos]) + __half2float(*(offset_dy)));
#endif
offset_dy += inner_dim;
}
#endif
}
}
......@@ -94,7 +101,7 @@ void IndexSelectGrad<float16, CUDAContext>(
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* indices,
const int64_t* index,
const float16* dy,
float16* dx,
CUDAContext* ctx) {
......@@ -108,7 +115,7 @@ void IndexSelectGrad<float16, CUDAContext>(
inner_dim,
axis_dim,
select_dim,
indices,
index,
reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx));
} // IndexSelectGrad
......@@ -120,7 +127,7 @@ void IndexSelectGrad<float16, CUDAContext>(
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int64_t* indices, \
const int64_t* index, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
......@@ -130,7 +137,7 @@ void IndexSelectGrad<float16, CUDAContext>(
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, select_dim, indices, x, y); \
nthreads, inner_dim, axis_dim, select_dim, index, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -140,7 +147,7 @@ void IndexSelectGrad<float16, CUDAContext>(
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int64_t* indices, \
const int64_t* index, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
......@@ -150,7 +157,7 @@ void IndexSelectGrad<float16, CUDAContext>(
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, select_dim, indices, dy, dx); \
nthreads, inner_dim, axis_dim, select_dim, index, dy, dx); \
}
DEFINE_KERNEL_LAUNCHER(bool);
......
......@@ -10,8 +10,8 @@ namespace {
template <typename T>
void _BroadcastLossGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* dy,
T* dx) {
std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim};
......@@ -52,8 +52,8 @@ void ReduceLossGrad<float16, CPUContext>(
template <>
void BroadcastLossGrad<float16, CPUContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* dy,
float16* dx,
CPUContext* ctx) {
......@@ -98,12 +98,12 @@ void BroadcastLossGrad<float16, CPUContext>(
template <> \
void BroadcastLossGrad<T, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* dy, \
T* dx, \
CPUContext* ctx) { \
_BroadcastLossGrad(outer_dim, axis_dim, inner_dim, dy, dx); \
_BroadcastLossGrad(outer_dim, inner_dim, axis_dim, dy, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
......
......@@ -146,12 +146,12 @@ void ReduceLossGrad<float16, CUDAContext>(
template <>
void BroadcastLossGrad<float16, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* dy,
float16* dx,
CUDAContext* ctx) {
auto nthreads = outer_dim * axis_dim * inner_dim;
const auto nthreads = outer_dim * axis_dim * inner_dim;
_BroadcastLossGrad<<<
CUDA_BLOCKS(nthreads),
CUDA_THREADS,
......@@ -214,12 +214,12 @@ void BroadcastLossGrad<float16, CUDAContext>(
template <> \
void BroadcastLossGrad<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
auto nthreads = outer_dim * axis_dim * inner_dim; \
const auto nthreads = outer_dim * axis_dim * inner_dim; \
_BroadcastLossGrad<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
......
......@@ -10,10 +10,10 @@ namespace {
template <typename LogitType, typename TargetType>
void _NLLLoss(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* log_prob,
const LogitType* logit,
const TargetType* target,
LogitType* loss,
LogitType* mask) {
......@@ -26,7 +26,7 @@ void _NLLLoss(
loss[i] = mask[i] = LogitType(0);
} else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
loss[i] = -log_prob[k], mask[i] = LogitType(1);
loss[i] = -logit[k], mask[i] = LogitType(1);
}
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
}
......@@ -35,12 +35,12 @@ void _NLLLoss(
template <typename LogitType, typename TargetType>
void _NLLLossGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* log_prob,
const LogitType* logit,
const TargetType* target,
LogitType* dx,
LogitType* dlogit,
LogitType* mask) {
std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim};
......@@ -51,7 +51,7 @@ void _NLLLossGrad(
mask[i] = LogitType(0);
} else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
dx[k] = LogitType(-1), mask[i] = LogitType(1);
dlogit[k] = LogitType(-1), mask[i] = LogitType(1);
}
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
}
......@@ -65,20 +65,20 @@ void _NLLLossGrad(
template <> \
void name<LogitType, TargetType, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const int ignore_index, \
const LogitType* log_prob, \
const LogitType* logit, \
const TargetType* target, \
LogitType* loss, \
LogitType* mask, \
CPUContext* ctx) { \
_##name( \
outer_dim, \
axis_dim, \
inner_dim, \
axis_dim, \
ignore_index, \
log_prob, \
logit, \
target, \
loss, \
mask); \
......
......@@ -12,10 +12,10 @@ namespace {
template <typename LogitType, typename TargetType>
__global__ void _NLLLoss(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* log_prob,
const LogitType* logit,
const TargetType* target,
LogitType* loss,
LogitType* mask) {
......@@ -26,7 +26,7 @@ __global__ void _NLLLoss(
if (label == ignore_index) {
loss[yi] = mask[yi] = LogitType(0);
} else {
loss[yi] = -log_prob[(i * axis_dim + label) * inner_dim + j];
loss[yi] = -logit[(i * axis_dim + label) * inner_dim + j];
mask[yi] = LogitType(1);
}
}
......@@ -35,12 +35,12 @@ __global__ void _NLLLoss(
template <typename LogitType, typename TargetType>
__global__ void _NLLLossGrad(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* log_prob,
const LogitType* logit,
const TargetType* target,
LogitType* dx,
LogitType* dlogit,
LogitType* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim;
......@@ -49,7 +49,7 @@ __global__ void _NLLLossGrad(
if (label == ignore_index) {
mask[yi] = LogitType(0);
} else {
dx[(i * axis_dim + label) * inner_dim + j] = LogitType(-1);
dlogit[(i * axis_dim + label) * inner_dim + j] = LogitType(-1);
mask[yi] = LogitType(1);
}
}
......@@ -63,21 +63,21 @@ __global__ void _NLLLossGrad(
template <> \
void name<LogitType, TargetType, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const int ignore_index, \
const LogitType* log_prob, \
const LogitType* logit, \
const TargetType* target, \
LogitType* loss, \
LogitType* mask, \
CUDAContext* ctx) { \
auto nthreads = outer_dim * inner_dim; \
const auto nthreads = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
axis_dim, \
inner_dim, \
axis_dim, \
ignore_index, \
log_prob, \
logit, \
target, \
loss, \
mask); \
......
......@@ -10,8 +10,8 @@ namespace {
template <typename LogitType, typename TargetType>
void _SigmoidFocalLoss(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const LogitType pos_alpha,
const LogitType neg_alpha,
const LogitType gamma,
......@@ -55,8 +55,8 @@ void _SigmoidFocalLoss(
template <typename LogitType, typename TargetType>
void _SigmoidFocalLossGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const LogitType pos_alpha,
const LogitType neg_alpha,
const LogitType gamma,
......@@ -108,8 +108,8 @@ void _SigmoidFocalLossGrad(
template <> \
void name<LogitType, TargetType, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const float pos_alpha, \
const float neg_alpha, \
const float gamma, \
......@@ -121,8 +121,8 @@ void _SigmoidFocalLossGrad(
CPUContext* ctx) { \
_##name( \
outer_dim, \
axis_dim, \
inner_dim, \
axis_dim, \
(LogitType)pos_alpha, \
(LogitType)neg_alpha, \
(LogitType)gamma, \
......
......@@ -12,8 +12,8 @@ namespace {
template <typename LogitType, typename TargetType>
__global__ void _SigmoidFocalLoss(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const LogitType pos_alpha,
const LogitType neg_alpha,
const LogitType gamma,
......@@ -53,8 +53,8 @@ __global__ void _SigmoidFocalLoss(
template <typename LogitType, typename TargetType>
__global__ void _SigmoidFocalLossGrad(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const LogitType pos_alpha,
const LogitType neg_alpha,
const LogitType gamma,
......@@ -102,8 +102,8 @@ __global__ void _SigmoidFocalLossGrad(
template <> \
void name<LogitType, TargetType, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const float pos_alpha, \
const float neg_alpha, \
const float gamma, \
......@@ -113,11 +113,11 @@ __global__ void _SigmoidFocalLossGrad(
LogitType* loss, \
LogitType* mask, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \
const auto nthreads = outer_dim * axis_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
axis_dim, \
inner_dim, \
axis_dim, \
(LogitType)pos_alpha, \
(LogitType)neg_alpha, \
(LogitType)gamma, \
......
......@@ -10,8 +10,8 @@ namespace {
template <typename LogitType, typename TargetType>
void _SparseSoftmaxCrossEntropy(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* prob,
const TargetType* target,
......@@ -36,8 +36,8 @@ void _SparseSoftmaxCrossEntropy(
template <typename LogitType, typename TargetType>
void _SparseSoftmaxCrossEntropyGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* prob,
const TargetType* target,
......@@ -72,8 +72,8 @@ void _SparseSoftmaxCrossEntropyGrad(
template <> \
void name<LogitType, TargetType, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const int ignore_index, \
const LogitType* prob, \
const TargetType* target, \
......@@ -82,8 +82,8 @@ void _SparseSoftmaxCrossEntropyGrad(
CPUContext* ctx) { \
_##name( \
outer_dim, \
axis_dim, \
inner_dim, \
axis_dim, \
ignore_index, \
prob, \
target, \
......
......@@ -12,8 +12,8 @@ namespace {
template <typename LogitType, typename TargetType>
__global__ void _SparseSoftmaxCrossEntropy(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* prob,
const TargetType* target,
......@@ -36,8 +36,8 @@ __global__ void _SparseSoftmaxCrossEntropy(
template <typename LogitType, typename TargetType>
__global__ void _SparseSoftmaxCrossEntropyGrad(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* prob,
const TargetType* target,
......@@ -69,19 +69,19 @@ __global__ void _SparseSoftmaxCrossEntropyGrad(
template <> \
void name<LogitType, TargetType, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const int ignore_index, \
const LogitType* prob, \
const TargetType* target, \
LogitType* loss, \
LogitType* mask, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * inner_dim; \
const auto nthreads = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
axis_dim, \
inner_dim, \
axis_dim, \
ignore_index, \
prob, \
target, \
......
......@@ -10,8 +10,8 @@ namespace {
template <typename T>
void _L1Normalize(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const T* x,
......@@ -32,8 +32,8 @@ void _L1Normalize(
template <typename T>
void _L2Normalize(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const T* x,
......@@ -54,8 +54,8 @@ void _L2Normalize(
template <typename T>
void _L1NormalizeGrad(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const T* dy,
......@@ -81,8 +81,8 @@ void _L1NormalizeGrad(
template <typename T>
void _L2NormalizeGrad(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const T* dy,
......@@ -112,8 +112,8 @@ void _L2NormalizeGrad(
template <>
void L1Normalize<float16, CPUContext>(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const float16* x,
......@@ -125,8 +125,8 @@ void L1Normalize<float16, CPUContext>(
template <>
void L2Normalize<float16, CPUContext>(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const float16* x,
......@@ -138,8 +138,8 @@ void L2Normalize<float16, CPUContext>(
template <>
void L1NormalizeGrad<float16, CPUContext>(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const float16* dy,
......@@ -152,8 +152,8 @@ void L1NormalizeGrad<float16, CPUContext>(
template <>
void L2NormalizeGrad<float16, CPUContext>(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const float16* dy,
......@@ -167,14 +167,14 @@ void L2NormalizeGrad<float16, CPUContext>(
template <> \
void name<T, CPUContext>( \
const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \
const int reduce_dim, \
const float scale, \
const float eps, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_##name(outer_dim, reduce_dim, inner_dim, (T)scale, (T)eps, x, y); \
_##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, x, y); \
}
DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
......@@ -187,15 +187,15 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
template <> \
void name<T, CPUContext>( \
const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \
const int reduce_dim, \
const float scale, \
const float eps, \
const T* dy, \
const T* x, \
T* dx, \
CPUContext* ctx) { \
_##name(outer_dim, reduce_dim, inner_dim, (T)scale, (T)eps, dy, x, dx); \
_##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, dy, x, dx); \
}
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float);
......
......@@ -14,8 +14,8 @@ namespace {
template <typename T>
__global__ void _L1Normalize(
const int nblocks,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const T* x,
......@@ -42,8 +42,8 @@ __global__ void _L1Normalize(
__global__ void _L1Normalize(
const int nblocks,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const half* x,
......@@ -71,8 +71,8 @@ __global__ void _L1Normalize(
template <typename T>
__global__ void _L2Normalize(
const int nblocks,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const T* x,
......@@ -99,8 +99,8 @@ __global__ void _L2Normalize(
__global__ void _L2Normalize(
const int nblocks,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const half* x,
......@@ -128,8 +128,8 @@ __global__ void _L2Normalize(
template <typename T>
__global__ void _L1NormalizeGrad(
const int nblocks,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const T* dy,
......@@ -162,8 +162,8 @@ __global__ void _L1NormalizeGrad(
__global__ void _L1NormalizeGrad(
const int nblocks,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const half* dy,
......@@ -199,8 +199,8 @@ __global__ void _L1NormalizeGrad(
template <typename T>
__global__ void _L2NormalizeGrad(
const int nblocks,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const T* dy,
......@@ -233,8 +233,8 @@ __global__ void _L2NormalizeGrad(
__global__ void _L2NormalizeGrad(
const int nblocks,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const half* dy,
......@@ -275,8 +275,8 @@ __global__ void _L2NormalizeGrad(
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \
const int reduce_dim, \
const float scale, \
const float eps, \
const float16* x, \
......@@ -285,8 +285,8 @@ __global__ void _L2NormalizeGrad(
const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, \
reduce_dim, \
inner_dim, \
reduce_dim, \
scale, \
eps, \
reinterpret_cast<const half*>(x), \
......@@ -301,8 +301,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, float16);
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \
const int reduce_dim, \
const float scale, \
const float eps, \
const T* x, \
......@@ -310,7 +310,7 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, float16);
CUDAContext* ctx) { \
const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, reduce_dim, inner_dim, (T)scale, (T)eps, x, y); \
nblocks, inner_dim, reduce_dim, (T)scale, (T)eps, x, y); \
}
DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
......@@ -323,8 +323,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \
const int reduce_dim, \
const float scale, \
const float eps, \
const float16* dy, \
......@@ -334,8 +334,8 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, \
reduce_dim, \
inner_dim, \
reduce_dim, \
scale, \
eps, \
reinterpret_cast<const half*>(dy), \
......@@ -351,8 +351,8 @@ DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16);
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
const int reduce_dim, \
const int inner_dim, \
const int reduce_dim, \
const float scale, \
const float eps, \
const T* dy, \
......@@ -361,7 +361,7 @@ DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16);
CUDAContext* ctx) { \
const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, reduce_dim, inner_dim, (T)scale, (T)eps, dy, x, dx); \
nblocks, inner_dim, reduce_dim, (T)scale, (T)eps, dy, x, dx); \
}
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float);
......
......@@ -22,8 +22,8 @@ void _BiasAdd(
template <typename T>
void _BiasAdd(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
const T* b,
T* y) {
......@@ -44,8 +44,8 @@ void _BiasAdd(
template <>
void BiasAdd<float16, CPUContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* x,
const float16* b,
float16* y,
......@@ -57,8 +57,8 @@ void BiasAdd<float16, CPUContext>(
template <> \
void BiasAdd<T, CPUContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* x, \
const T* b, \
T* y, \
......@@ -66,7 +66,7 @@ void BiasAdd<float16, CPUContext>(
if (inner_dim == 1) { \
_BiasAdd(outer_dim, axis_dim, x, b, y); \
} else { \
_BiasAdd(outer_dim, axis_dim, inner_dim, x, b, y); \
_BiasAdd(outer_dim, inner_dim, axis_dim, x, b, y); \
} \
}
......
......@@ -38,8 +38,8 @@ __global__ void _BiasAdd<half>(
template <typename T>
__global__ void _BiasAdd(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
const T* b,
T* y) {
......@@ -55,8 +55,8 @@ __global__ void _BiasAdd(
template <>
__global__ void _BiasAdd<half>(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const half* x,
const half* b,
half* y) {
......@@ -74,13 +74,13 @@ __global__ void _BiasAdd<half>(
template <>
void BiasAdd<float16, CUDAContext>(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float16* x,
const float16* b,
float16* y,
CUDAContext* ctx) {
const int nthreads = outer_dim * axis_dim * inner_dim;
const auto nthreads = outer_dim * axis_dim * inner_dim;
if (inner_dim == 1) {
_BiasAdd<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
......@@ -91,8 +91,8 @@ void BiasAdd<float16, CUDAContext>(
} else {
_BiasAdd<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
axis_dim,
inner_dim,
axis_dim,
reinterpret_cast<const half*>(x),
reinterpret_cast<const half*>(b),
reinterpret_cast<half*>(y));
......@@ -103,13 +103,13 @@ void BiasAdd<float16, CUDAContext>(
template <> \
void BiasAdd<T, CUDAContext>( \
const int outer_dim, \
const int axis_dim, \
const int inner_dim, \
const int axis_dim, \
const T* x, \
const T* b, \
T* y, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * axis_dim * inner_dim; \
const auto nthreads = outer_dim * axis_dim * inner_dim; \
if (inner_dim == 1) { \
_BiasAdd<<< \
CUDA_BLOCKS(nthreads), \
......@@ -121,7 +121,7 @@ void BiasAdd<float16, CUDAContext>(
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(nthreads, axis_dim, inner_dim, x, b, y); \
ctx->cuda_stream()>>>(nthreads, inner_dim, axis_dim, x, b, y); \
} \
}
......
#include "dragon/operators/activation/hardsigmoid_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void HardSigmoidOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
kernel::HardSigmoid(
X.count(),
alpha_,
beta_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void HardSigmoidOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void HardSigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
kernel::HardSigmoidGrad(
Y.count(),
alpha_,
dY.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(Y)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void HardSigmoidGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(HardSigmoid);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSigmoid);
#endif
DEPLOY_CPU_OPERATOR(HardSigmoidGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSigmoidGradient);
#endif
OPERATOR_SCHEMA(HardSigmoid)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(HardSigmoidGradient)
/* Y, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1)
/* dY => dX */
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(HardSigmoid, InplaceGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_HARDSIGMOID_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_HARDSIGMOID_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class HardSigmoidOp : public Operator<Context> {
public:
HardSigmoidOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)),
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
float alpha_, beta_;
};
template <class Context>
class HardSigmoidGradientOp : public Operator<Context> {
public:
HardSigmoidGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
float alpha_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_HARDSIGMOID_OP_H_
#include "dragon/operators/activation/hardswish_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void HardSwishOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
kernel::HardSwish(
X.count(),
alpha_,
beta_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void HardSwishOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void HardSwishGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(1), *dX = Output(0);
kernel::HardSwishGrad(
X.count(),
alpha_,
beta_,
dY.template data<T, Context>(),
X.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void HardSwishGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(HardSwish);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSwish);
#endif
DEPLOY_CPU_OPERATOR(HardSwishGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSwishGradient);
#endif
OPERATOR_SCHEMA(HardSwish)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(HardSwishGradient)
/* X, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(HardSwish, GenericGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_HARDSWISH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_HARDSWISH_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class HardSwishOp : public Operator<Context> {
public:
HardSwishOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)),
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
float alpha_, beta_;
};
template <class Context>
class HardSwishGradientOp : public Operator<Context> {
public:
HardSwishGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)),
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
float alpha_, beta_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_HARDSWISH_OP_H_
......@@ -12,8 +12,8 @@ void SoftmaxOp<Context>::DoRunWithType() {
CANONICALIZE_AXIS_WITH_TENSOR(X);
kernel::Softmax(
X.count(0, axis),
X.dim(axis),
X.count(axis + 1),
X.dim(axis),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
......@@ -31,8 +31,8 @@ void SoftmaxGradientOp<Context>::DoRunWithType() {
CANONICALIZE_AXIS_WITH_TENSOR(Y);
kernel::SoftmaxGrad(
Y.count(0, axis),
Y.dim(axis),
Y.count(axis + 1),
Y.dim(axis),
dY.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(Y)->template mutable_data<T, Context>(),
......
#include "dragon/operators/activation/swish_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void SwishOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
kernel::Swish(
X.count(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void SwishOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void SwishGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &Y = Input(1);
auto &dY = Input(2), *dX = Output(0);
kernel::SwishGrad(
X.count(),
dY.template data<T, Context>(),
X.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void SwishGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Swish);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Swish);
#endif
DEPLOY_CPU_OPERATOR(SwishGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SwishGradient);
#endif
OPERATOR_SCHEMA(Swish)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(SwishGradient)
/* X, Y, dY */
.NumInputs(3)
/* dX */
.NumOutputs(1);
namespace {
class GradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
"",
vector<string>({I(0), O(0), GO(0)}),
vector<string>({GI(0)}));
}
};
} // namespace
REGISTER_GRADIENT(Swish, GradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class SwishOp : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(SwishOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
};
template <class Context>
class SwishGradientOp : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(SwishGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_
......@@ -39,8 +39,8 @@ void ChannelAffineOp<Context>::DoRunWithType() {
kernel::ChannelAffine(
X.count(0, axis),
X.count(axis, axis + num_axes),
X.count(axis + num_axes),
X.count(axis, axis + num_axes),
X.template data<T, Context>(),
W.template data<T, Context>(),
InputSize() <= 2 ? nullptr : Input(2).template data<T, Context>(),
......@@ -121,8 +121,8 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
Output(0)->ReshapeLike(Input(-1));
kernel::ChannelAffine(
X.count(0, axis),
X.count(axis, axis + num_axes),
X.count(axis + num_axes),
X.count(axis, axis + num_axes),
dY.template data<T, Context>(),
W.template data<T, Context>(),
(const T*)nullptr,
......
......@@ -12,8 +12,8 @@ void CumSumOp<Context>::DoRunWithType() {
kernel::CumSum(
X.count(0, axis),
X.dim(axis),
X.count(axis + 1),
X.dim(axis),
exclusive_,
reverse_,
X.template data<T, Context>(),
......@@ -34,8 +34,8 @@ void CumSumGradientOp<Context>::DoRunWithType() {
kernel::CumSum(
dY.count(0, axis),
dY.dim(axis),
dY.count(axis + 1),
dY.dim(axis),
exclusive_,
!reverse_,
dY.template data<T, Context>(),
......
......@@ -26,7 +26,7 @@ void MultinomialOp<Context>::DoRunWithType() {
CPUContext cpu_ctx;
auto* prob = Buffer("prob")->template mutable_data<T, CPUContext>();
kernel::Softmax(
X.count(0, axis), X.dim(axis), X.count(axis + 1), x, prob, &cpu_ctx);
X.count(0, axis), X.count(axis + 1), X.dim(axis), x, prob, &cpu_ctx);
x = prob;
}
......
......@@ -27,8 +27,8 @@ void NLLLossOp<Context>::DoRunWithType() {
kernel::NLLLoss(
outer_dim,
X.dim(axis),
inner_dim,
X.dim(axis),
ignore_index_,
X.template data<LogitType, Context>(),
Input(1).template data<TargetType, Context>(),
......@@ -109,8 +109,8 @@ void NLLLossGradientOp<Context>::DoRunWithType() {
kernel::NLLLossGrad(
outer_dim,
dX->dim(axis),
inner_dim,
dX->dim(axis),
ignore_index_,
X.template data<LogitType, Context>(),
Input(1).template data<TargetType, Context>(),
......@@ -120,7 +120,7 @@ void NLLLossGradientOp<Context>::DoRunWithType() {
if (reduction_ == "NONE") {
kernel::BroadcastLossGrad(
outer_dim, dX->dim(axis), inner_dim, dy, dx, ctx());
outer_dim, inner_dim, dX->dim(axis), dy, dx, ctx());
} else {
int64_t normalizer = 1;
if (reduction_ == "VALID") {
......
......@@ -26,8 +26,8 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() {
kernel::SigmoidFocalLoss(
outer_dim,
X.dim(axis),
inner_dim,
X.dim(axis),
pos_alpha_,
neg_alpha_,
gamma_,
......@@ -107,8 +107,8 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() {
kernel::SigmoidFocalLossGrad(
outer_dim,
dX->dim(axis),
inner_dim,
dX->dim(axis),
pos_alpha_,
neg_alpha_,
gamma_,
......
......@@ -24,8 +24,8 @@ void SoftmaxCrossEntropyOp<Context>::DoRunWithType() {
kernel::Softmax(
outer_dim,
X.dim(axis),
inner_dim,
X.dim(axis),
X.template data<T, Context>(),
prob,
ctx());
......@@ -90,7 +90,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
if (reduction_ == "NONE") {
kernel::BroadcastLossGrad(
outer_dim, dX->dim(axis), inner_dim, dy, dx, ctx());
outer_dim, inner_dim, dX->dim(axis), dy, dx, ctx());
} else {
int64_t normalizer = 1;
if (reduction_ == "BATCH_MEAN") {
......
......@@ -29,16 +29,16 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
kernel::Softmax(
outer_dim,
X.dim(axis),
inner_dim,
X.dim(axis),
X.template data<LogitType, Context>(),
prob,
ctx());
kernel::SparseSoftmaxCrossEntropy(
outer_dim,
X.dim(axis),
inner_dim,
X.dim(axis),
ignore_index_,
prob,
Input(1).template data<TargetType, Context>(),
......@@ -120,8 +120,8 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
kernel::SparseSoftmaxCrossEntropyGrad(
outer_dim,
dX->dim(axis),
inner_dim,
dX->dim(axis),
ignore_index_,
prob,
Input(1).template data<TargetType, Context>(),
......@@ -131,7 +131,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
if (reduction_ == "NONE") {
kernel::BroadcastLossGrad(
outer_dim, dX->dim(axis), inner_dim, dy, dx, ctx());
outer_dim, inner_dim, dX->dim(axis), dy, dx, ctx());
} else {
int64_t normalizer = 1;
if (reduction_ == "VALID") {
......
......@@ -59,8 +59,8 @@ void FullyConnectedOp<Context>::DoRunWithType() {
TENSOR_FILL(Input(2), vec64_t({N}));
kernel::BiasAdd(
M,
N,
1,
N,
Y->template data<T, Context>(),
Input(2).template data<T, Context>(),
Y->template mutable_data<T, Context>(),
......
......@@ -56,9 +56,9 @@ void BatchNormOp<Context>::TrainingImpl() {
// Compute affine transformation
if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx());
kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx());
kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
}
}
......@@ -91,9 +91,9 @@ void BatchNormOp<Context>::InferenceImpl() {
// Compute affine transformation
if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx());
kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx());
kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
}
}
......
......@@ -89,9 +89,9 @@ void SyncBatchNormOp<Context>::TrainingImpl() {
// Compute affine transformation
if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, C_, S_, x, scale, bias, y, ctx());
kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, C_, 1, x, scale, bias, y, ctx());
kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
}
}
......
......@@ -28,8 +28,8 @@ void LpNormalizeOp<Context>::DoRunWithType() {
if (p_ == 1) {
kernel::L1Normalize(
X.count(0, axis),
reduce_dim,
X.count(axis + num_axes),
reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
epsilon_,
X.template data<T, Context>(),
......@@ -38,8 +38,8 @@ void LpNormalizeOp<Context>::DoRunWithType() {
} else if (p_ == 2) {
kernel::L2Normalize(
X.count(0, axis),
reduce_dim,
X.count(axis + num_axes),
reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
epsilon_,
X.template data<T, Context>(),
......@@ -65,8 +65,8 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
if (p_ == 1) {
kernel::L1NormalizeGrad(
X.count(0, axis),
reduce_dim,
X.count(axis + num_axes),
reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
epsilon_,
dY.template data<T, Context>(),
......@@ -76,8 +76,8 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
} else if (p_ == 2) {
kernel::L2NormalizeGrad(
X.count(0, axis),
reduce_dim,
X.count(axis + num_axes),
reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
epsilon_,
dY.template data<T, Context>(),
......
......@@ -23,8 +23,8 @@ void BiasAddOp<Context>::DoRunWithType() {
TENSOR_FILL(B, vec64_t({C}));
kernel::BiasAdd(
N,
C,
S,
C,
X.template data<T, Context>(),
B.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
......
......@@ -118,10 +118,10 @@ template <typename T>
void ConvOpBase<Context>::Pb(const T* bias, T* y) {
if (data_format() == "NCHW") {
kernel::BiasAdd(
Input(0).dim(0), out_channels_, out_dim_, y, bias, y, ctx());
Input(0).dim(0), out_dim_, out_channels_, y, bias, y, ctx());
} else if (data_format() == "NHWC") {
kernel::BiasAdd(
Input(0).dim(0) * out_dim_, out_channels_, 1, y, bias, y, ctx());
Input(0).dim(0) * out_dim_, 1, out_channels_, y, bias, y, ctx());
}
}
......
......@@ -23,6 +23,8 @@ from dragon.core.ops.activation_ops import dropout
from dragon.core.ops.activation_ops import drop_block2d
from dragon.core.ops.activation_ops import drop_path
from dragon.core.ops.activation_ops import elu
from dragon.core.ops.activation_ops import hardsigmoid
from dragon.core.ops.activation_ops import hardswish
from dragon.core.ops.activation_ops import leaky_relu
from dragon.core.ops.activation_ops import log_softmax
from dragon.core.ops.activation_ops import prelu
......@@ -30,6 +32,7 @@ from dragon.core.ops.activation_ops import relu
from dragon.core.ops.activation_ops import relu6
from dragon.core.ops.activation_ops import selu
from dragon.core.ops.activation_ops import softmax
from dragon.core.ops.activation_ops import swish
from dragon.core.ops.math_ops import fully_connected
from dragon.core.ops.normalization_ops import batch_norm
from dragon.core.ops.normalization_ops import group_norm
......
......@@ -20,7 +20,7 @@ from dragon.core.util import tls
def device(device_type, device_index=0):
"""Context-manager to nest the the device spec.
"""Context-manager to nest the device spec.
Examples:
......
......@@ -223,6 +223,96 @@ def elu(inputs, alpha=1., **kwargs):
@OpSchema.num_inputs(1)
def hardsigmoid(inputs, alpha=0.2, beta=0.5, **kwargs):
r"""Apply the hard sigmoid function.
The **HardSigmoid** function is defined as:
.. math:: \text{HardSigmoid}(x) = \max(0, \min(1, \alpha * x + \beta))
Examples:
```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.hardsigmoid(x, inplace=False))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
alpha : float, optional, default=0.2
The value to :math:`\alpha`.
beta : float, optional, default=0.5
The value to :math:`\beta`.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
args['alpha'] = float(alpha)
args['beta'] = float(beta)
inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.HardSigmoid
if context.executing_eagerly():
return op_lib \
.instantiate(
alpha=args['alpha'],
beta=args['beta'],
).apply([inputs], inplace=inplace)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs):
r"""Apply the hard swish function.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
The **HardSwish** function is defined as:
.. math:: \text{HardSwish}(x) = x \cdot \max(0, \min(1, \alpha * x + \beta))
Examples:
```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.hardswish(x))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
alpha : float, optional, default=0.2
The value to :math:`\alpha`.
beta : float, optional, default=0.5
The value to :math:`\beta`.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
args['alpha'] = float(alpha)
args['beta'] = float(beta)
op_lib = activation_ops_lib.HardSwish
if context.executing_eagerly():
return op_lib \
.instantiate(
alpha=args['alpha'],
beta=args['beta'],
).apply([inputs])
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def leaky_relu(inputs, alpha=0.2, **kwargs):
r"""Apply the leaky rectified linear unit.
......@@ -269,7 +359,7 @@ def leaky_relu(inputs, alpha=0.2, **kwargs):
@OpSchema.num_inputs(1)
def log_softmax(inputs, axis=-1, **kwargs):
r"""Apply the composite of logarithm and softmax.
r"""Compute the composite of logarithm and softmax.
The **LogSoftmax** function is defined as:
......@@ -492,7 +582,7 @@ def selu(inputs, alpha=1.67326, gamma=1.0507, **kwargs):
@OpSchema.num_inputs(1)
def sigmoid(inputs, **kwargs):
r"""Apply the sigmoid function.
r"""Compute the sigmoid result of input.
The **Sigmoid** function is defined as:
......@@ -529,7 +619,7 @@ def sigmoid(inputs, **kwargs):
@OpSchema.num_inputs(1)
def softmax(inputs, axis=-1, **kwargs):
r"""Apply the softmax function.
r"""Compute the softmax result.
The **Softmax** function is defined as:
......@@ -602,3 +692,40 @@ def tanh(inputs, **kwargs):
.apply([inputs], inplace=inplace)
else:
return op_lib.blend('Tanh', **args)
@OpSchema.num_inputs(1)
def swish(inputs, **kwargs):
r"""Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
The **Swish** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Examples:
```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(dragon.nn.swish(x))
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
op_lib = activation_ops_lib.Activation
if context.executing_eagerly():
return op_lib \
.instantiate(op_type='Swish') \
.apply([inputs])
else:
return op_lib.blend('Swish', **args)
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Activation(Operator):
"""Base activation operator."""
def __init__(self, key, dev, **kwargs):
super(Activation, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
......@@ -31,6 +33,8 @@ class Activation(Operator):
class Dropout(Activation):
"""Dropout operator."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
self.prob = kwargs.get('prob', 0.5)
......@@ -47,6 +51,8 @@ class Dropout(Activation):
class DropBlock2d(Activation):
"""DropBlock2d operator."""
def __init__(self, key, dev, **kwargs):
super(DropBlock2d, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', 7)
......@@ -69,6 +75,8 @@ class DropBlock2d(Activation):
class DropPath(Activation):
"""DropPath operator."""
def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs)
self.prob = kwargs.get('prob', 0.2)
......@@ -85,6 +93,8 @@ class DropPath(Activation):
class Elu(Activation):
"""Elu operator."""
def __init__(self, key, dev, **kwargs):
super(Elu, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.)
......@@ -96,7 +106,45 @@ class Elu(Activation):
}
class HardSigmoid(Activation):
"""HardSigmoid operator."""
def __init__(self, key, dev, **kwargs):
super(HardSigmoid, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.2)
self.beta = kwargs.get('beta', 0.5)
def attributes(self):
return {
'op_type': 'HardSigmoid',
'arguments': {
'alpha': float(self.alpha),
'beta': float(self.beta),
},
}
class HardSwish(Activation):
"""HardSwish operator."""
def __init__(self, key, dev, **kwargs):
super(HardSwish, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.2)
self.beta = kwargs.get('beta', 0.5)
def attributes(self):
return {
'op_type': 'HardSwish',
'arguments': {
'alpha': float(self.alpha),
'beta': float(self.beta),
},
}
class PRelu(Operator):
"""PRelu operator."""
def __init__(self, key, dev, **kwargs):
super(PRelu, self).__init__(key, dev, **kwargs)
self.data_format = kwargs.get('data_format', 'NCHW')
......@@ -112,6 +160,8 @@ class PRelu(Operator):
class Relu(Activation):
"""Relu operator."""
def __init__(self, key, dev, **kwargs):
super(Relu, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.)
......@@ -124,6 +174,8 @@ class Relu(Activation):
class Relu6(Activation):
"""Relu6 operator."""
def __init__(self, key, dev, **kwargs):
super(Relu6, self).__init__(key, dev, **kwargs)
......@@ -135,6 +187,8 @@ class Relu6(Activation):
class Selu(Activation):
"""Selu operator."""
def __init__(self, key, dev, **kwargs):
super(Selu, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.67326)
......@@ -151,6 +205,8 @@ class Selu(Activation):
class Softmax(Activation):
"""Softmax operator."""
def __init__(self, key, dev, **kwargs):
super(Softmax, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
......
......@@ -19,6 +19,8 @@ from dragon.core.framework.ops import Operator
class ArgReduce(Operator):
"""ArgReduce operator."""
def __init__(self, key, dev, **kwargs):
super(ArgReduce, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', 'ArgMax')
......@@ -39,6 +41,8 @@ class ArgReduce(Operator):
class Cast(Operator):
"""Cast operator."""
def __init__(self, key, dev, **kwargs):
super(Cast, self).__init__(key, dev, **kwargs)
self.dtype = kwargs.get('dtype', 'float32')
......@@ -58,6 +62,8 @@ class Cast(Operator):
class ChannelAffine(Operator):
"""ChannelAffine operator."""
def __init__(self, key, dev, **kwargs):
super(ChannelAffine, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
......@@ -78,6 +84,8 @@ class ChannelAffine(Operator):
class ChannelNormalize(Operator):
"""ChannelNormalize operator."""
def __init__(self, key, dev, **kwargs):
super(ChannelNormalize, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -115,6 +123,8 @@ class ChannelNormalize(Operator):
class ChannelShuffle(Operator):
"""ChannelShuffle operator."""
def __init__(self, key, dev, **kwargs):
super(ChannelShuffle, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -134,6 +144,8 @@ class ChannelShuffle(Operator):
class Concat(Operator):
"""Concat operator."""
def __init__(self, key, dev, **kwargs):
super(Concat, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -149,6 +161,8 @@ class Concat(Operator):
class Cumulative(Operator):
"""Cumulative operator."""
def __init__(self, key, dev, **kwargs):
super(Cumulative, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -171,6 +185,8 @@ class Cumulative(Operator):
class Expand(Operator):
"""Expand operator."""
def __init__(self, key, dev, **kwargs):
super(Expand, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -200,6 +216,8 @@ class Expand(Operator):
class ExpandDims(Operator):
"""ExpandDims operator."""
def __init__(self, key, dev, **kwargs):
super(ExpandDims, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', [0])
......@@ -218,6 +236,8 @@ class ExpandDims(Operator):
class Flatten(Operator):
"""Flatten operator."""
def __init__(self, key, dev, **kwargs):
super(Flatten, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -240,6 +260,8 @@ class Flatten(Operator):
class IndexSelect(Operator):
"""IndexSelect operator."""
def __init__(self, key, dev, **kwargs):
super(IndexSelect, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -259,6 +281,8 @@ class IndexSelect(Operator):
class LinSpace(Operator):
"""LinSpace operator."""
def __init__(self, key, dev, **kwargs):
super(LinSpace, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -309,6 +333,8 @@ class LinSpace(Operator):
class MaskedSelect(Operator):
"""MaskedSelect operator."""
def __init__(self, key, dev, **kwargs):
super(MaskedSelect, self).__init__(key, dev, **kwargs)
......@@ -320,6 +346,8 @@ class MaskedSelect(Operator):
class Moments(Operator):
"""Moments operator."""
def __init__(self, key, dev, **kwargs):
super(Moments, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None)
......@@ -339,6 +367,8 @@ class Moments(Operator):
class Multinomial(Operator):
"""Multinomial operator."""
def __init__(self, key, dev, **kwargs):
super(Multinomial, self).__init__(key, dev, **kwargs)
self.epsilon = kwargs.get('epsilon', 0.)
......@@ -360,6 +390,8 @@ class Multinomial(Operator):
class NonZero(Operator):
"""NonZero operator."""
def __init__(self, key, dev, **kwargs):
super(NonZero, self).__init__(key, dev, **kwargs)
......@@ -371,6 +403,8 @@ class NonZero(Operator):
class OneHot(Operator):
"""OneHot operator."""
def __init__(self, key, dev, **kwargs):
super(OneHot, self).__init__(key, dev, **kwargs)
self.depth = kwargs.get('depth', 1)
......@@ -392,6 +426,8 @@ class OneHot(Operator):
class Pad(Operator):
"""Pad operator."""
def __init__(self, key, dev, **kwargs):
super(Pad, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -425,6 +461,8 @@ class Pad(Operator):
class Permutation(Operator):
"""Permutation operator."""
def __init__(self, key, dev, **kwargs):
super(Permutation, self).__init__(key, dev, **kwargs)
self.dtype = kwargs.get('dtype', 'int64')
......@@ -453,6 +491,8 @@ class Permutation(Operator):
class Range(Operator):
"""Range operator."""
def __init__(self, key, dev, **kwargs):
super(Range, self).__init__(key, dev, **kwargs)
self.num_args = kwargs.get('num_args', 3)
......@@ -487,6 +527,8 @@ class Range(Operator):
class Reduce(Operator):
"""Reduce operator."""
def __init__(self, key, dev, **kwargs):
super(Reduce, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None)
......@@ -507,6 +549,8 @@ class Reduce(Operator):
class Repeat(Operator):
"""Repeat operator."""
def __init__(self, key, dev, **kwargs):
super(Repeat, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 2147483647)
......@@ -526,6 +570,8 @@ class Repeat(Operator):
class Reshape(Operator):
"""Reshape operator."""
def __init__(self, key, dev, **kwargs):
super(Reshape, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -555,6 +601,8 @@ class Reshape(Operator):
class Slice(Operator):
"""Slice operator."""
def __init__(self, key, dev, **kwargs):
super(Slice, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -590,6 +638,8 @@ class Slice(Operator):
class Shape(Operator):
"""Shape operator."""
def __init__(self, key, dev, **kwargs):
super(Shape, self).__init__(key, dev, **kwargs)
self._device = device_spec.DeviceSpec()
......@@ -602,6 +652,8 @@ class Shape(Operator):
class Sort(Operator):
"""Sort operator."""
def __init__(self, key, dev, **kwargs):
super(Sort, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -621,6 +673,8 @@ class Sort(Operator):
class Split(Operator):
"""Split operator."""
def __init__(self, key, dev, **kwargs):
super(Split, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -643,6 +697,8 @@ class Split(Operator):
class Squeeze(Operator):
"""Squeeze operator."""
def __init__(self, key, dev, **kwargs):
super(Squeeze, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None)
......@@ -674,6 +730,8 @@ class Stack(Operator):
class Tile(Operator):
"""Tile operator."""
def __init__(self, key, dev, **kwargs):
super(Tile, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -703,6 +761,8 @@ class Tile(Operator):
class Transpose(Operator):
"""Transpose operator."""
def __init__(self, key, dev, **kwargs):
super(Transpose, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -732,6 +792,8 @@ class Transpose(Operator):
class TopK(Operator):
"""TopK operator."""
def __init__(self, key, dev, **kwargs):
super(TopK, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 1)
......@@ -755,6 +817,8 @@ class TopK(Operator):
class Unique(Operator):
"""Unique operator."""
def __init__(self, key, dev, **kwargs):
super(Unique, self).__init__(key, dev, **kwargs)
self.return_inverse = kwargs.get('return_inverse', False)
......@@ -776,6 +840,8 @@ class Unique(Operator):
class Where(Operator):
"""Where operator."""
def __init__(self, key, dev, **kwargs):
super(Where, self).__init__(key, dev, **kwargs)
......
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Assign(Operator):
"""Assign operator."""
def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -54,6 +56,8 @@ class Assign(Operator):
class Copy(Operator):
"""Copy operator."""
def __init__(self, key, dev, **kwargs):
super(Copy, self).__init__(key, dev, **kwargs)
......@@ -66,6 +70,8 @@ class Copy(Operator):
class MaskedAssign(Operator):
"""MaskedAssign operator."""
def __init__(self, key, dev, **kwargs):
super(MaskedAssign, self).__init__(key, dev, **kwargs)
......
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Collective(Operator):
"""Collective operator."""
def __init__(self, key, dev, **kwargs):
super(Collective, self).__init__(key, dev, **kwargs)
self.root = kwargs.get('root', 0)
......
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Initializer(Operator):
"""Initializer operator."""
def __init__(self, key, dev, **kwargs):
super(Initializer, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -43,6 +45,8 @@ class Initializer(Operator):
class Eye(Initializer):
"""Eye operator."""
def __init__(self, key, dev, **kwargs):
super(Eye, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 0)
......@@ -79,6 +83,8 @@ class Fill(Initializer):
class GlorotNormal(Initializer):
"""GlorotNormal operator."""
def __init__(self, key, dev, **kwargs):
super(GlorotNormal, self).__init__(key, dev, **kwargs)
self.scale = kwargs.get('scale', 2.)
......@@ -99,6 +105,8 @@ class GlorotNormal(Initializer):
class GlorotUniform(Initializer):
"""GlorotUniform operator."""
def __init__(self, key, dev, **kwargs):
super(GlorotUniform, self).__init__(key, dev, **kwargs)
self.scale = kwargs.get('scale', 3.)
......@@ -119,6 +127,8 @@ class GlorotUniform(Initializer):
class RandomNormal(Initializer):
"""RandomNormal operator."""
def __init__(self, key, dev, **kwargs):
super(RandomNormal, self).__init__(key, dev, **kwargs)
self.mean = kwargs.get('mean', 0.)
......@@ -139,6 +149,8 @@ class RandomNormal(Initializer):
class RandomUniform(Initializer):
"""RandomUniform operator."""
def __init__(self, key, dev, **kwargs):
super(RandomUniform, self).__init__(key, dev, **kwargs)
self.low = kwargs.get('low', 0.)
......@@ -159,6 +171,8 @@ class RandomUniform(Initializer):
class TruncatedNormal(Initializer):
"""TruncatedNormal operator."""
def __init__(self, key, dev, **kwargs):
super(TruncatedNormal, self).__init__(key, dev, **kwargs)
self.mean = kwargs.get('mean', 0.)
......
......@@ -17,9 +17,11 @@ from __future__ import print_function
from dragon.core.framework.ops import Operator
class _Loss(Operator):
class Loss(Operator):
"""Loss operator."""
def __init__(self, key, dev, **kwargs):
super(_Loss, self).__init__(key, dev, **kwargs)
super(Loss, self).__init__(key, dev, **kwargs)
self.reduction = kwargs.get('reduction', 'MEAN')
def attributes(self):
......@@ -34,17 +36,23 @@ class _Loss(Operator):
return self.dispatch(inputs, [self.alloc()])
class L1Loss(_Loss):
class L1Loss(Loss):
"""L1Loss operator."""
def __init__(self, key, dev, **kwargs):
super(L1Loss, self).__init__(key, dev, **kwargs)
class L2Loss(_Loss):
class L2Loss(Loss):
"""L2Loss operator."""
def __init__(self, key, dev, **kwargs):
super(L2Loss, self).__init__(key, dev, **kwargs)
class NLLLoss(_Loss):
class NLLLoss(Loss):
"""NLLLoss operator."""
def __init__(self, key, dev, **kwargs):
super(NLLLoss, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -61,12 +69,16 @@ class NLLLoss(_Loss):
}
class SigmoidCrossEntropy(_Loss):
class SigmoidCrossEntropy(Loss):
"""SigmoidCrossEntropy operator."""
def __init__(self, key, dev, **kwargs):
super(SigmoidCrossEntropy, self).__init__(key, dev, **kwargs)
class SmoothL1Loss(_Loss):
class SmoothL1Loss(Loss):
"""SmoothL1Loss operator."""
def __init__(self, key, dev, **kwargs):
super(SmoothL1Loss, self).__init__(key, dev, **kwargs)
self.beta = kwargs.get('beta', 1.)
......@@ -81,7 +93,9 @@ class SmoothL1Loss(_Loss):
}
class SoftmaxCrossEntropy(_Loss):
class SoftmaxCrossEntropy(Loss):
"""SoftmaxCrossEntropy operator."""
def __init__(self, key, dev, **kwargs):
super(SoftmaxCrossEntropy, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -96,7 +110,9 @@ class SoftmaxCrossEntropy(_Loss):
}
class SparseSoftmaxCrossEntropy(_Loss):
class SparseSoftmaxCrossEntropy(Loss):
"""SparseSoftmaxCrossEntropy operator."""
def __init__(self, key, dev, **kwargs):
super(SparseSoftmaxCrossEntropy, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -113,7 +129,9 @@ class SparseSoftmaxCrossEntropy(_Loss):
}
class SigmoidFocalLoss(_Loss):
class SigmoidFocalLoss(Loss):
"""SigmoidFocalLoss operator."""
def __init__(self, key, dev, **kwargs):
super(SigmoidFocalLoss, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Axpby(Operator):
"""Axpby operator."""
def __init__(self, key, dev, **kwargs):
super(Axpby, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.)
......@@ -40,6 +42,8 @@ class Axpby(Operator):
class BinaryOp(Operator):
"""Binary operator."""
def __init__(self, key, dev, **kwargs):
super(BinaryOp, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
......@@ -52,6 +56,8 @@ class BinaryOp(Operator):
class Clip(Operator):
"""Clip operator."""
def __init__(self, key, dev, **kwargs):
super(Clip, self).__init__(key, dev, **kwargs)
self.low = kwargs.get('low', None)
......@@ -75,6 +81,8 @@ class Clip(Operator):
class FullyConnected(Operator):
"""FullyConnected operator."""
def __init__(self, key, dev, **kwargs):
super(FullyConnected, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
......@@ -94,6 +102,8 @@ class FullyConnected(Operator):
class MatMul(Operator):
"""MatMul operator."""
def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False)
......@@ -113,6 +123,8 @@ class MatMul(Operator):
class UnaryOp(Operator):
"""Unary operator."""
def __init__(self, key, dev, **kwargs):
super(UnaryOp, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
......
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class Metric(Operator):
"""Metric operator."""
def __init__(self, key, dev, **kwargs):
super(Metric, self).__init__(key, dev, **kwargs)
self.reduction = kwargs.get('reduction', 'MEAN')
......@@ -27,6 +29,8 @@ class Metric(Operator):
class Accuracy(Metric):
"""Accuracy operator."""
def __init__(self, key, dev, **kwargs):
super(Accuracy, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
......
......@@ -36,19 +36,11 @@ def batch_norm(
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The running average of statistics are calculated as:
.. math::
x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
Note that the number of inputs should be **5**, i.e.,
this operators is implemented into the fused version.
However, you can still fix the ``gamma`` and ``beta``,
by disabling the their gradients directly.
.. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
Parameters
----------
......@@ -91,18 +83,11 @@ def group_norm(inputs, axis=-1, group=32, epsilon=1e-5, **kwargs):
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
It turns out to be **InstanceNorm**, if ``group`` is **0**,
Note that it turns out to be **InstanceNorm**, if ``group`` is **0**,
or **LayerNorm**, if ``group`` is **1**.
Note that the number of inputs should be **3**, i.e.,
this operators is implemented into the fused version.
However, you can still fix the ``gamma`` and ``beta``,
by disabling the their gradients directly.
Parameters
----------
inputs : Sequence[dragon.Tensor]
......@@ -141,14 +126,34 @@ def instance_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
The normalization is defined as:
.. math::
\text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: \text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Parameters
----------
inputs : Sequence[dragon.Tensor]
The tensor ``x``, ``gamma`` and ``beta``.
axis : int, optional, default=-1
The channel axis.
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
Note that the number of inputs should be **3**, i.e.,
this operators is implemented into the fused version.
Returns
-------
dragon.Tensor
The output tensor.
"""
return group_norm(inputs, axis=axis, group=0, epsilon=epsilon, **kwargs)
However, you can still fix the **gamma** and **beta**,
by disabling the their gradients directly.
@OpSchema.num_inputs(3)
def layer_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
r"""Apply the layer normalization.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
The normalization is defined as:
.. math:: \text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Parameters
----------
......@@ -165,7 +170,7 @@ def instance_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
The output tensor.
"""
return group_norm(inputs, axis=axis, group=0, epsilon=epsilon, **kwargs)
return group_norm(inputs, axis=axis, group=1, epsilon=epsilon, **kwargs)
@OpSchema.num_inputs(1)
......@@ -238,40 +243,6 @@ def lp_normalize(inputs, axis=None, p=2, epsilon=1e-12, reduction='sum', **kwarg
return op_lib.blend(**args)
@OpSchema.num_inputs(3)
def layer_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
r"""Apply the layer normalization.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
The normalization is defined as:
.. math::
\text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Note that the number of inputs should be *3*, i.e.,
this operators is implemented into the fused version.
However, you can still fix the *gamma* and *beta*,
by disabling the their gradients directly.
Parameters
----------
inputs : Sequence[dragon.Tensor]
The tensor ``x``, ``gamma`` and ``beta``.
axis : int, optional, default=-1
The channel axis.
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
Returns
-------
dragon.Tensor
The output tensor.
"""
return group_norm(inputs, axis=axis, group=1, epsilon=epsilon, **kwargs)
@OpSchema.num_inputs(1)
def local_response_norm(
inputs,
......@@ -347,19 +318,11 @@ def sync_batch_norm(
The normalization is defined as:
.. math::
\text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: \text{out} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The running average of statistics are calculated as:
.. math::
x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
Note that the number of inputs should be **5**, i.e.,
this operators is implemented into the fused version.
However, you can still fix the ``gamma`` and ``beta``,
by disabling the their gradients directly.
.. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}}
Parameters
----------
......
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class BatchNorm(Operator):
"""BatchNorm operator."""
def __init__(self, key, dev, **kwargs):
super(BatchNorm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -43,6 +45,8 @@ class BatchNorm(Operator):
class GroupNorm(Operator):
"""GroupNorm operator."""
def __init__(self, key, dev, **kwargs):
super(GroupNorm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -64,6 +68,8 @@ class GroupNorm(Operator):
class LpNormalize(Operator):
"""LpNormalize operator."""
def __init__(self, key, dev, **kwargs):
super(LpNormalize, self).__init__(key, dev, **kwargs)
self.p = kwargs.get('p', 2)
......@@ -89,6 +95,8 @@ class LpNormalize(Operator):
class LocalResponseNorm(Operator):
"""LocalResponseNorm operator."""
def __init__(self, key, dev, **kwargs):
super(LocalResponseNorm, self).__init__(key, dev, **kwargs)
self.size = kwargs.get('size', 5)
......@@ -114,6 +122,8 @@ class LocalResponseNorm(Operator):
class SyncBatchNorm(BatchNorm):
"""SyncBatchNorm operator."""
def __init__(self, key, dev, **kwargs):
super(SyncBatchNorm, self).__init__(key, dev, **kwargs)
self.process_group = kwargs.get('process_group', None)
......
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class LSTMCell(Operator):
"""LSTMCell operator."""
def __init__(self, key, dev, **kwargs):
super(LSTMCell, self).__init__(key, dev, **kwargs)
......@@ -30,6 +32,8 @@ class LSTMCell(Operator):
class Recurrent(Operator):
"""Recurrent operator."""
def __init__(self, key, dev, **kwargs):
super(Recurrent, self).__init__(key, dev, **kwargs)
self.mode = kwargs.get('mode', 'rnn_tanh')
......@@ -58,6 +62,8 @@ class Recurrent(Operator):
class RNNParamSet(Operator):
"""RNNParamSet operator."""
def __init__(self, key, dev, **kwargs):
super(RNNParamSet, self).__init__(key, dev, **kwargs)
self.param_type = kwargs.get('param_type', 'matrix')
......
......@@ -18,6 +18,8 @@ from dragon.core.framework.ops import Operator
class ParamUpdate(Operator):
"""ParamUpdate operator."""
def __init__(self, key, dev, **kwargs):
super(ParamUpdate, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
......
......@@ -17,9 +17,11 @@ from __future__ import print_function
from dragon.core.framework.ops import Operator
class _ConvNd(Operator):
class ConvNd(Operator):
"""ConvNd operator."""
def __init__(self, key, dev, **kwargs):
super(_ConvNd, self).__init__(key, dev, **kwargs)
super(ConvNd, self).__init__(key, dev, **kwargs)
self.num_output = kwargs.get('dim_out', 1)
self.kernel_shape = kwargs.get('kernel_shape', 1)
self.strides = kwargs.get('strides', 1)
......@@ -46,9 +48,11 @@ class _ConvNd(Operator):
return self.dispatch(inputs, [self.alloc()])
class _PoolNd(Operator):
class PoolNd(Operator):
"""PoolNd operator."""
def __init__(self, key, dev, **kwargs):
super(_PoolNd, self).__init__(key, dev, **kwargs)
super(PoolNd, self).__init__(key, dev, **kwargs)
self.kernel_shape = kwargs.get('kernel_shape', 1)
self.strides = kwargs.get('strides', 1)
self.pads = kwargs.get('pads', 0)
......@@ -78,6 +82,8 @@ class _PoolNd(Operator):
class BiasAdd(Operator):
"""BiasAdd operator."""
def __init__(self, key, dev, **kwargs):
super(BiasAdd, self).__init__(key, dev, **kwargs)
self.data_format = kwargs.get('data_format', 'NCHW')
......@@ -93,12 +99,16 @@ class BiasAdd(Operator):
return self.dispatch(inputs, outputs)
class Conv2d(_ConvNd):
class Conv2d(ConvNd):
"""Conv2d operator."""
def __init__(self, key, dev, **kwargs):
super(Conv2d, self).__init__(key, dev, **kwargs)
class ConvTranspose2d(_ConvNd):
class ConvTranspose2d(ConvNd):
"""ConvTranspose2d operator."""
def __init__(self, key, dev, **kwargs):
super(ConvTranspose2d, self).__init__(key, dev, **kwargs)
self.output_padding = kwargs.get('output_padding', None)
......@@ -121,6 +131,8 @@ class ConvTranspose2d(_ConvNd):
class DepthToSpace(Operator):
"""DepthToSpace operator."""
def __init__(self, key, dev, **kwargs):
super(DepthToSpace, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', '2')
......@@ -139,17 +151,23 @@ class DepthToSpace(Operator):
return self.dispatch(inputs, [self.alloc()])
class DepthwiseConv2d(_ConvNd):
class DepthwiseConv2d(ConvNd):
"""DepthwiseConv2d operator."""
def __init__(self, key, dev, **kwargs):
super(DepthwiseConv2d, self).__init__(key, dev, **kwargs)
class Pool2d(_PoolNd):
class Pool2d(PoolNd):
"""Pool2d operator."""
def __init__(self, key, dev, **kwargs):
super(Pool2d, self).__init__(key, dev, **kwargs)
class Resize(Operator):
"""Resize operator."""
def __init__(self, key, dev, **kwargs):
super(Resize, self).__init__(key, dev, **kwargs)
self.num_sizes = kwargs.get('num_sizes', 0)
......@@ -193,6 +211,8 @@ class Resize(Operator):
class RoiAlign(Operator):
"""RoiAlign operator."""
def __init__(self, key, dev, **kwargs):
super(RoiAlign, self).__init__(key, dev, **kwargs)
self.pooled_h = kwargs.get('pooled_h', 0)
......@@ -216,6 +236,8 @@ class RoiAlign(Operator):
class RoiPool(Operator):
"""RoiPool operator."""
def __init__(self, key, dev, **kwargs):
super(RoiPool, self).__init__(key, dev, **kwargs)
self.pooled_h = kwargs.get('pooled_h', 7)
......@@ -237,6 +259,8 @@ class RoiPool(Operator):
class SpaceToDepth(Operator):
"""SpaceToDepth operator."""
def __init__(self, key, dev, **kwargs):
super(SpaceToDepth, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', '2')
......
......@@ -30,6 +30,20 @@ def dropout_exporter(op_def, shape_dict, ws):
return node, const_tensors
@exporter.register('HardSigmoid')
def hardsigmoid_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
alpha, beta = 0.2, 0.5
for arg in op_def.arg:
if arg.name == 'alpha':
alpha = arg.f
elif arg.name == 'beta':
beta = arg.f
helper.add_attribute(node, 'alpha', alpha)
helper.add_attribute(node, 'beta', beta)
return node, const_tensors
@exporter.register('PRelu')
def prelu_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
......
......@@ -84,6 +84,47 @@ void EluGrad(
T* dx,
Context* ctx);
/* activation.hardsigmoid */
template <typename T, class Context>
void HardSigmoid(
const int count,
const float alpha,
const float beta,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void HardSigmoidGrad(
const int count,
const float alpha,
const T* dy,
const T* y,
T* dx,
Context* ctx);
/* activation.hardswish */
template <typename T, class Context>
void HardSwish(
const int count,
const float alpha,
const float beta,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void HardSwishGrad(
const int count,
const float alpha,
const float beta,
const T* dy,
const T* x,
T* dx,
Context* ctx);
/* activation.prelu */
template <typename T, class Context>
......@@ -185,8 +226,8 @@ void SigmoidGrad(const int count, const T* dy, const T* y, T* dx, Context* ctx);
template <typename T, class Context>
void Softmax(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
T* y,
Context* ctx);
......@@ -194,13 +235,27 @@ void Softmax(
template <typename T, class Context>
void SoftmaxGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* dy,
const T* y,
T* dx,
Context* ctx);
/* activation.hardswish */
template <typename T, class Context>
void Swish(const int count, const T* x, T* y, Context* ctx);
template <typename T, class Context>
void SwishGrad(
const int count,
const T* dy,
const T* x,
const T* y,
T* dx,
Context* ctx);
/* activation.tanh */
template <typename T, class Context>
......@@ -236,8 +291,8 @@ void ArgMin(
template <typename T, class Context>
void ChannelAffine(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
const T* w,
const T* b,
......@@ -275,8 +330,8 @@ void ChannelShuffle(
template <typename T, class Context>
void CumSum(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const bool exclusive,
const bool reverse,
const T* x,
......@@ -296,7 +351,7 @@ void IndexSelect(
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* indices,
const int64_t* index,
const T* x,
T* y,
Context* ctx);
......@@ -529,7 +584,7 @@ void TopSelect(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int topk,
const int select_dim,
const int largest,
const T* x,
T* value,
......@@ -585,8 +640,8 @@ void ReduceLossGrad(
template <typename T, class Context>
void BroadcastLossGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* dy,
T* dx,
Context* ctx);
......@@ -596,10 +651,10 @@ void BroadcastLossGrad(
template <typename LogitType, typename TargetType, class Context>
void NLLLoss(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* log_prob,
const LogitType* logit,
const TargetType* target,
LogitType* loss,
LogitType* mask,
......@@ -608,12 +663,12 @@ void NLLLoss(
template <typename LogitType, typename TargetType, class Context>
void NLLLossGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* log_prob,
const LogitType* logit,
const TargetType* target,
LogitType* dx,
LogitType* dlogit,
LogitType* mask,
Context* ctx);
......@@ -642,12 +697,12 @@ void SigmoidCrossEntropyGrad(
template <typename LogitType, typename TargetType, class Context>
void SigmoidFocalLoss(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float pos_alpha,
const float neg_alpha,
const float gamma,
const int neg_id,
const int negative_index,
const LogitType* logit,
const TargetType* target,
LogitType* loss,
......@@ -657,8 +712,8 @@ void SigmoidFocalLoss(
template <typename LogitType, typename TargetType, class Context>
void SigmoidFocalLossGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const float pos_alpha,
const float neg_alpha,
const float gamma,
......@@ -693,8 +748,8 @@ template <typename T, class Context>
void SoftmaxCrossEntropy(
const int count,
const T* prob,
const T* targets,
T* losses,
const T* target,
T* loss,
Context* ctx);
/* loss.sparse_softmax_cross_entropy */
......@@ -702,8 +757,8 @@ void SoftmaxCrossEntropy(
template <typename LogitType, typename TargetType, class Context>
void SparseSoftmaxCrossEntropy(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* prob,
const TargetType* target,
......@@ -714,8 +769,8 @@ void SparseSoftmaxCrossEntropy(
template <typename LogitType, typename TargetType, class Context>
void SparseSoftmaxCrossEntropyGrad(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const int ignore_index,
const LogitType* prob,
const TargetType* target,
......@@ -907,8 +962,8 @@ void GroupNormBackward(
template <typename T, class Context>
void L1Normalize(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const T* x,
......@@ -918,8 +973,8 @@ void L1Normalize(
template <typename T, class Context>
void L1NormalizeGrad(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const T* dy,
......@@ -930,8 +985,8 @@ void L1NormalizeGrad(
template <typename T, class Context>
void L2Normalize(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const T* x,
......@@ -941,8 +996,8 @@ void L2Normalize(
template <typename T, class Context>
void L2NormalizeGrad(
const int outer_dim,
const int reduce_dim,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const T* dy,
......@@ -1030,8 +1085,8 @@ void SGDUpdate(
template <typename T, class Context>
void BiasAdd(
const int outer_dim,
const int axis_dim,
const int inner_dim,
const int axis_dim,
const T* x,
const T* b,
T* y,
......
......@@ -16,11 +16,13 @@ from __future__ import print_function as _print_function
from dragon.vm.tensorflow.core.keras.activations import elu
from dragon.vm.tensorflow.core.keras.activations import get
from dragon.vm.tensorflow.core.keras.activations import exponential
from dragon.vm.tensorflow.core.keras.activations import hard_sigmoid
from dragon.vm.tensorflow.core.keras.activations import linear
from dragon.vm.tensorflow.core.keras.activations import relu
from dragon.vm.tensorflow.core.keras.activations import selu
from dragon.vm.tensorflow.core.keras.activations import sigmoid
from dragon.vm.tensorflow.core.keras.activations import softmax
from dragon.vm.tensorflow.core.keras.activations import swish
from dragon.vm.tensorflow.core.keras.activations import tanh
__all__ = [_s for _s in dir() if not _s.startswith('_')]
......@@ -41,5 +41,6 @@ from dragon.vm.tensorflow.core.ops.nn import sigmoid_cross_entropy_with_logits
from dragon.vm.tensorflow.core.ops.nn import softmax
from dragon.vm.tensorflow.core.ops.nn import softmax_cross_entropy_with_logits
from dragon.vm.tensorflow.core.ops.nn import sparse_softmax_cross_entropy_with_logits
from dragon.vm.tensorflow.core.ops.nn import swish
__all__ = [_s for _s in dir() if not _s.startswith('_')]
......@@ -78,7 +78,7 @@ def name_scope(name):
def device(device_name):
"""Context-manager to nest the the device spec.
"""Context-manager to nest the device spec.
Examples:
......
......@@ -13,6 +13,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.ops import activation_ops as _activation_ops
from dragon.core.util import six
from dragon.vm.tensorflow.core.keras.utils import generic_utils
from dragon.vm.tensorflow.core.ops import math_ops
......@@ -83,6 +84,34 @@ def exponential(x):
return math_ops.exp(x)
def hard_sigmoid(x, **kwargs):
r"""Apply the hard sigmoid function to input.
The **HardSigmoid** function is defined as:
.. math:: \text{HardSigmoid}(x) = \max(0, \min(1, 0.2 * x + 0.5))
Examples:
```python
x = tf.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(tf.keras.activations.hard_sigmoid(x, inplace=False))
```
Parameters
----------
x : dragon.Tensor
The input tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
return _activation_ops.hardsigmoid(x, **kwargs)
def linear(x):
r"""Apply the linear activation to input.
......@@ -240,6 +269,35 @@ def softmax(x, axis=-1, **kwargs):
return nn.softmax(x, axis=axis, **kwargs)
def swish(x):
r"""Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
The **Swish** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Examples:
```python
x = tf.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(tf.keras.activations.swish(x))
```
Parameters
----------
x : dragon.Tensor
The input tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
return nn.swish(x)
def tanh(x, **kwargs):
r"""Apply the tanh function to input.
......
......@@ -14,6 +14,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.ops import activation_ops
from dragon.core.ops import array_ops
from dragon.core.ops import normalization_ops
......@@ -182,3 +183,32 @@ def moments(x, axes=None, keepdims=False, name=None):
"""
return array_ops.moments(x, axis=axes, keep_dims=keepdims, name=name)
def swish(features):
r"""Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
The **Swish** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Examples:
```python
x = tf.constant([-2.5, -1.0, 0.0, 1.0, 2.5])
print(tf.nn.swish(x))
```
Parameters
----------
features : dragon.Tensor
The input tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
return activation_ops.swish(features)
......@@ -188,6 +188,49 @@ class TestActivationOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_elu()
def test_hardsigmoid(self):
alpha, beta = 0.2, 0.5
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], 'float32')
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.nn.hardsigmoid(x, alpha=alpha, beta=beta)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
result = np.clip(alpha * data + beta, 0, 1)
self.assertEqual([y, dx],
[result, (result > 0) * (result < 1) * data * alpha])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_hardsigmoid_cuda(self):
with dragon.device('cuda'):
self.test_hardsigmoid()
def test_hardswish(self):
alpha, beta = 0.2, 0.5
bound = beta / alpha
alpha2x = alpha * 2.
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], 'float32')
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.nn.hardswish(x, alpha=alpha, beta=beta)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
result = data * np.clip(alpha * data + beta, 0, 1)
result2 = data.copy()
inds = np.where(data < bound)[0]
result2[inds] = data[inds] * (data[inds] * alpha2x + beta)
result2[np.where(data < -bound)[0]] = 0
self.assertEqual([y, dx], [result, result2])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_hardswish_cuda(self):
with dragon.device('cuda'):
self.test_hardswish()
def test_leaky_relu(self):
alpha = 0.2
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......@@ -370,6 +413,24 @@ class TestActivationOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_softmax()
def test_swish(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], 'float32')
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.nn.swish(x)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
result = data * (1. / (1. + np.exp(-data)))
result2 = data * (result + (1. / (1. + np.exp(-data))) * (1. - result))
self.assertEqual([y, dx], [result, result2])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_swish_cuda(self):
with dragon.device('cuda'):
self.test_swish()
def test_tanh(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......
......@@ -354,6 +354,24 @@ class TestModules(OpTestCase):
m.reset_parameters()
_ = repr(m)
def test_hardsigmoid(self):
alpha, beta = 1.0 / 6.0, 0.5
data = np.array([--3., -2., -1., 0., 1., 2., 3], 'float32')
x = new_tensor(data)
m = torch.nn.Hardsigmoid(inplace=True)
y, _ = m(x), repr(m)
result = np.clip(alpha * data + beta, 0, 1)
self.assertEqual(y, result)
def test_hardswish(self):
alpha, beta = 1.0 / 6.0, 0.5
data = np.array([-3., -2., -1., 0., 1., 2., 3], 'float32')
x = new_tensor(data)
m = torch.nn.Hardswish()
y, _ = m(x), repr(m)
result = data * np.clip(alpha * data + beta, 0, 1)
self.assertEqual(y, result)
def test_leaky_relu(self):
alpha = 0.2
data = np.array([-1., 0., 1.], 'float32')
......@@ -553,6 +571,14 @@ class TestModules(OpTestCase):
y, _ = m(x), repr(m)
self.assertEqual(y, data)
def test_swish(self):
data = np.array([-3., -2., -1., 0., 1., 2., 3], 'float32')
x = new_tensor(data)
m = torch.nn.Swish()
y, _ = m(x), repr(m)
result = data * (1. / (1. + np.exp(-data)))
self.assertEqual(y, result)
def test_tanh(self):
data = np.array([0.2, 0.4, 0.6, 0.8, 1.], 'float32')
x = new_tensor(data)
......
......@@ -210,6 +210,16 @@ class TestTensorOps(OpTestCase):
data.fill(value)
self.assertEqual(x, data)
def test_full(self):
entries = [((2, 3), 1), ((2, 3), 1.)]
for shape, value in entries:
data = np.zeros(shape)
x = torch.full((1,), 0).new_full(shape, value)
data.fill(value)
self.assertEqual(x, data)
self.assertEqual(torch.empty(1).new_ones(shape), np.ones(shape))
self.assertEqual(torch.empty(1).new_zeros(shape), np.zeros(shape))
def test_flatten(self):
data = arange((1, 2, 3))
x = new_tensor(data)
......@@ -377,6 +387,8 @@ class TestTensorOps(OpTestCase):
data = np.array([-1., 0., 1.], 'float32')
x = new_tensor(data)
self.assertEqual(-x, -data)
x.neg_()
self.assertEqual(x, -data)
def test_non_zero(self):
data = arange((2, 3))
......
......@@ -53,6 +53,7 @@ class TestTensor(unittest.TestCase):
self.assertEqual(torch.Tensor([0]).dim(), 1)
self.assertEqual(float(torch.Tensor(1).one_()), 1.)
self.assertEqual(torch.empty(2, 3).ndimension(), 2)
self.assertEqual(torch.empty(3).new_empty(2, 3).ndimension(), 2)
self.assertEqual(repr(torch.tensor(1)), '1')
self.assertNotEqual(a.__hash__(), b.__hash__())
self.assertNotEqual(a.__repr__(), b.__repr__())
......
......@@ -82,6 +82,8 @@ from dragon.vm.torch.core.ops.array.functional import unsqueeze
from dragon.vm.torch.core.ops.array.functional import where
from dragon.vm.torch.core.ops.init.functional import arange
from dragon.vm.torch.core.ops.init.functional import eye
from dragon.vm.torch.core.ops.init.functional import full
from dragon.vm.torch.core.ops.init.functional import full_like
from dragon.vm.torch.core.ops.init.functional import linspace
from dragon.vm.torch.core.ops.init.functional import ones
from dragon.vm.torch.core.ops.init.functional import ones_like
......
......@@ -21,6 +21,8 @@ from dragon.vm.torch._api.nn import init
# Classes
from dragon.vm.torch.core.nn.modules.activation import ELU
from dragon.vm.torch.core.nn.modules.activation import GumbelSoftmax
from dragon.vm.torch.core.nn.modules.activation import Hardsigmoid
from dragon.vm.torch.core.nn.modules.activation import Hardswish
from dragon.vm.torch.core.nn.modules.activation import LeakyReLU
from dragon.vm.torch.core.nn.modules.activation import LogSoftmax
from dragon.vm.torch.core.nn.modules.activation import PReLU
......@@ -29,6 +31,7 @@ from dragon.vm.torch.core.nn.modules.activation import ReLU6
from dragon.vm.torch.core.nn.modules.activation import SELU
from dragon.vm.torch.core.nn.modules.activation import Sigmoid
from dragon.vm.torch.core.nn.modules.activation import Softmax
from dragon.vm.torch.core.nn.modules.activation import Swish
from dragon.vm.torch.core.nn.modules.activation import Tanh
from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm1d
from dragon.vm.torch.core.nn.modules.batchnorm import BatchNorm2d
......
......@@ -27,6 +27,8 @@ from dragon.vm.torch.core.nn.functional import drop_path
from dragon.vm.torch.core.nn.functional import dropout
from dragon.vm.torch.core.nn.functional import elu
from dragon.vm.torch.core.nn.functional import group_norm
from dragon.vm.torch.core.nn.functional import hardsigmoid
from dragon.vm.torch.core.nn.functional import hardswish
from dragon.vm.torch.core.nn.functional import kl_div
from dragon.vm.torch.core.nn.functional import l1_loss
from dragon.vm.torch.core.nn.functional import leaky_relu
......@@ -47,6 +49,7 @@ from dragon.vm.torch.core.nn.functional import sigmoid
from dragon.vm.torch.core.nn.functional import sigmoid_focal_loss
from dragon.vm.torch.core.nn.functional import smooth_l1_loss
from dragon.vm.torch.core.nn.functional import softmax
from dragon.vm.torch.core.nn.functional import swish
from dragon.vm.torch.core.nn.functional import sync_batch_norm
from dragon.vm.torch.core.nn.functional import tanh
from dragon.vm.torch.core.nn.functional import upsample
......
......@@ -85,13 +85,11 @@ def batch_norm(
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The moving average of stats are calculated as:
.. math::
x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{stat}
.. math:: x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{stat}
Parameters
----------
......@@ -617,8 +615,7 @@ def group_norm(input, weight, bias, groups=32, eps=1e-5):
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Parameters
----------
......@@ -648,6 +645,73 @@ def group_norm(input, weight, bias, groups=32, eps=1e-5):
.apply(input, weight, bias)
def hardsigmoid(input, inplace=False):
r"""Apply the hard sigmoid function to input.
The **HardSigmoid** function is defined as:
.. math::
\text{Hardsigmoid}(x) = \begin{cases}
0 & \text{if~} x \le -3, \\
1 & \text{if~} x \ge +3, \\
x / 6 + 1 / 2 & \text{otherwise}
\end{cases}
See Also
--------
`torch.nn.Hardsigmoid(...)`_
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
inplace : bool, optional, default=False
Whether to do the operation in-place.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return _functions.HardSigmoid \
.instantiate(input.device, alpha=1. / 6., beta=0.5) \
.apply(input, inplace=inplace)
def hardswish(input):
r"""Apply the hard swish function to input.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
The **HardSwish** function is defined as:
.. math::
\text{Hardsigmoid}(x) = \begin{cases}
0 & \text{if~} x \le -3, \\
x & \text{if~} x \ge +3, \\
x \cdot (x + 3) /6 & \text{otherwise}
\end{cases}
See Also
--------
`torch.nn.Hardswish(...)`_
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return _functions.HardSwish \
.instantiate(input.device, alpha=1. / 6., beta=0.5) \
.apply(input)
def interpolate(
input,
size=None,
......@@ -1516,6 +1580,32 @@ def softmax(input, dim, inplace=False):
.apply(input, inplace=inplace)
def swish(input):
r"""Apply the swish function to input.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
The **Swish** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.Swish(...)`_
"""
return _activation(input, False, 'Swish')
def sync_batch_norm(
input,
running_mean,
......@@ -1532,8 +1622,7 @@ def sync_batch_norm(
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The moving average of stats are calculated as:
......
......@@ -19,6 +19,8 @@ from dragon.vm.torch.core.autograd import function
class _Activation(function.Function):
"""Base activation function class."""
def __init__(self, key, dev, **kwargs):
super(_Activation, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
......@@ -32,6 +34,8 @@ class _Activation(function.Function):
class _ConvNd(function.Function):
"""Base convolution function class."""
def __init__(self, key, dev, **kwargs):
super(_ConvNd, self).__init__(key, dev, **kwargs)
self.num_output = kwargs.get('out_channels', 1)
......@@ -62,6 +66,8 @@ class _ConvNd(function.Function):
class _Loss(function.Function):
"""Base loss function class."""
def __init__(self, key, dev, **kwargs):
super(_Loss, self).__init__(key, dev, **kwargs)
self.reduction = kwargs.get('reduction', 'mean').upper()
......@@ -71,6 +77,8 @@ class _Loss(function.Function):
class _PoolNd(function.Function):
"""Base pooling function class."""
def __init__(self, key, dev, **kwargs):
super(_PoolNd, self).__init__(key, dev, **kwargs)
self.kernel_shape = kwargs.get('kernel_shape', 1)
......@@ -99,6 +107,8 @@ class _PoolNd(function.Function):
class BatchNorm(function.Function):
"""BatchNorm function."""
def __init__(self, key, dev, **kwargs):
super(BatchNorm, self).__init__(key, dev, **kwargs)
self.momentum = kwargs.get('momentum', 0.1)
......@@ -122,16 +132,22 @@ class BatchNorm(function.Function):
class Conv2d(_ConvNd):
"""Conv2d function."""
def __init__(self, key, dev, **kwargs):
super(Conv2d, self).__init__(key, dev, **kwargs)
class ConvTranspose2d(_ConvNd):
"""ConvTranspose2d function."""
def __init__(self, key, dev, **kwargs):
super(ConvTranspose2d, self).__init__(key, dev, **kwargs)
class CTCLoss(_Loss):
"""CTCLoss function."""
def __init__(self, key, dev, **kwargs):
super(CTCLoss, self).__init__(key, dev, **kwargs)
self.padding_mask = kwargs.get('padding_mask', -1)
......@@ -147,11 +163,15 @@ class CTCLoss(_Loss):
class DepthwiseConv2d(_ConvNd):
"""DepthwiseConv2d function."""
def __init__(self, key, dev, **kwargs):
super(DepthwiseConv2d, self).__init__(key, dev, **kwargs)
class DropBlock2d(_Activation):
"""DropBlock2d function."""
def __init__(self, key, dev, **kwargs):
super(DropBlock2d, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', 7)
......@@ -173,6 +193,8 @@ class DropBlock2d(_Activation):
class Dropout(_Activation):
"""Dropout function."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
self.p = kwargs.get('p', 0.5)
......@@ -182,6 +204,8 @@ class Dropout(_Activation):
class DropPath(_Activation):
"""DropPath function."""
def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs)
self.p = kwargs.get('p', 0.2)
......@@ -198,6 +222,8 @@ class DropPath(_Activation):
class Elu(_Activation):
"""ELU function."""
def __init__(self, key, dev, **kwargs):
super(Elu, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.)
......@@ -211,70 +237,67 @@ class Elu(_Activation):
}
class Linear(function.Function):
class GroupNorm(function.Function):
"""GroupNorm function."""
def __init__(self, key, dev, **kwargs):
super(Linear, self).__init__(key, dev, **kwargs)
super(GroupNorm, self).__init__(key, dev, **kwargs)
self.group = kwargs.get('group', 32)
self.epsilon = kwargs.get('epsilon', 1e-5)
def attributes(self):
return {
'op_type': 'FullyConnected',
'op_type': 'GroupNorm',
'arguments': {
'axis': -1,
'transW': True,
},
'axis': 1,
'group': self.group,
'epsilon': self.epsilon,
}
}
def forward(self, input, weight, bias=None, out=None):
inputs = [input, weight] + ([bias] if bias else [])
outputs = [out] if out else [self.alloc()]
return self.dispatch(inputs, outputs)
def forward(self, input, weight, bias):
return self.dispatch([input, weight, bias], [self.alloc()])
class LocalResponseNorm(function.Function):
class HardSigmoid(_Activation):
"""HardSigmoid function."""
def __init__(self, key, dev, **kwargs):
super(LocalResponseNorm, self).__init__(key, dev, **kwargs)
self.size = kwargs.get('size', 5)
self.alpha = kwargs.get('alpha', 0.0001)
self.beta = kwargs.get('beta', 0.75)
self.bias = kwargs.get('bias', 1.)
super(HardSigmoid, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.2)
self.beta = kwargs.get('beta', 0.5)
def attributes(self):
return {
'op_type': 'LRN',
'op_type': 'HardSigmoid',
'arguments': {
'size': self.size,
'alpha': self.alpha,
'beta': self.beta,
'bias': self.bias,
'data_format': 'NCHW',
}
'alpha': float(self.alpha),
'beta': float(self.beta),
},
}
def forward(self, input):
return self.dispatch([input], [self.alloc()])
class HardSwish(_Activation):
"""HardSwish function."""
class GroupNorm(function.Function):
def __init__(self, key, dev, **kwargs):
super(GroupNorm, self).__init__(key, dev, **kwargs)
self.group = kwargs.get('group', 32)
self.epsilon = kwargs.get('epsilon', 1e-5)
super(HardSwish, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.2)
self.beta = kwargs.get('beta', 0.5)
def attributes(self):
return {
'op_type': 'GroupNorm',
'op_type': 'HardSwish',
'arguments': {
'axis': 1,
'group': self.group,
'epsilon': self.epsilon,
}
'alpha': float(self.alpha),
'beta': float(self.beta),
},
}
def forward(self, input, weight, bias):
return self.dispatch([input, weight, bias], [self.alloc()])
class L1Loss(_Loss):
"""L1Loss function."""
def __init__(self, key, dev, **kwargs):
super(L1Loss, self).__init__(key, dev, **kwargs)
......@@ -289,6 +312,8 @@ class L1Loss(_Loss):
class L2Loss(_Loss):
"""L2Loss function."""
def __init__(self, key, dev, **kwargs):
super(L2Loss, self).__init__(key, dev, **kwargs)
......@@ -302,7 +327,56 @@ class L2Loss(_Loss):
}
class Linear(function.Function):
"""Linear function."""
def __init__(self, key, dev, **kwargs):
super(Linear, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'FullyConnected',
'arguments': {
'axis': -1,
'transW': True,
},
}
def forward(self, input, weight, bias=None, out=None):
inputs = [input, weight] + ([bias] if bias else [])
outputs = [out] if out else [self.alloc()]
return self.dispatch(inputs, outputs)
class LocalResponseNorm(function.Function):
"""LocalResponseNorm function."""
def __init__(self, key, dev, **kwargs):
super(LocalResponseNorm, self).__init__(key, dev, **kwargs)
self.size = kwargs.get('size', 5)
self.alpha = kwargs.get('alpha', 0.0001)
self.beta = kwargs.get('beta', 0.75)
self.bias = kwargs.get('bias', 1.)
def attributes(self):
return {
'op_type': 'LRN',
'arguments': {
'size': self.size,
'alpha': self.alpha,
'beta': self.beta,
'bias': self.bias,
'data_format': 'NCHW',
}
}
def forward(self, input):
return self.dispatch([input], [self.alloc()])
class LpNormalize(function.Function):
"""LpNormalize function."""
def __init__(self, key, dev, **kwargs):
super(LpNormalize, self).__init__(key, dev, **kwargs)
self.p = kwargs.get('p', 2)
......@@ -326,6 +400,8 @@ class LpNormalize(function.Function):
class LSTMCell(function.Function):
"""LSTMCell function."""
def __init__(self, key, dev, **kwargs):
super(LSTMCell, self).__init__(key, dev, **kwargs)
......@@ -338,6 +414,8 @@ class LSTMCell(function.Function):
class NLLLoss(_Loss):
"""NLLLoss function."""
def __init__(self, key, dev, **kwargs):
super(NLLLoss, self).__init__(key, dev, **kwargs)
self.ignore_index = kwargs.get('ignore_index', None)
......@@ -354,6 +432,8 @@ class NLLLoss(_Loss):
class Pad(function.Function):
"""Pad function."""
def __init__(self, key, dev, **kwargs):
super(Pad, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -390,11 +470,15 @@ class Pad(function.Function):
class Pool2d(_PoolNd):
"""Pool2d function."""
def __init__(self, key, dev, **kwargs):
super(Pool2d, self).__init__(key, dev, **kwargs)
class PRelu(function.Function):
"""PRelu function."""
def __init__(self, key, dev, **kwargs):
super(PRelu, self).__init__(key, dev, **kwargs)
......@@ -411,6 +495,8 @@ class PRelu(function.Function):
class Recurrent(function.Function):
"""Recurrent function."""
def __init__(self, key, dev, **kwargs):
super(Recurrent, self).__init__(key, dev, **kwargs)
self.mode = kwargs.get('mode', 'rnn_tanh')
......@@ -447,6 +533,8 @@ class Recurrent(function.Function):
class Relu(_Activation):
"""Relu function."""
def __init__(self, key, dev, **kwargs):
super(Relu, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.)
......@@ -461,6 +549,8 @@ class Relu(_Activation):
class Relu6(_Activation):
"""Relu6 function."""
def __init__(self, key, dev, **kwargs):
super(Relu6, self).__init__(key, dev, **kwargs)
......@@ -474,6 +564,8 @@ class Relu6(_Activation):
class Resize(function.Function):
"""Resize function."""
def __init__(self, key, dev, **kwargs):
super(Resize, self).__init__(key, dev, **kwargs)
self.num_sizes = kwargs.get('num_sizes', 0)
......@@ -518,6 +610,8 @@ class Resize(function.Function):
class RNNParamSet(function.Function):
"""RNNParamSet function."""
def __init__(self, key, dev, **kwargs):
super(RNNParamSet, self).__init__(key, dev, **kwargs)
self.param_type = kwargs.get('param_type', 'matrix')
......@@ -549,6 +643,8 @@ class RNNParamSet(function.Function):
class SigmoidCrossEntropy(_Loss):
"""SigmoidCrossEntropy function."""
def __init__(self, key, dev, **kwargs):
super(SigmoidCrossEntropy, self).__init__(key, dev, **kwargs)
......@@ -562,6 +658,8 @@ class SigmoidCrossEntropy(_Loss):
class SigmoidFocalLoss(_Loss):
"""SigmoidFocalLoss function."""
def __init__(self, key, dev, **kwargs):
super(SigmoidFocalLoss, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 0.25)
......@@ -582,6 +680,8 @@ class SigmoidFocalLoss(_Loss):
class SmoothL1Loss(_Loss):
"""SmoothL1Loss function."""
def __init__(self, key, dev, **kwargs):
super(SmoothL1Loss, self).__init__(key, dev, **kwargs)
self.beta = kwargs.get('beta', 1.)
......@@ -597,6 +697,8 @@ class SmoothL1Loss(_Loss):
class Softmax(_Activation):
"""Softmax function."""
def __init__(self, key, dev, **kwargs):
super(Softmax, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
......@@ -611,6 +713,8 @@ class Softmax(_Activation):
class SparseSoftmaxCrossEntropy(_Loss):
"""SparseSoftmaxCrossEntropy function."""
def __init__(self, key, dev, **kwargs):
super(SparseSoftmaxCrossEntropy, self).__init__(key, dev, **kwargs)
self.ignore_index = kwargs.get('ignore_index', None)
......@@ -627,6 +731,8 @@ class SparseSoftmaxCrossEntropy(_Loss):
class SyncBatchNorm(BatchNorm):
"""SyncBatchNorm function."""
def __init__(self, key, dev, **kwargs):
super(SyncBatchNorm, self).__init__(key, dev, **kwargs)
self.process_group = kwargs.get('process_group', None)
......
......@@ -124,6 +124,87 @@ class GumbelSoftmax(Module):
return F.softmax(scores, self.dim, self.inplace)
class Hardsigmoid(Module):
r"""Apply the hard sigmoid function.
The **HardSigmoid** function is defined as:
.. math::
\text{Hardsigmoid}(x) = \begin{cases}
0 & \text{if~} x \le -3, \\
1 & \text{if~} x \ge +3, \\
x / 6 + 1 / 2 & \text{otherwise}
\end{cases}
Examples:
```python
m = torch.nn.Hardsigmoid()
x = torch.randn(2, 3)
y = m(x)
```
See Also
--------
`torch.nn.functional.hardsigmoid(...)`_
"""
def __init__(self, inplace=False):
"""Create a ``Hardsigmoid`` module.
Parameters
----------
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(Hardsigmoid, self).__init__()
self.inplace = inplace
def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
return inplace_str
def forward(self, input):
return F.hardsigmoid(input, self.inplace)
class Hardswish(Module):
r"""Apply the hard swish function.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
The **HardSwish** function is defined as:
.. math::
\text{Hardsigmoid}(x) = \begin{cases}
0 & \text{if~} x \le -3, \\
x & \text{if~} x \ge +3, \\
x \cdot (x + 3) /6 & \text{otherwise}
\end{cases}
Examples:
```python
m = torch.nn.Hardswish()
x = torch.randn(2, 3)
y = m(x)
```
See Also
--------
`torch.nn.functional.hardswish(...)`_
"""
def __init__(self):
"""Create a ``Hardswish`` module."""
super(Hardswish, self).__init__()
def forward(self, input):
return F.hardswish(input)
class LeakyReLU(Module):
r"""Apply the leaky rectified linear unit.
......@@ -494,6 +575,36 @@ class Softmax(Module):
return F.softmax(input, self.dim, self.inplace)
class Swish(Module):
r"""Apply the swish function.
`[Ramachandran et.al, 2017] <https://arxiv.org/abs/1710.05941>`_.
The **Swish** function is defined as:
.. math:: \text{Swish}(x) = x \cdot \frac{1}{1 + \exp(-x)}
Examples:
```python
m = torch.nn.Swish()
x = torch.randn(2, 3)
y = m(x)
```
See Also
--------
`torch.nn.functional.swish(...)`_
"""
def __init__(self):
"""Create a ``Swish`` module."""
super(Swish, self).__init__()
def forward(self, input):
return F.swish(input)
class Tanh(Module):
r"""Apply the tanh function.
......
......@@ -89,13 +89,11 @@ class BatchNorm1d(_BatchNorm):
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The running average of statistics are calculated as:
.. math::
x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
.. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
See Also
--------
......@@ -140,13 +138,11 @@ class BatchNorm2d(_BatchNorm):
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The running average of statistics are calculated as:
.. math::
x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
.. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
See Also
--------
......@@ -191,13 +187,11 @@ class BatchNorm3d(_BatchNorm):
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The running average of statistics are calculated as:
.. math::
x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
.. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
See Also
--------
......@@ -242,13 +236,11 @@ class SyncBatchNorm(_BatchNorm):
The normalization is defined as:
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
.. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The running average of statistics are calculated as:
.. math::
x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
.. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}}
Additionally, specify ``process_group`` to perform synchronization.
......
......@@ -297,7 +297,7 @@ class KLDivLoss(_Loss):
The flag indicating whether ``target`` is passed in log space.
"""
super(KDivLoss, self).__init__(size_average, reduce, reduction)
super(KLDivLoss, self).__init__(size_average, reduce, reduction)
self.log_target = log_target
def forward(self, input, target):
......
......@@ -18,6 +18,8 @@ from dragon.vm.torch.core.autograd import function
class ArgReduce(function.Function):
"""ArgReduce function."""
def __init__(self, key, dev, **kwargs):
super(ArgReduce, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', 'ArgMax')
......@@ -38,6 +40,8 @@ class ArgReduce(function.Function):
class Assign(function.Function):
"""Assign function."""
def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -76,6 +80,8 @@ class Assign(function.Function):
class Cast(function.Function):
"""Cast function."""
def __init__(self, key, dev, **kwargs):
super(Cast, self).__init__(key, dev, **kwargs)
self.dtype = kwargs.get('dtype', 'float32')
......@@ -95,6 +101,8 @@ class Cast(function.Function):
class ChannelAffine(function.Function):
"""ChannelAffine function."""
def __init__(self, key, dev, **kwargs):
super(ChannelAffine, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
......@@ -115,6 +123,8 @@ class ChannelAffine(function.Function):
class ChannelNormalize(function.Function):
"""ChannelNormalize function."""
def __init__(self, key, dev, **kwargs):
super(ChannelNormalize, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -152,6 +162,8 @@ class ChannelNormalize(function.Function):
class ChannelShuffle(function.Function):
"""ChannelShuffle function."""
def __init__(self, key, dev, **kwargs):
super(ChannelShuffle, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -171,6 +183,8 @@ class ChannelShuffle(function.Function):
class Concat(function.Function):
"""Concat function."""
def __init__(self, key, dev, **kwargs):
super(Concat, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -186,6 +200,8 @@ class Concat(function.Function):
class Cumulative(function.Function):
"""Cumulative function."""
def __init__(self, key, dev, **kwargs):
super(Cumulative, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -208,6 +224,8 @@ class Cumulative(function.Function):
class Expand(function.Function):
"""Expand function."""
def __init__(self, key, dev, **kwargs):
super(Expand, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -237,6 +255,8 @@ class Expand(function.Function):
class Flatten(function.Function):
"""Flatten function."""
def __init__(self, key, dev, **kwargs):
super(Flatten, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -256,6 +276,8 @@ class Flatten(function.Function):
class IndexSelect(function.Function):
"""IndexSelect function."""
def __init__(self, key, dev, **kwargs):
super(IndexSelect, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -275,6 +297,8 @@ class IndexSelect(function.Function):
class MaskedAssign(function.Function):
"""MaskedAssign function."""
def __init__(self, key, dev, **kwargs):
super(MaskedAssign, self).__init__(key, dev, **kwargs)
......@@ -286,6 +310,8 @@ class MaskedAssign(function.Function):
class MaskedSelect(function.Function):
"""MaskedSelect function."""
def __init__(self, key, dev, **kwargs):
super(MaskedSelect, self).__init__(key, dev, **kwargs)
......@@ -297,6 +323,8 @@ class MaskedSelect(function.Function):
class Multinomial(function.Function):
"""Multinomial function."""
def __init__(self, key, dev, **kwargs):
super(Multinomial, self).__init__(key, dev, **kwargs)
self.epsilon = kwargs.get('epsilon', 0.)
......@@ -317,6 +345,8 @@ class Multinomial(function.Function):
class NonZero(function.Function):
"""NonZero function."""
def __init__(self, key, dev, **kwargs):
super(NonZero, self).__init__(key, dev, **kwargs)
......@@ -328,6 +358,8 @@ class NonZero(function.Function):
class OneHot(function.Function):
"""OneHot function."""
def __init__(self, key, dev, **kwargs):
super(OneHot, self).__init__(key, dev, **kwargs)
self.depth = kwargs.get('depth', 1)
......@@ -345,6 +377,8 @@ class OneHot(function.Function):
class Reduce(function.Function):
"""Reduce function."""
def __init__(self, key, dev, **kwargs):
super(Reduce, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None)
......@@ -365,6 +399,8 @@ class Reduce(function.Function):
class Reshape(function.Function):
"""Reshape function."""
def __init__(self, key, dev, **kwargs):
super(Reshape, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -394,6 +430,8 @@ class Reshape(function.Function):
class Slice(function.Function):
"""Slice function."""
def __init__(self, key, dev, **kwargs):
super(Slice, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -429,6 +467,8 @@ class Slice(function.Function):
class Sort(function.Function):
"""Sort function."""
def __init__(self, key, dev, **kwargs):
super(Sort, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
......@@ -449,6 +489,8 @@ class Sort(function.Function):
class Split(function.Function):
"""Split function."""
def __init__(self, key, dev, **kwargs):
super(Split, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -469,6 +511,8 @@ class Split(function.Function):
class Stack(function.Function):
"""Stack function."""
def __init__(self, key, dev, **kwargs):
super(Stack, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -486,6 +530,8 @@ class Stack(function.Function):
class Squeeze(function.Function):
"""Squeeze function."""
def __init__(self, key, dev, **kwargs):
super(Squeeze, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None)
......@@ -503,6 +549,8 @@ class Squeeze(function.Function):
class Tile(function.Function):
"""Tile function."""
def __init__(self, key, dev, **kwargs):
super(Tile, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -534,6 +582,8 @@ class Tile(function.Function):
class Transpose(function.Function):
"""Transpose function."""
def __init__(self, key, dev, **kwargs):
super(Transpose, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -563,6 +613,8 @@ class Transpose(function.Function):
class TopK(function.Function):
"""TopK function."""
def __init__(self, key, dev, **kwargs):
super(TopK, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 1)
......@@ -587,6 +639,8 @@ class TopK(function.Function):
class Unique(function.Function):
"""Unique function."""
def __init__(self, key, dev, **kwargs):
super(Unique, self).__init__(key, dev, **kwargs)
self.return_inverse = kwargs.get('return_inverse', False)
......@@ -608,6 +662,8 @@ class Unique(function.Function):
class UnSqueeze(function.Function):
"""UnSqueeze function."""
def __init__(self, key, dev, **kwargs):
super(UnSqueeze, self).__init__(key, dev, **kwargs)
self.axes = kwargs.get('axes', None)
......@@ -625,6 +681,8 @@ class UnSqueeze(function.Function):
class Where(function.Function):
"""Where function."""
def __init__(self, key, dev, **kwargs):
super(Where, self).__init__(key, dev, **kwargs)
......
......@@ -18,6 +18,8 @@ from dragon.vm.torch.core.autograd import function
class Collective(function.Function):
"""Collective function."""
def __init__(self, key, dev, **kwargs):
super(Collective, self).__init__(key, dev, **kwargs)
self.root = kwargs.get('root', 0)
......
......@@ -18,6 +18,8 @@ from dragon.vm.torch.core.autograd import function
class _Initializer(function.Function):
"""Base initializer function."""
def __init__(self, key, dev, **kwargs):
super(_Initializer, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -39,6 +41,8 @@ class _Initializer(function.Function):
class Eye(_Initializer):
"""Eye function."""
def __init__(self, key, dev, **kwargs):
super(Eye, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 0)
......@@ -57,6 +61,8 @@ class Eye(_Initializer):
class Fill(_Initializer):
"""Fill function."""
def __init__(self, key, dev, **kwargs):
super(Fill, self).__init__(key, dev, **kwargs)
self.value = kwargs.get('value', 0.)
......@@ -75,6 +81,8 @@ class Fill(_Initializer):
class LinSpace(function.Function):
"""LinSpace function."""
def __init__(self, key, dev, **kwargs):
super(LinSpace, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
......@@ -123,6 +131,8 @@ class LinSpace(function.Function):
class Permutation(function.Function):
"""Permutation function."""
def __init__(self, key, dev, **kwargs):
super(Permutation, self).__init__(key, dev, **kwargs)
self.dtype = kwargs.get('dtype', 'int64')
......@@ -149,6 +159,8 @@ class Permutation(function.Function):
class RandomNormal(_Initializer):
"""RandomNormal function."""
def __init__(self, key, dev, **kwargs):
super(RandomNormal, self).__init__(key, dev, **kwargs)
self.mean = kwargs.get('mean', 0.)
......@@ -169,6 +181,8 @@ class RandomNormal(_Initializer):
class RandomUniform(_Initializer):
"""RandomUniform function."""
def __init__(self, key, dev, **kwargs):
super(RandomUniform, self).__init__(key, dev, **kwargs)
self.low = kwargs.get('low', 0.)
......@@ -189,6 +203,8 @@ class RandomUniform(_Initializer):
class Range(function.Function):
"""Range function."""
def __init__(self, key, dev, **kwargs):
super(Range, self).__init__(key, dev, **kwargs)
self.num_args = kwargs.get('num_args', 3)
......
......@@ -135,6 +135,7 @@ def eye(
def fill(out, shape, value):
"""Fill a tensor with a scalar."""
return _functions.Fill \
.instantiate(
out.device,
......@@ -144,10 +145,93 @@ def fill(out, shape, value):
).apply(out, shape)
def fill_like(out, shape_like, value):
def fill_like(out, other, value):
"""Fill a tensor with a scalar as the other."""
return _functions.Fill \
.instantiate(out.device, value=float(value), dtype=out.dtype) \
.apply(out, [], shape_like)
.apply(out, [], other)
def full(
size,
fill_value,
out=None,
dtype='int64',
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Examples:
```python
print(torch.full((1, 2), 1)) # [[1, 1]]
```
Parameters
----------
size : Sequence[int]
The output shape.
fill_value : number
The scalar to fill.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
dtype : str, optional, default='int64'
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
out = out if out else utils.new_leaf(size, locals())
return fill(out, size, fill_value)
def full_like(
input,
fill_value,
out=None,
dtype='int64',
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar with size as input.
Examples:
```python
print(torch.full_like(torch.zeros(1, 2), 1)) # [[1, 1]]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The tensor for indicating shape.
fill_value : number
The scalar to fill.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
dtype : str, optional, default='int64'
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
out = utils.new_leaf(input.shape, locals())
return fill_like(out, input, fill_value)
def linspace(
......
......@@ -18,6 +18,8 @@ from dragon.vm.torch.core.autograd import function
class Axpby(function.Function):
"""Axpby function."""
def __init__(self, key, dev, **kwargs):
super(Axpby, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.)
......@@ -37,6 +39,8 @@ class Axpby(function.Function):
class BinaryFunc(function.Function):
"""Binary function."""
def __init__(self, key, dev, **kwargs):
super(BinaryFunc, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
......@@ -49,6 +53,8 @@ class BinaryFunc(function.Function):
class Clip(function.Function):
"""Clip function."""
def __init__(self, key, dev, **kwargs):
super(Clip, self).__init__(key, dev, **kwargs)
self.min = kwargs.get('min', None)
......@@ -72,6 +78,8 @@ class Clip(function.Function):
class UnaryFunc(function.Function):
"""Unary function."""
def __init__(self, key, dev, **kwargs):
super(UnaryFunc, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
......@@ -84,6 +92,8 @@ class UnaryFunc(function.Function):
class MatMul(function.Function):
"""MatMul function."""
def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False)
......
......@@ -719,6 +719,50 @@ def floor_(self):
return math_funcs.floor(self, self)
def new_full(
self,
size,
fill_value,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : Sequence[int]
The size of output tensor.
fill_value : number
The scalar to fill.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.full(...)`_
"""
return init_funcs.full(
size,
fill_value,
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def ge(self, other):
r"""Compute the element-wise greater-equal comparison.
......@@ -1208,6 +1252,24 @@ def neg(self):
return math_funcs.neg(self)
def neg_(self):
r"""Compute the element-wise negative.
.. math:: \text{self} = -\text{self}
Returns
-------
dragon.vm.torch.Tensor
The self.
See Also
--------
`torch.neg(...)`_
"""
return math_funcs.neg(self, self)
def nonzero(self):
r"""Return the index of non-zero elements.
......@@ -1993,6 +2055,8 @@ Tensor.multinomial = multinomial
Tensor.narrow = narrow
Tensor.ne = ne
Tensor.neg = neg
Tensor.neg_ = neg_
Tensor.new_full = new_full
Tensor.nonzero = nonzero
Tensor.normal_ = normal_
Tensor.permute = permute
......
......@@ -18,6 +18,8 @@ from dragon.vm.torch.core.autograd import function
class ParamUpdate(function.Function):
"""ParamUpdate function."""
def __init__(self, key, dev, **kwargs):
super(ParamUpdate, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
......@@ -41,6 +43,8 @@ class ParamUpdate(function.Function):
class GradAccumulate(function.Function):
"""GradAccumulate function."""
def __init__(self, key, dev, **kwargs):
super(GradAccumulate, self).__init__(key, dev, **kwargs)
self.momentum = kwargs.get('momentum', 1)
......@@ -48,7 +52,10 @@ class GradAccumulate(function.Function):
def attributes(self):
return {
'op_type': 'Axpby',
'arguments': {'alpha': 1., 'beta': float(self.momentum)},
'arguments': {
'alpha': 1.0,
'beta': float(self.momentum),
},
}
def forward(self, grad):
......
......@@ -21,11 +21,11 @@ from dragon.vm.torch.core import cpp
from dragon.vm.torch.core.tensor import Tensor
def new_leaf(sizes, kwargs):
def new_leaf(size, kwargs):
"""Return a leaf tensor from optional kwargs."""
device = kwargs.get('device', cpp.device())
return Tensor(
*sizes,
*size,
dtype=kwargs.get('dtype', 'float32'),
device=cpp.device() if device is None else device,
requires_grad=kwargs.get('requires_grad', False)
......
......@@ -19,6 +19,7 @@ from dragon.core.framework import config
from dragon.core.framework import context
from dragon.core.framework import proto_util
from dragon.core.framework import workspace
from dragon.core.util import nest
from dragon.core.util import six
from dragon.core.util import string
from dragon.vm.torch.core import cpp
......@@ -1349,6 +1350,177 @@ class Tensor(object):
"""
def neg_(self):
r"""Compute the element-wise negative.
.. math:: \text{self} = -\text{self}
Returns
-------
dragon.vm.torch.Tensor
The self.
See Also
--------
`torch.neg(...)`_
"""
def new_empty(
self,
*size,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with uninitialized data.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : int...
The size of output tensor.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.empty(...)`_
"""
return empty(
*nest.flatten(size),
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def new_full(
self,
size,
fill_value,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : Sequence[int]
The size of output tensor.
fill_value : number
The scalar to fill.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.full(...)`_
"""
def new_ones(
self,
*size,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with with ones.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : int...
The size of output tensor.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.ones(...)`_
"""
return self.new_full(
nest.flatten(size),
fill_value=1,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
def new_zeros(
self,
*size,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with with zeros.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : int...
The size of output tensor.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.zeros(...)`_
"""
return self.new_full(
nest.flatten(size),
fill_value=0,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
def nonzero(self):
r"""Return the index of non-zero elements.
......@@ -2530,12 +2702,12 @@ class LongTensor(object):
return Tensor(*args, **kwargs)
def empty(*sizes, dtype=None, device=None, requires_grad=False):
def empty(*size, dtype=None, device=None, requires_grad=False):
"""Return a tensor filled with uninitialized data.
Parameters
----------
sizes : int...
size : int...
The sizes of output tensor.
dtype : str, optional
The optional data type.
......@@ -2551,7 +2723,7 @@ def empty(*sizes, dtype=None, device=None, requires_grad=False):
"""
return Tensor(
*sizes,
*size,
dtype=dtype if dtype else 'float32',
device=cpp.device() if device is None else device,
requires_grad=requires_grad,
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!