Commit b7e959e9 by Ting PAN

Remove HardSwish arguments

Summary:
This commit fixes the alpha and beta to 1/6 and 0.5 for hardswish,
the same behavior as ONNX scheme.
1 parent 094c8c32
...@@ -85,6 +85,7 @@ Name Supported Reference ...@@ -85,6 +85,7 @@ Name Supported Reference
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool` `GlobalMaxPool`_ |v| :func:`dragon.nn.pool`
`Greater`_ |v| :func:`dragon.math.greater` `Greater`_ |v| :func:`dragon.math.greater`
`HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid` `HardSigmoid`_ |v| :func:`dragon.nn.hardsigmoid`
`HardSwish`_ |v| :func:`dragon.nn.hardswish`
`Hardmax`_ `Hardmax`_
`Identity`_ |v| :func:`dragon.identity` `Identity`_ |v| :func:`dragon.identity`
`If`_ `If`_
...@@ -251,6 +252,7 @@ Name Supported Reference ...@@ -251,6 +252,7 @@ Name Supported Reference
.. _GlobalMaxPool: https://github.com/onnx/onnx/blob/master/docs/Operators.md#GlobalMaxPool .. _GlobalMaxPool: https://github.com/onnx/onnx/blob/master/docs/Operators.md#GlobalMaxPool
.. _Greater: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Greater .. _Greater: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Greater
.. _HardSigmoid: https://github.com/onnx/onnx/blob/master/docs/Operators.md#HardSigmoid .. _HardSigmoid: https://github.com/onnx/onnx/blob/master/docs/Operators.md#HardSigmoid
.. _HardSwish: https://github.com/onnx/onnx/blob/master/docs/Operators.md#HardSwish
.. _Hardmax: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Hardmax .. _Hardmax: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Hardmax
.. _Identity: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Identity .. _Identity: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Identity
.. _If: https://github.com/onnx/onnx/blob/master/docs/Operators.md#If .. _If: https://github.com/onnx/onnx/blob/master/docs/Operators.md#If
......
...@@ -51,6 +51,8 @@ void _HardSigmoidGrad<float16>( ...@@ -51,6 +51,8 @@ void _HardSigmoidGrad<float16>(
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void HardSigmoid<T, CPUContext>( \ void HardSigmoid<T, CPUContext>( \
......
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -9,52 +10,30 @@ namespace kernels { ...@@ -9,52 +10,30 @@ namespace kernels {
namespace { namespace {
template <typename T> template <typename T, typename AccT>
__global__ void __global__ void
_HardSigmoid(const int N, const T alpha, const T beta, const T* x, T* y) { _HardSigmoid(const int N, const AccT alpha, const AccT beta, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
y[i] = max(T(0), min(T(1), fma(x[i], alpha, beta))); const AccT s_val = fma(convert::To<AccT>(x[i]), alpha, beta);
y[i] = convert::To<T>(max(AccT(0), min(AccT(1), s_val)));
} }
} }
__global__ void _HardSigmoid( template <typename T, typename AccT>
const int N,
const float alpha,
const float beta,
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
y[i] =
__float2half(max(0.f, min(1.f, fma(__half2float(x[i]), alpha, beta))));
}
}
template <typename T>
__global__ void _HardSigmoidGrad( __global__ void _HardSigmoidGrad(
const int N, const int N,
const float alpha, const AccT alpha,
const T* dy, const T* dy,
const T* y, const T* y,
T* dx) { T* dx) {
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
dx[i] = (y[i] > T(0) && y[i] < T(1)) ? dy[i] * alpha : T(0); const AccT val = convert::To<AccT>(y[i]);
dx[i] = convert::To<T>(
(val > AccT(0) && val < AccT(1)) ? convert::To<AccT>(dy[i]) * alpha
: AccT(0));
} }
} }
template <>
__global__ void _HardSigmoidGrad<half>(
const int N,
const float alpha,
const half* dy,
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, N) {
const float val = __half2float(y[i]);
dx[i] = __half2float(
(val > 0.f && val < 1.f) ? __half2float(dy[i]) * alpha : 0.f);
}
} // HardSigmoidGrad
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
......
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_eigen.h" #include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
...@@ -9,43 +8,30 @@ namespace kernels { ...@@ -9,43 +8,30 @@ namespace kernels {
namespace { namespace {
template <typename T> template <typename T>
void _HardSwish(const int N, const T alpha, const T beta, const T* x, T* y) { void _HardSwish(const int N, const T* x, T* y) {
const T kAlpha = 0.1666666666666667;
ConstEigenVectorArrayMap<T> X(x, N); ConstEigenVectorArrayMap<T> X(x, N);
EigenVectorArrayMap<T>(y, N) = EigenVectorArrayMap<T>(y, N) =
X * ((X * alpha + beta).cwiseMin(T(1)).cwiseMax(T(0))); X * ((X * kAlpha + T(0.5)).cwiseMin(T(1)).cwiseMax(T(0)));
} }
template <> template <>
void _HardSwish<float16>( void _HardSwish<float16>(const int N, const float16* x, float16* y) {
const int N,
const float16 alpha,
const float16 beta,
const float16* x,
float16* y) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
template <typename T> template <typename T>
void _HardSwishGrad( void _HardSwishGrad(const int N, const T* dy, const T* x, T* dx) {
const int N, const T kAlpha2 = 0.3333333333333333;
const T alpha,
const T beta,
const T* dy,
const T* x,
T* dx) {
const auto bound = beta / alpha;
const auto alpha2x = alpha * T(2);
EigenVectorArrayMap<T>(dx, N) = ConstEigenVectorArrayMap<T>(dy, N) * EigenVectorArrayMap<T>(dx, N) = ConstEigenVectorArrayMap<T>(dy, N) *
ConstEigenVectorArrayMap<T>(x, N).unaryExpr([&](T a) { ConstEigenVectorArrayMap<T>(x, N).unaryExpr([&](T a) {
return (a < -bound) ? T(0) : (a < bound ? a * alpha2x + beta : T(1)); return a < T(-3) ? T(0) : (a < T(3) ? a * kAlpha2 + T(0.5) : T(1));
}); });
} }
template <> template <>
void _HardSwishGrad<float16>( void _HardSwishGrad<float16>(
const int N, const int N,
const float16 alpha,
const float16 beta,
const float16* dy, const float16* dy,
const float16* x, const float16* x,
float16* dx) { float16* dx) {
...@@ -54,29 +40,20 @@ void _HardSwishGrad<float16>( ...@@ -54,29 +40,20 @@ void _HardSwishGrad<float16>(
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void HardSwish<T, CPUContext>( \ void HardSwish<T, CPUContext>( \
const int N, \ const int N, const T* x, T* y, CPUContext* ctx) { \
const float alpha, \ _HardSwish(N, x, y); \
const float beta, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_HardSwish(N, convert::To<T>(alpha), convert::To<T>(beta), x, y); \
} }
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void HardSwishGrad<T, CPUContext>( \ void HardSwishGrad<T, CPUContext>( \
const int N, \ const int N, const T* dy, const T* x, T* dx, CPUContext* ctx) { \
const float alpha, \ _HardSwishGrad(N, dy, x, dx); \
const float beta, \
const T* dy, \
const T* x, \
T* dx, \
CPUContext* ctx) { \
_HardSwishGrad(N, convert::To<T>(alpha), convert::To<T>(beta), dy, x, dx); \
} }
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
......
...@@ -10,67 +10,27 @@ namespace kernels { ...@@ -10,67 +10,27 @@ namespace kernels {
namespace { namespace {
#define LDG(x, i) __ldg(x + i) template <typename T, typename AccT>
#define LDG2(x, i) __half2float(__ldg(x + i)) __global__ void _HardSwish(const int N, const T* x, T* y) {
template <typename T>
__global__ void
_HardSwish(const int N, const T alpha, const T beta, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
y[i] = LDG(x, i) * max(T(0), min(T(1), fma(LDG(x, i), alpha, beta)));
}
}
__global__ void _HardSwish(
const int N,
const float alpha,
const float beta,
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
y[i] = __float2half( const AccT val = convert::To<AccT>(x[i]);
LDG2(x, i) * max(0.f, min(1.f, fma(LDG2(x, i), alpha, beta)))); const AccT s_val = fma(val, AccT(0.1666666666666667), AccT(0.5));
y[i] = convert::To<T>(val * max(AccT(0), min(AccT(1), s_val)));
} }
} }
template <typename T> template <typename T, typename AccT>
__global__ void _HardSwishGrad( __global__ void _HardSwishGrad(const int N, const T* dy, const T* x, T* dx) {
const int N,
const T alpha,
const T beta,
const T* dy,
const T* x,
T* dx) {
const T bound = beta / alpha;
const T alpha2x = alpha * T(2);
CUDA_1D_KERNEL_LOOP(i, N) { CUDA_1D_KERNEL_LOOP(i, N) {
dx[i] = (LDG(x, i) < -bound) const AccT val = convert::To<AccT>(x[i]);
? T(0) dx[i] = (val < AccT(-3))
: (LDG(x, i) < bound) ? dy[i] * fma(LDG(x, i), alpha2x, beta) : dy[i]; ? convert::To<T>(AccT(0))
} : (val < AccT(3)) ? convert::To<T>(
} convert::To<AccT>(dy[i]) *
fma(val, AccT(0.3333333333333333), AccT(0.5)))
__global__ void _HardSwishGrad(
const int N,
const float alpha,
const float beta,
const half* dy,
const half* x,
half* dx) {
const float bound = beta / alpha;
const float alpha2x = alpha * 2.f;
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, N) {
const float val = __half2float(x[i]);
dx[i] = (val < -bound) ? kZero
: (val < bound)
? __float2half(__half2float(dy[i]) * fma(val, alpha2x, beta))
: dy[i]; : dy[i];
} }
} // HardSwishGrad }
#undef LDG
#undef LDG2
} // namespace } // namespace
...@@ -79,16 +39,10 @@ __global__ void _HardSwishGrad( ...@@ -79,16 +39,10 @@ __global__ void _HardSwishGrad(
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void HardSwish<T, CUDAContext>( \ void HardSwish<T, CUDAContext>( \
const int N, \ const int N, const T* x, T* y, CUDAContext* ctx) { \
const float alpha, \ _HardSwish<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
const float beta, \ <<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_HardSwish<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \ N, \
convert::To<math::AccmulatorType<T>::type>(alpha), \
convert::To<math::AccmulatorType<T>::type>(beta), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \ reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \ reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} }
...@@ -96,17 +50,10 @@ __global__ void _HardSwishGrad( ...@@ -96,17 +50,10 @@ __global__ void _HardSwishGrad(
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void HardSwishGrad<T, CUDAContext>( \ void HardSwishGrad<T, CUDAContext>( \
const int N, \ const int N, const T* dy, const T* x, T* dx, CUDAContext* ctx) { \
const float alpha, \ _HardSwishGrad<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
const float beta, \ <<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
const T* dy, \
const T* x, \
T* dx, \
CUDAContext* ctx) { \
_HardSwishGrad<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \ N, \
convert::To<math::AccmulatorType<T>::type>(alpha), \
convert::To<math::AccmulatorType<T>::type>(beta), \
reinterpret_cast<const math::ScalarType<T>::type*>(dy), \ reinterpret_cast<const math::ScalarType<T>::type*>(dy), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \ reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<math::ScalarType<T>::type*>(dx)); \ reinterpret_cast<math::ScalarType<T>::type*>(dx)); \
......
...@@ -8,52 +8,49 @@ namespace kernels { ...@@ -8,52 +8,49 @@ namespace kernels {
namespace { namespace {
template <typename T> template <typename T>
void _Swish(const int N, const T* x, T* y) { void _Silu(const int N, const T* x, T* y) {
ConstEigenVectorArrayMap<T> X(x, N); ConstEigenVectorArrayMap<T> X(x, N);
EigenVectorArrayMap<T>(y, N) = X / (T(1) + (-X).exp()); EigenVectorArrayMap<T>(y, N) = X / (T(1) + (-X).exp());
} }
template <> template <>
void _Swish<float16>(const int N, const float16* x, float16* y) { void _Silu<float16>(const int N, const float16* x, float16* y) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
template <typename T> template <typename T>
void _SwishGrad(const int N, const T* dy, const T* x, const T* y, T* dx) { void _SiluGrad(const int N, const T* dy, const T* x, T* dx) {
ConstEigenVectorArrayMap<T> X(x, N); ConstEigenVectorArrayMap<T> X(x, N);
ConstEigenVectorArrayMap<T> Y(y, N); ConstEigenVectorArrayMap<T> dY(dy, N);
EigenVectorArrayMap<T>(dx, N) = ConstEigenVectorArrayMap<T>(dy, N) * EigenVectorArrayMap<T> dX(dx, N);
(Y + (T(1) / (T(1) + (-X).exp())) * (T(1) - Y)); dX = T(1) / (T(1) + (-X).exp());
dX = dY * dX * (X + T(1) - X * dX);
} }
template <> template <>
void _SwishGrad<float16>( void _SiluGrad<float16>(
const int N, const int N,
const float16* dy, const float16* dy,
const float16* x, const float16* x,
const float16* y,
float16* dx) { float16* dx) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void Swish<T, CPUContext>(const int N, const T* x, T* y, CPUContext* ctx) { \ void Silu<T, CPUContext>(const int N, const T* x, T* y, CPUContext* ctx) { \
_Swish(N, x, y); \ _Silu(N, x, y); \
} }
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void SwishGrad<T, CPUContext>( \ void SiluGrad<T, CPUContext>( \
const int N, \ const int N, const T* dy, const T* x, T* dx, CPUContext* ctx) { \
const T* dy, \ _SiluGrad(N, dy, x, dx); \
const T* x, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
_SwishGrad(N, dy, x, y, dx); \
} }
DEFINE_KERNEL_LAUNCHER(float16); DEFINE_KERNEL_LAUNCHER(float16);
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, typename AccT>
__global__ void _Silu(const int N, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
const AccT val = convert::To<AccT>(x[i]);
y[i] = convert::To<T>(val / (AccT(1) + exp(-val)));
}
}
template <typename T, typename AccT>
__global__ void _SiluGrad(const int N, const T* dy, const T* x, T* dx) {
CUDA_1D_KERNEL_LOOP(i, N) {
const AccT val = convert::To<AccT>(x[i]);
const AccT s_val = AccT(1) / (AccT(1) + exp(-val));
dx[i] = convert::To<T>(
convert::To<AccT>(dy[i]) * s_val * (val + AccT(1) - val * s_val));
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Silu<T, CUDAContext>(const int N, const T* x, T* y, CUDAContext* ctx) { \
_Silu<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void SiluGrad<T, CUDAContext>( \
const int N, const T* dy, const T* x, T* dx, CUDAContext* ctx) { \
_SiluGrad<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
reinterpret_cast<const math::ScalarType<T>::type*>(dy), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<math::ScalarType<T>::type*>(dx)); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
#define LDG(x, i) __ldg(x + i)
#define LDG2(x, i) __half2float(__ldg(x + i))
template <typename T>
__global__ void _Swish(const int N, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
y[i] = LDG(x, i) / (T(1) + exp(-LDG(x, i)));
}
}
template <>
__global__ void _Swish<half>(const int N, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
y[i] = __float2half(LDG2(x, i) / (1.f + exp(-LDG2(x, i))));
}
}
template <typename T>
__global__ void
_SwishGrad(const int N, const T* dy, const T* x, const T* y, T* dx) {
CUDA_1D_KERNEL_LOOP(i, N) {
dx[i] = dy[i] * (LDG(y, i) + (T(1) - LDG(y, i)) / (T(1) + exp(-x[i])));
}
}
template <>
__global__ void _SwishGrad<half>(
const int N,
const half* dy,
const half* x,
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, N) {
dx[i] = __float2half(
__half2float(dy[i]) *
(LDG2(y, i) + (1.f - LDG2(y, i)) / (1.f + exp(-__half2float(x[i])))));
}
} // SwishGrad
#undef LDG
#undef LDG2
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Swish<T, CUDAContext>( \
const int N, const T* x, T* y, CUDAContext* ctx) { \
_Swish<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void SwishGrad<T, CUDAContext>( \
const int N, \
const T* dy, \
const T* x, \
const T* y, \
T* dx, \
CUDAContext* ctx) { \
_SwishGrad<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
reinterpret_cast<const math::ScalarType<T>::type*>(dy), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
reinterpret_cast<const math::ScalarType<T>::type*>(y), \
reinterpret_cast<math::ScalarType<T>::type*>(dx)); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
...@@ -17,11 +17,6 @@ void HardSigmoidOp<Context>::DoRunWithType() { ...@@ -17,11 +17,6 @@ void HardSigmoidOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void HardSigmoidOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T> template <typename T>
void HardSigmoidGradientOp<Context>::DoRunWithType() { void HardSigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0); auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
...@@ -34,11 +29,6 @@ void HardSigmoidGradientOp<Context>::DoRunWithType() { ...@@ -34,11 +29,6 @@ void HardSigmoidGradientOp<Context>::DoRunWithType() {
ctx()); ctx());
} }
template <class Context>
void HardSigmoidGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(HardSigmoid); DEPLOY_CPU_OPERATOR(HardSigmoid);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSigmoid); DEPLOY_CUDA_OPERATOR(HardSigmoid);
......
...@@ -26,7 +26,9 @@ class HardSigmoidOp : public Operator<Context> { ...@@ -26,7 +26,9 @@ class HardSigmoidOp : public Operator<Context> {
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {} beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
...@@ -43,7 +45,9 @@ class HardSigmoidGradientOp : public Operator<Context> { ...@@ -43,7 +45,9 @@ class HardSigmoidGradientOp : public Operator<Context> {
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)) {} alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
......
...@@ -9,37 +9,23 @@ void HardSwishOp<Context>::DoRunWithType() { ...@@ -9,37 +9,23 @@ void HardSwishOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
kernels::HardSwish( kernels::HardSwish(
X.count(), X.count(),
alpha_,
beta_,
X.template data<T, Context>(), X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
template <class Context> template <class Context>
void HardSwishOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T> template <typename T>
void HardSwishGradientOp<Context>::DoRunWithType() { void HardSwishGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(1), *dX = Output(0); auto &X = Input(0), &dY = Input(1), *dX = Output(0);
kernels::HardSwishGrad( kernels::HardSwishGrad(
X.count(), X.count(),
alpha_,
beta_,
dY.template data<T, Context>(), dY.template data<T, Context>(),
X.template data<T, Context>(), X.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(), dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
template <class Context>
void HardSwishGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(HardSwish); DEPLOY_CPU_OPERATOR(HardSwish);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(HardSwish); DEPLOY_CUDA_OPERATOR(HardSwish);
......
...@@ -20,37 +20,29 @@ namespace dragon { ...@@ -20,37 +20,29 @@ namespace dragon {
template <class Context> template <class Context>
class HardSwishOp : public Operator<Context> { class HardSwishOp : public Operator<Context> {
public: public:
HardSwishOp(const OperatorDef& def, Workspace* ws) SIMPLE_CTOR_DTOR(HardSwishOp);
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)),
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
float alpha_, beta_;
}; };
template <class Context> template <class Context>
class HardSwishGradientOp : public Operator<Context> { class HardSwishGradientOp : public Operator<Context> {
public: public:
HardSwishGradientOp(const OperatorDef& def, Workspace* ws) SIMPLE_CTOR_DTOR(HardSwishGradientOp);
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 0.2f)),
beta_(OP_SINGLE_ARG(float, "beta", 0.5f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
float alpha_, beta_;
}; };
} // namespace dragon } // namespace dragon
......
#include "dragon/operators/activation/swish_op.h" #include "dragon/operators/activation/silu_op.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void SwishOp<Context>::DoRunWithType() { void SiluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
kernels::Swish( kernels::Silu(
X.count(), X.count(),
X.template data<T, Context>(), X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
...@@ -15,67 +15,39 @@ void SwishOp<Context>::DoRunWithType() { ...@@ -15,67 +15,39 @@ void SwishOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void SwishOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T> template <typename T>
void SwishGradientOp<Context>::DoRunWithType() { void SiluGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &Y = Input(1); auto &X = Input(0), &dY = Input(1), *dX = Output(0);
auto &dY = Input(2), *dX = Output(0); kernels::SiluGrad(
kernels::SwishGrad(
X.count(), X.count(),
dY.template data<T, Context>(), dY.template data<T, Context>(),
X.template data<T, Context>(), X.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(), dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
template <class Context> DEPLOY_CPU_OPERATOR(Silu);
void SwishGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Swish);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Swish); DEPLOY_CUDA_OPERATOR(Silu);
#endif #endif
DEPLOY_CPU_OPERATOR(SwishGradient); DEPLOY_CPU_OPERATOR(SiluGradient);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SwishGradient); DEPLOY_CUDA_OPERATOR(SiluGradient);
#endif #endif
OPERATOR_SCHEMA(Swish) OPERATOR_SCHEMA(Silu)
/* X */ /* X */
.NumInputs(1) .NumInputs(1)
/* Y */ /* Y */
.NumOutputs(1); .NumOutputs(1);
OPERATOR_SCHEMA(SwishGradient) OPERATOR_SCHEMA(SiluGradient)
/* X, Y, dY */ /* X, dY */
.NumInputs(3) .NumInputs(2)
/* dX */ /* dX */
.NumOutputs(1); .NumOutputs(1);
namespace { REGISTER_GRADIENT(Silu, GenericGradientMaker);
class GradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
void CreateGradientDefs() override {
AddGradientDef(
def().type() + "Gradient",
"",
vector<string>({I(0), O(0), GO(0)}),
vector<string>({GI(0)}));
}
};
} // namespace
REGISTER_GRADIENT(Swish, GradientMaker);
} // namespace dragon } // namespace dragon
...@@ -10,32 +10,36 @@ ...@@ -10,32 +10,36 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_SILU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_SILU_OP_H_
#include "dragon/core/operator.h" #include "dragon/core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SwishOp : public Operator<Context> { class SiluOp : public Operator<Context> {
public: public:
SIMPLE_CTOR_DTOR(SwishOp); SIMPLE_CTOR_DTOR(SiluOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
}; };
template <class Context> template <class Context>
class SwishGradientOp : public Operator<Context> { class SiluGradientOp : public Operator<Context> {
public: public:
SIMPLE_CTOR_DTOR(SwishGradientOp); SIMPLE_CTOR_DTOR(SiluGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
...@@ -43,4 +47,4 @@ class SwishGradientOp : public Operator<Context> { ...@@ -43,4 +47,4 @@ class SwishGradientOp : public Operator<Context> {
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_SWISH_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_SILU_OP_H_
...@@ -79,11 +79,6 @@ void PadOp<Context>::DoRunWithType() { ...@@ -79,11 +79,6 @@ void PadOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void PadOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <class Context>
template <typename T> template <typename T>
void PadGradientOp<Context>::DoRunWithType() { void PadGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0); auto &dY = Input(0), *dX = Output(0);
...@@ -119,11 +114,6 @@ void PadGradientOp<Context>::DoRunWithType() { ...@@ -119,11 +114,6 @@ void PadGradientOp<Context>::DoRunWithType() {
} }
} }
template <class Context>
void PadGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Pad); DEPLOY_CPU_OPERATOR(Pad);
#ifdef USE_CUDA #ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Pad); DEPLOY_CUDA_OPERATOR(Pad);
......
...@@ -28,7 +28,9 @@ class PadOp final : public Operator<Context> { ...@@ -28,7 +28,9 @@ class PadOp final : public Operator<Context> {
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
...@@ -44,26 +46,18 @@ class PadGradientOp final : public Operator<Context> { ...@@ -44,26 +46,18 @@ class PadGradientOp final : public Operator<Context> {
public: public:
PadGradientOp(const OperatorDef& def, Workspace* ws) PadGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
pad_l_(OP_REPEATED_ARG(int64_t, "pad_l")), mode_(OP_SINGLE_ARG(string, "mode", "CONSTANT")) {}
pad_r_(OP_REPEATED_ARG(int64_t, "pad_r")),
mode_(OP_SINGLE_ARG(string, "mode", "CONSTANT")) {
if (pad_r_.empty()) {
pad_r_ = pad_l_;
} else {
CHECK_EQ(pad_l_.size(), pad_r_.size())
<< "\nThe <pad_l> and <pad_r> should have the same length.";
}
}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected: protected:
string mode_; string mode_;
vec64_t pad_l_, pad_r_;
}; };
DEFINE_OP_REPEATED_ARG(int64_t, PadOp, pads); DEFINE_OP_REPEATED_ARG(int64_t, PadOp, pads);
......
...@@ -280,7 +280,7 @@ def group_norm_args(**kwargs): ...@@ -280,7 +280,7 @@ def group_norm_args(**kwargs):
} }
@register(['HardSigmoid', 'HardSwish']) @register('HardSigmoid')
def hard_sigmoid_args(**kwargs): def hard_sigmoid_args(**kwargs):
return { return {
'alpha': kwargs.get('alpha', 0.2), 'alpha': kwargs.get('alpha', 0.2),
......
...@@ -249,7 +249,7 @@ def hardsigmoid(inputs, alpha=0.2, beta=0.5, inplace=False, **kwargs): ...@@ -249,7 +249,7 @@ def hardsigmoid(inputs, alpha=0.2, beta=0.5, inplace=False, **kwargs):
Examples: Examples:
```python ```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5]) x = dragon.constant([-2.5, -1., 0., 1., 2.5])
print(dragon.nn.hardsigmoid(x)) print(dragon.nn.hardsigmoid(x))
``` ```
...@@ -279,18 +279,18 @@ def hardsigmoid(inputs, alpha=0.2, beta=0.5, inplace=False, **kwargs): ...@@ -279,18 +279,18 @@ def hardsigmoid(inputs, alpha=0.2, beta=0.5, inplace=False, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs): def hardswish(inputs, **kwargs):
r"""Apply the hard swish function. r"""Apply the hard swish function.
`[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_. `[Howard et.al, 2019] <https://arxiv.org/abs/1905.02244>`_.
The **HardSwish** function is defined as: The **HardSwish** function is defined as:
.. math:: \text{HardSwish}(x) = x \cdot \max(0, \min(1, \alpha * x + \beta)) .. math:: \text{HardSwish}(x) = x \cdot \max(0, \min(1, \frac{x}{6} + 0.5))
Examples: Examples:
```python ```python
x = dragon.constant([-2.5, -1.0, 0.0, 1.0, 2.5]) x = dragon.constant([-3., -1., 0., 1., 3.])
print(dragon.nn.hardswish(x)) print(dragon.nn.hardswish(x))
``` ```
...@@ -298,10 +298,6 @@ def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs): ...@@ -298,10 +298,6 @@ def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs):
---------- ----------
inputs : dragon.Tensor inputs : dragon.Tensor
The input tensor. The input tensor.
alpha : float, optional, default=0.2
The value to :math:`\alpha`.
beta : float, optional, default=0.5
The value to :math:`\beta`.
Returns Returns
------- -------
...@@ -309,10 +305,9 @@ def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs): ...@@ -309,10 +305,9 @@ def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs):
The output tensor. The output tensor.
""" """
alpha, beta = float(alpha), float(beta)
if context.executing_eagerly(): if context.executing_eagerly():
return OpLib.execute('HardSwish', inputs, alpha=alpha, beta=beta) return OpLib.execute('HardSwish', inputs)
return OpLib.add('HardSwish', inputs, alpha=alpha, beta=beta, **kwargs) return OpLib.add('HardSwish', inputs, **kwargs)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
...@@ -622,8 +617,8 @@ def silu(inputs, **kwargs): ...@@ -622,8 +617,8 @@ def silu(inputs, **kwargs):
""" """
if context.executing_eagerly(): if context.executing_eagerly():
return OpLib.execute('Swish', inputs) return OpLib.execute('Silu', inputs)
return OpLib.add('Swish', inputs, **kwargs) return OpLib.add('Silu', inputs, **kwargs)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
......
...@@ -115,23 +115,10 @@ void HardSigmoidGrad( ...@@ -115,23 +115,10 @@ void HardSigmoidGrad(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void HardSwish( void HardSwish(const int N, const T* x, T* y, Context* ctx);
const int N,
const float alpha,
const float beta,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void HardSwishGrad( void HardSwishGrad(const int N, const T* dy, const T* x, T* dx, Context* ctx);
const int N,
const float alpha,
const float beta,
const T* dy,
const T* x,
T* dx,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void PRelu( void PRelu(
...@@ -217,6 +204,12 @@ template <typename T, class Context> ...@@ -217,6 +204,12 @@ template <typename T, class Context>
void SigmoidGrad(const int N, const T* dy, const T* y, T* dx, Context* ctx); void SigmoidGrad(const int N, const T* dy, const T* y, T* dx, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Silu(const int N, const T* x, T* y, Context* ctx);
template <typename T, class Context>
void SiluGrad(const int N, const T* dy, const T* x, T* dx, Context* ctx);
template <typename T, class Context>
void Softmax( void Softmax(
const int N, const int N,
const int S, const int S,
...@@ -236,18 +229,6 @@ void SoftmaxGrad( ...@@ -236,18 +229,6 @@ void SoftmaxGrad(
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Swish(const int N, const T* x, T* y, Context* ctx);
template <typename T, class Context>
void SwishGrad(
const int N,
const T* dy,
const T* x,
const T* y,
T* dx,
Context* ctx);
template <typename T, class Context>
void Tanh(const int N, const T* x, T* y, Context* ctx); void Tanh(const int N, const T* x, T* y, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
......
...@@ -989,7 +989,7 @@ def gelu(features, approximate=False, name=None): ...@@ -989,7 +989,7 @@ def gelu(features, approximate=False, name=None):
The **GELU** function is defined as: The **GELU** function is defined as:
.. math:: \text{GELU}(x) = 0.5x(1 + \tanh[\sqrt{2/\pi}(x + 0.044715x^{3})]) .. math:: \text{GELU}(x) = x\cdot\frac{1}{2}[1 + \text{erf}(x / \sqrt{2})]
Examples: Examples:
......
...@@ -218,7 +218,7 @@ class TestActivationOps(OpTestCase): ...@@ -218,7 +218,7 @@ class TestActivationOps(OpTestCase):
alpha, beta = 0.2, 0.5 alpha, beta = 0.2, 0.5
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
data = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], 'float32') data = np.array([-3., -2., -1., 0., 1., 2., 3.], 'float32')
x = new_tensor(data) x = new_tensor(data)
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch(x) tape.watch(x)
...@@ -234,21 +234,20 @@ class TestActivationOps(OpTestCase): ...@@ -234,21 +234,20 @@ class TestActivationOps(OpTestCase):
self.test_hardsigmoid() self.test_hardsigmoid()
def test_hardswish(self): def test_hardswish(self):
alpha, beta = 0.2, 0.5 alpha, beta = 1. / 6., 0.5
bound = beta / alpha bound = beta / alpha
alpha2x = alpha * 2.
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
data = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], 'float32') data = np.array([-3., -2., -1., 0., 1., 2., 3.], 'float32')
x = new_tensor(data) x = new_tensor(data)
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch(x) tape.watch(x)
y = dragon.nn.hardswish(x, alpha=alpha, beta=beta) y = dragon.nn.hardswish(x)
dx = tape.gradient(y, [x], output_gradients=[x])[0] dx = tape.gradient(y, [x], output_gradients=[x])[0]
result = data * np.clip(alpha * data + beta, 0, 1) result = data * np.clip(alpha * data + beta, 0, 1)
result2 = data.copy() result2 = data.copy()
inds = np.where(data < bound)[0] inds = np.where(data < bound)[0]
result2[inds] = data[inds] * (data[inds] * alpha2x + beta) result2[inds] = data[inds] * (data[inds] * alpha * 2 + beta)
result2[np.where(data < -bound)[0]] = 0 result2[np.where(data < -bound)[0]] = 0
self.assertEqual([y, dx], [result, result2]) self.assertEqual([y, dx], [result, result2])
...@@ -417,15 +416,15 @@ class TestActivationOps(OpTestCase): ...@@ -417,15 +416,15 @@ class TestActivationOps(OpTestCase):
def test_silu(self): def test_silu(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
data = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], 'float32') data = np.array([-3., -2., -1., 0., 1., 2., 3.], 'float32')
x = new_tensor(data) x = new_tensor(data)
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch(x) tape.watch(x)
y = dragon.nn.silu(x) y = dragon.nn.silu(x)
dx = tape.gradient(y, [x], output_gradients=[x])[0] dx = tape.gradient(y, [x], output_gradients=[x])[0]
result = data * (1. / (1. + np.exp(-data))) result = data * (1. / (1. + np.exp(-data)))
result2 = data * (result + (1. / (1. + np.exp(-data))) * (1. - result)) grad = data * (result + (1. / (1. + np.exp(-data))) * (1. - result))
self.assertEqual([y, dx], [result, result2]) self.assertEqual([y, dx], [result, grad])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_silu_cuda(self): def test_silu_cuda(self):
......
...@@ -528,7 +528,7 @@ class TestModules(OpTestCase): ...@@ -528,7 +528,7 @@ class TestModules(OpTestCase):
self.assertEqual(y, result) self.assertEqual(y, result)
def test_hardswish(self): def test_hardswish(self):
alpha, beta = 1.0 / 6.0, 0.5 alpha, beta = 1. / 6., 0.5
data = np.array([-3., -2., -1., 0., 1., 2., 3], 'float32') data = np.array([-3., -2., -1., 0., 1., 2., 3], 'float32')
x = new_tensor(data) x = new_tensor(data)
m = torch.nn.Hardswish() m = torch.nn.Hardswish()
......
...@@ -1006,8 +1006,7 @@ def hardswish(input): ...@@ -1006,8 +1006,7 @@ def hardswish(input):
`torch.nn.Hardswish(...)`_ `torch.nn.Hardswish(...)`_
""" """
return FunctionLib.apply( return FunctionLib.apply('HardSwish', input.device, [input])
'HardSwish', input.device, [input], alpha=1. / 6., beta=0.5)
def interpolate( def interpolate(
...@@ -1992,7 +1991,7 @@ def silu(input): ...@@ -1992,7 +1991,7 @@ def silu(input):
`torch.nn.SiLU(...)`_ `torch.nn.SiLU(...)`_
""" """
return FunctionLib.apply('Swish', input.device, [input]) return FunctionLib.apply('Silu', input.device, [input])
def smooth_l1_loss( def smooth_l1_loss(
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!