Commit 1ad360e9 by Ting PAN

Add tests of operator spec for AutoGraph

Summary:
This commit tests the correctness of shape inference and data type
blended by autograph module.
1 parent 1bd78a3c
Showing with 655 additions and 752 deletions
...@@ -7,16 +7,16 @@ namespace kernel { ...@@ -7,16 +7,16 @@ namespace kernel {
namespace { namespace {
template <typename Tx, typename Ty> template <typename InputT, typename OutputT>
void _ChannelNormalize( void _ChannelNormalize(
const int axis, const int axis,
const int num_dims, const int num_dims,
const int64_t* x_strides, const int64_t* x_strides,
const int64_t* y_dims, const int64_t* y_dims,
const Tx* x, const InputT* x,
const float* mean, const float* mean,
const float* std, const float* std,
Ty* y) { OutputT* y) {
const auto count = const auto count =
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t idx(num_dims, 0); vec64_t idx(num_dims, 0);
...@@ -27,7 +27,8 @@ void _ChannelNormalize( ...@@ -27,7 +27,8 @@ void _ChannelNormalize(
xi += idx[d] * x_strides[d]; xi += idx[d] * x_strides[d];
if (d == axis) wi = idx[d]; if (d == axis) wi = idx[d];
} }
y[yi] = ((Ty)x[xi] - (Ty)mean[wi]) / (Ty)std[wi]; y[yi] =
convert::To<OutputT>((convert::To<float>(x[xi]) - mean[wi]) / std[wi]);
math::utils::IncreaseIndexInDims(num_dims, y_dims, idx.data()); math::utils::IncreaseIndexInDims(num_dims, y_dims, idx.data());
} }
} }
...@@ -36,83 +37,43 @@ void _ChannelNormalize( ...@@ -36,83 +37,43 @@ void _ChannelNormalize(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> #define DEFINE_KERNEL_LAUNCHER(InputT, OutputT) \
void ChannelNormalize<float16, float16, CPUContext>(
const int axis,
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const float16* x,
const float* mean,
const float* std,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_KERNEL_LAUNCHER(Tx, Ty) \
template <> \ template <> \
void ChannelNormalize<Tx, Ty, CPUContext>( \ void ChannelNormalize<InputT, OutputT, CPUContext>( \
const int axis, \ const int axis, \
const int num_dims, \ const int num_dims, \
const int64_t* x_strides, \ const int64_t* x_strides, \
const int64_t* y_dims, \ const int64_t* y_dims, \
const Tx* x, \ const InputT* x, \
const float* mean, \ const float* mean, \
const float* std, \ const float* std, \
Ty* y, \ OutputT* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_ChannelNormalize(axis, num_dims, x_strides, y_dims, x, mean, std, y); \ _ChannelNormalize(axis, num_dims, x_strides, y_dims, x, mean, std, y); \
} }
#define DEFINE_FP16_KERNEL_LAUNCHER(T) \ DEFINE_KERNEL_LAUNCHER(int8_t, float16);
template <> \
void ChannelNormalize<float16, T, CPUContext>( \
const int axis, \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const float16* x, \
const float* mean, \
const float* std, \
T* y, \
CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \
} \
template <> \
void ChannelNormalize<T, float16, CPUContext>( \
const int axis, \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
const float* mean, \
const float* std, \
float16* y, \
CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \
}
DEFINE_KERNEL_LAUNCHER(int8_t, float); DEFINE_KERNEL_LAUNCHER(int8_t, float);
DEFINE_KERNEL_LAUNCHER(int8_t, double); DEFINE_KERNEL_LAUNCHER(int8_t, double);
DEFINE_KERNEL_LAUNCHER(uint8_t, float16);
DEFINE_KERNEL_LAUNCHER(uint8_t, float); DEFINE_KERNEL_LAUNCHER(uint8_t, float);
DEFINE_KERNEL_LAUNCHER(uint8_t, double); DEFINE_KERNEL_LAUNCHER(uint8_t, double);
DEFINE_KERNEL_LAUNCHER(int, float16);
DEFINE_KERNEL_LAUNCHER(int, float); DEFINE_KERNEL_LAUNCHER(int, float);
DEFINE_KERNEL_LAUNCHER(int, double); DEFINE_KERNEL_LAUNCHER(int, double);
DEFINE_KERNEL_LAUNCHER(int64_t, float16);
DEFINE_KERNEL_LAUNCHER(int64_t, float); DEFINE_KERNEL_LAUNCHER(int64_t, float);
DEFINE_KERNEL_LAUNCHER(int64_t, double); DEFINE_KERNEL_LAUNCHER(int64_t, double);
DEFINE_KERNEL_LAUNCHER(float16, float16);
DEFINE_KERNEL_LAUNCHER(float16, float);
DEFINE_KERNEL_LAUNCHER(float16, double);
DEFINE_KERNEL_LAUNCHER(float, float16);
DEFINE_KERNEL_LAUNCHER(float, float); DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(float, double); DEFINE_KERNEL_LAUNCHER(float, double);
DEFINE_KERNEL_LAUNCHER(double, float16);
DEFINE_KERNEL_LAUNCHER(double, float); DEFINE_KERNEL_LAUNCHER(double, float);
DEFINE_KERNEL_LAUNCHER(double, double); DEFINE_KERNEL_LAUNCHER(double, double);
DEFINE_FP16_KERNEL_LAUNCHER(int8_t);
DEFINE_FP16_KERNEL_LAUNCHER(uint8_t);
DEFINE_FP16_KERNEL_LAUNCHER(int);
DEFINE_FP16_KERNEL_LAUNCHER(int64_t);
DEFINE_FP16_KERNEL_LAUNCHER(float);
DEFINE_FP16_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_FP16_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -7,51 +7,51 @@ namespace kernel { ...@@ -7,51 +7,51 @@ namespace kernel {
namespace { namespace {
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void _NLLLoss( void _NLLLoss(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask) { LogitT* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
int count = dims[0] * dims[1], k; int count = dims[0] * dims[1], k;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int label = (int)target[i]; const int label = (int)target[i];
if (label == ignore_index) { if (label == ignore_index) {
loss[i] = mask[i] = LogitType(0); loss[i] = mask[i] = LogitT(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
loss[i] = -logit[k], mask[i] = LogitType(1); loss[i] = -logit[k], mask[i] = LogitT(1);
} }
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data()); math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
} }
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void _NLLLossGrad( void _NLLLossGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* dlogit, LogitT* dlogit,
LogitType* mask) { LogitT* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
int count = dims[0] * dims[1], k; int count = dims[0] * dims[1], k;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int label = (int)target[i]; const int label = (int)target[i];
if (label == ignore_index) { if (label == ignore_index) {
mask[i] = LogitType(0); mask[i] = LogitT(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
dlogit[k] = LogitType(-1), mask[i] = LogitType(1); dlogit[k] = LogitT(-1), mask[i] = LogitT(1);
} }
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data()); math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
...@@ -61,17 +61,17 @@ void _NLLLossGrad( ...@@ -61,17 +61,17 @@ void _NLLLossGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, LogitType, TargetType) \ #define DEFINE_KERNEL_LAUNCHER(name, LogitT, TargetT) \
template <> \ template <> \
void name<LogitType, TargetType, CPUContext>( \ void name<LogitT, TargetT, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int ignore_index, \ const int ignore_index, \
const LogitType* logit, \ const LogitT* logit, \
const TargetType* target, \ const TargetT* target, \
LogitType* loss, \ LogitT* loss, \
LogitType* mask, \ LogitT* mask, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
......
...@@ -9,48 +9,48 @@ namespace kernel { ...@@ -9,48 +9,48 @@ namespace kernel {
namespace { namespace {
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
__global__ void _NLLLoss( __global__ void _NLLLoss(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask) { LogitT* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int label = target[i * inner_dim + j]; const int label = target[i * inner_dim + j];
if (label == ignore_index) { if (label == ignore_index) {
loss[yi] = mask[yi] = LogitType(0); loss[yi] = mask[yi] = LogitT(0);
} else { } else {
loss[yi] = -logit[(i * axis_dim + label) * inner_dim + j]; loss[yi] = -logit[(i * axis_dim + label) * inner_dim + j];
mask[yi] = LogitType(1); mask[yi] = LogitT(1);
} }
} }
} }
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
__global__ void _NLLLossGrad( __global__ void _NLLLossGrad(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* dlogit, LogitT* dlogit,
LogitType* mask) { LogitT* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int label = target[i * inner_dim + j]; const int label = target[i * inner_dim + j];
if (label == ignore_index) { if (label == ignore_index) {
mask[yi] = LogitType(0); mask[yi] = LogitT(0);
} else { } else {
dlogit[(i * axis_dim + label) * inner_dim + j] = LogitType(-1); dlogit[(i * axis_dim + label) * inner_dim + j] = LogitT(-1);
mask[yi] = LogitType(1); mask[yi] = LogitT(1);
} }
} }
} }
...@@ -59,17 +59,17 @@ __global__ void _NLLLossGrad( ...@@ -59,17 +59,17 @@ __global__ void _NLLLossGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, LogitType, TargetType) \ #define DEFINE_KERNEL_LAUNCHER(name, LogitT, TargetT) \
template <> \ template <> \
void name<LogitType, TargetType, CUDAContext>( \ void name<LogitT, TargetT, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int ignore_index, \ const int ignore_index, \
const LogitType* logit, \ const LogitT* logit, \
const TargetType* target, \ const TargetT* target, \
LogitType* loss, \ LogitT* loss, \
LogitType* mask, \ LogitT* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const auto nthreads = outer_dim * inner_dim; \ const auto nthreads = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
......
...@@ -7,19 +7,19 @@ namespace kernel { ...@@ -7,19 +7,19 @@ namespace kernel {
namespace { namespace {
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void _SigmoidFocalLoss( void _SigmoidFocalLoss(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const LogitType pos_alpha, const LogitT pos_alpha,
const LogitType neg_alpha, const LogitT neg_alpha,
const LogitType gamma, const LogitT gamma,
const int negative_index, const int negative_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask) { LogitT* mask) {
std::array<int, 3> idx = {0, 0, 0}; std::array<int, 3> idx = {0, 0, 0};
std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim}; std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim};
const int count = dims[0] * dims[1] * dims[2]; const int count = dims[0] * dims[1] * dims[2];
...@@ -27,23 +27,21 @@ void _SigmoidFocalLoss( ...@@ -27,23 +27,21 @@ void _SigmoidFocalLoss(
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int t = (int)target[idx[0] * inner_dim + idx[2]]; const int t = (int)target[idx[0] * inner_dim + idx[2]];
// "0" is reserved for target if negative index is zero // "0" is reserved for target if negative index is zero
LogitType c1 = (LogitType)(t == (idx[1] + (negative_index ? 0 : 1))); LogitT c1 = (LogitT)(t == (idx[1] + (negative_index ? 0 : 1)));
LogitType c2 = LogitT c2 = (LogitT)((t >= 0) & (t != (idx[1] + (negative_index ? 0 : 1))));
(LogitType)((t >= 0) & (t != (idx[1] + (negative_index ? 0 : 1)))); LogitT p = LogitT(1) / (LogitT(1) + std::exp(-logit[i]));
LogitType p = LogitType(1) / (LogitType(1) + std::exp(-logit[i]));
// (1 - p)^{gamma} * log(p) // (1 - p)^{gamma} * log(p)
LogitType pos_term = std::pow(LogitType(1) - p, gamma) * LogitT pos_term =
std::log(std::max(p, (LogitType)FLT_MIN)); std::pow(LogitT(1) - p, gamma) * std::log(std::max(p, (LogitT)FLT_MIN));
// p^{gamma} * log(1 - p) // p^{gamma} * log(1 - p)
LogitType neg_term = std::pow(p, gamma) * LogitT neg_term = std::pow(p, gamma) *
(-logit[i] * (logit[i] >= 0) - (-logit[i] * (logit[i] >= 0) -
std::log( std::log(
LogitType(1) + LogitT(1) + std::exp(logit[i] - 2 * logit[i] * (logit[i] >= 0))));
std::exp(logit[i] - 2 * logit[i] * (logit[i] >= 0))));
loss[i] = LogitType(0); loss[i] = LogitT(0);
loss[i] += -c1 * pos_term * pos_alpha; loss[i] += -c1 * pos_term * pos_alpha;
loss[i] += -c2 * neg_term * neg_alpha; loss[i] += -c2 * neg_term * neg_alpha;
mask[i] = c1; mask[i] = c1;
...@@ -52,19 +50,19 @@ void _SigmoidFocalLoss( ...@@ -52,19 +50,19 @@ void _SigmoidFocalLoss(
} }
} }
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void _SigmoidFocalLossGrad( void _SigmoidFocalLossGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const LogitType pos_alpha, const LogitT pos_alpha,
const LogitType neg_alpha, const LogitT neg_alpha,
const LogitType gamma, const LogitT gamma,
const int negative_index, const int negative_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* dx, LogitT* dx,
LogitType* mask) { LogitT* mask) {
std::array<int, 3> idx = {0, 0, 0}; std::array<int, 3> idx = {0, 0, 0};
std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim}; std::array<int, 3> dims = {outer_dim, axis_dim, inner_dim};
const int count = dims[0] * dims[1] * dims[2]; const int count = dims[0] * dims[1] * dims[2];
...@@ -72,26 +70,24 @@ void _SigmoidFocalLossGrad( ...@@ -72,26 +70,24 @@ void _SigmoidFocalLossGrad(
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int t = (int)target[idx[0] * inner_dim + idx[2]]; const int t = (int)target[idx[0] * inner_dim + idx[2]];
// "0" is reserved for target if negative index is zero // "0" is reserved for target if negative index is zero
LogitType c1 = (LogitType)(t == (idx[1] + (negative_index ? 0 : 1))); LogitT c1 = (LogitT)(t == (idx[1] + (negative_index ? 0 : 1)));
LogitType c2 = LogitT c2 = (LogitT)((t >= 0) & (t != (idx[1] + (negative_index ? 0 : 1))));
(LogitType)((t >= 0) & (t != (idx[1] + (negative_index ? 0 : 1)))); LogitT p = LogitT(1) / (LogitT(1) + std::exp(-logit[i]));
LogitType p = LogitType(1) / (LogitType(1) + std::exp(-logit[i]));
// (1 - p)^{gamma} * (1 - p - gamma * p * log(p)) // (1 - p)^{gamma} * (1 - p - gamma * p * log(p))
LogitType pos_term = std::pow(LogitType(1) - p, gamma) * LogitT pos_term = std::pow(LogitT(1) - p, gamma) *
(LogitType(1) - p - (LogitT(1) - p - p * gamma * std::log(std::max(p, (LogitT)FLT_MIN)));
p * gamma * std::log(std::max(p, (LogitType)FLT_MIN)));
// p^{gamma} * (gamma * (1 - p) * log(1-p) - p) // p^{gamma} * (gamma * (1 - p) * log(1-p) - p)
LogitType neg_term = std::pow(p, gamma) * LogitT neg_term = std::pow(p, gamma) *
((-logit[i] * (logit[i] >= 0) - ((-logit[i] * (logit[i] >= 0) -
std::log( std::log(
LogitType(1) + LogitT(1) +
std::exp(logit[i] - LogitType(2) * logit[i] * (logit[i] >= 0)))) * std::exp(logit[i] - LogitT(2) * logit[i] * (logit[i] >= 0)))) *
(1 - p) * gamma - (1 - p) * gamma -
p); p);
dx[i] = LogitType(0); dx[i] = LogitT(0);
dx[i] += -c1 * pos_term * pos_alpha; dx[i] += -c1 * pos_term * pos_alpha;
dx[i] += -c2 * neg_term * neg_alpha; dx[i] += -c2 * neg_term * neg_alpha;
mask[i] = c1; mask[i] = c1;
...@@ -104,9 +100,9 @@ void _SigmoidFocalLossGrad( ...@@ -104,9 +100,9 @@ void _SigmoidFocalLossGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, LogitType, TargetType) \ #define DEFINE_KERNEL_LAUNCHER(name, LogitT, TargetT) \
template <> \ template <> \
void name<LogitType, TargetType, CPUContext>( \ void name<LogitT, TargetT, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
...@@ -114,18 +110,18 @@ void _SigmoidFocalLossGrad( ...@@ -114,18 +110,18 @@ void _SigmoidFocalLossGrad(
const float neg_alpha, \ const float neg_alpha, \
const float gamma, \ const float gamma, \
const int negative_index, \ const int negative_index, \
const LogitType* logit, \ const LogitT* logit, \
const TargetType* target, \ const TargetT* target, \
LogitType* loss, \ LogitT* loss, \
LogitType* mask, \ LogitT* mask, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
inner_dim, \ inner_dim, \
axis_dim, \ axis_dim, \
(LogitType)pos_alpha, \ (LogitT)pos_alpha, \
(LogitType)neg_alpha, \ (LogitT)neg_alpha, \
(LogitType)gamma, \ (LogitT)gamma, \
negative_index, \ negative_index, \
logit, \ logit, \
target, \ target, \
......
...@@ -9,19 +9,19 @@ namespace kernel { ...@@ -9,19 +9,19 @@ namespace kernel {
namespace { namespace {
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
__global__ void _SigmoidFocalLoss( __global__ void _SigmoidFocalLoss(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const LogitType pos_alpha, const LogitT pos_alpha,
const LogitType neg_alpha, const LogitT neg_alpha,
const LogitType gamma, const LogitT gamma,
const int negative_index, const int negative_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask) { LogitT* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int k = (yi / inner_dim) % axis_dim; const int k = (yi / inner_dim) % axis_dim;
...@@ -29,40 +29,39 @@ __global__ void _SigmoidFocalLoss( ...@@ -29,40 +29,39 @@ __global__ void _SigmoidFocalLoss(
const int t = target[i * inner_dim + j]; const int t = target[i * inner_dim + j];
// "0" is reserved for target if negative index is zero // "0" is reserved for target if negative index is zero
LogitType c1 = (LogitType)(t == (k + (negative_index ? 0 : 1))); LogitT c1 = (LogitT)(t == (k + (negative_index ? 0 : 1)));
LogitType c2 = LogitT c2 = (LogitT)((t >= 0) & (t != (k + (negative_index ? 0 : 1))));
(LogitType)((t >= 0) & (t != (k + (negative_index ? 0 : 1)))); LogitT p = LogitT(1) / (LogitT(1) + exp(-logit[yi]));
LogitType p = LogitType(1) / (LogitType(1) + exp(-logit[yi]));
// (1 - p)^{gamma} * log(p) // (1 - p)^{gamma} * log(p)
LogitType pos_term = pow(LogitType(1) - p, gamma) * log(max(p, FLT_MIN)); LogitT pos_term = pow(LogitT(1) - p, gamma) * log(max(p, FLT_MIN));
// p^{gamma} * log(1 - p) // p^{gamma} * log(1 - p)
LogitType neg_term = pow(p, gamma) * LogitT neg_term = pow(p, gamma) *
(-logit[yi] * (logit[yi] >= 0) - (-logit[yi] * (logit[yi] >= 0) -
log(LogitType(1) + log(LogitT(1) +
exp(logit[yi] - LogitType(2) * logit[yi] * (logit[yi] >= 0)))); exp(logit[yi] - LogitT(2) * logit[yi] * (logit[yi] >= 0))));
loss[yi] = LogitType(0); loss[yi] = LogitT(0);
loss[yi] += -c1 * pos_term * pos_alpha; loss[yi] += -c1 * pos_term * pos_alpha;
loss[yi] += -c2 * neg_term * neg_alpha; loss[yi] += -c2 * neg_term * neg_alpha;
mask[yi] = c1; mask[yi] = c1;
} }
} }
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
__global__ void _SigmoidFocalLossGrad( __global__ void _SigmoidFocalLossGrad(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const LogitType pos_alpha, const LogitT pos_alpha,
const LogitType neg_alpha, const LogitT neg_alpha,
const LogitType gamma, const LogitT gamma,
const int negative_index, const int negative_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* dx, LogitT* dx,
LogitType* mask) { LogitT* mask) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) { CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int j = xi % inner_dim; const int j = xi % inner_dim;
const int k = (xi / inner_dim) % axis_dim; const int k = (xi / inner_dim) % axis_dim;
...@@ -70,24 +69,23 @@ __global__ void _SigmoidFocalLossGrad( ...@@ -70,24 +69,23 @@ __global__ void _SigmoidFocalLossGrad(
const int t = target[i * inner_dim + j]; const int t = target[i * inner_dim + j];
// "0" is reserved for target if neg index is zero // "0" is reserved for target if neg index is zero
LogitType c1 = (LogitType)(t == (k + (negative_index ? 0 : 1))); LogitT c1 = (LogitT)(t == (k + (negative_index ? 0 : 1)));
LogitType c2 = LogitT c2 = (LogitT)((t >= 0) & (t != (k + (negative_index ? 0 : 1))));
(LogitType)((t >= 0) & (t != (k + (negative_index ? 0 : 1)))); LogitT p = LogitT(1) / (LogitT(1) + exp(-logit[xi]));
LogitType p = LogitType(1) / (LogitType(1) + exp(-logit[xi]));
// (1 - p)^{gamma} * (1 - p - gamma * p * log(p)) // (1 - p)^{gamma} * (1 - p - gamma * p * log(p))
LogitType pos_term = pow(LogitType(1) - p, gamma) * LogitT pos_term = pow(LogitT(1) - p, gamma) *
(LogitType(1) - p - p * gamma * log(max(p, FLT_MIN))); (LogitT(1) - p - p * gamma * log(max(p, FLT_MIN)));
// p^{gamma} * (gamma * (1 - p) * log(1-p) - p) // p^{gamma} * (gamma * (1 - p) * log(1-p) - p)
LogitType neg_term = pow(p, gamma) * LogitT neg_term = pow(p, gamma) *
((-logit[xi] * (logit[xi] >= 0) - ((-logit[xi] * (logit[xi] >= 0) -
log(LogitType(1) + log(LogitT(1) +
exp(logit[xi] - LogitType(2) * logit[xi] * (logit[xi] >= 0)))) * exp(logit[xi] - LogitT(2) * logit[xi] * (logit[xi] >= 0)))) *
(LogitType(1) - p) * gamma - (LogitT(1) - p) * gamma -
p); p);
dx[xi] = LogitType(0); dx[xi] = LogitT(0);
dx[xi] += -c1 * pos_term * pos_alpha; dx[xi] += -c1 * pos_term * pos_alpha;
dx[xi] += -c2 * neg_term * neg_alpha; dx[xi] += -c2 * neg_term * neg_alpha;
mask[xi] = c1; mask[xi] = c1;
...@@ -98,9 +96,9 @@ __global__ void _SigmoidFocalLossGrad( ...@@ -98,9 +96,9 @@ __global__ void _SigmoidFocalLossGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, LogitType, TargetType) \ #define DEFINE_KERNEL_LAUNCHER(name, LogitT, TargetT) \
template <> \ template <> \
void name<LogitType, TargetType, CUDAContext>( \ void name<LogitT, TargetT, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
...@@ -108,19 +106,19 @@ __global__ void _SigmoidFocalLossGrad( ...@@ -108,19 +106,19 @@ __global__ void _SigmoidFocalLossGrad(
const float neg_alpha, \ const float neg_alpha, \
const float gamma, \ const float gamma, \
const int negative_index, \ const int negative_index, \
const LogitType* logit, \ const LogitT* logit, \
const TargetType* target, \ const TargetT* target, \
LogitType* loss, \ LogitT* loss, \
LogitType* mask, \ LogitT* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const auto nthreads = outer_dim * axis_dim * inner_dim; \ const auto nthreads = outer_dim * axis_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \ nthreads, \
inner_dim, \ inner_dim, \
axis_dim, \ axis_dim, \
(LogitType)pos_alpha, \ (LogitT)pos_alpha, \
(LogitType)neg_alpha, \ (LogitT)neg_alpha, \
(LogitType)gamma, \ (LogitT)gamma, \
negative_index, \ negative_index, \
logit, \ logit, \
target, \ target, \
......
...@@ -7,58 +7,58 @@ namespace kernel { ...@@ -7,58 +7,58 @@ namespace kernel {
namespace { namespace {
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void _SparseSoftmaxCrossEntropy( void _SparseSoftmaxCrossEntropy(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitT* prob,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask) { LogitT* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
int count = dims[0] * dims[1], k; int count = dims[0] * dims[1], k;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int label = (int)target[i]; const int label = (int)target[i];
if (label == ignore_index) { if (label == ignore_index) {
loss[i] = mask[i] = LogitType(0); loss[i] = mask[i] = LogitT(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
loss[i] = -std::log(std::max(prob[k], LogitType(FLT_MIN))); loss[i] = -std::log(std::max(prob[k], LogitT(FLT_MIN)));
mask[i] = LogitType(1); mask[i] = LogitT(1);
} }
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data()); math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
} }
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void _SparseSoftmaxCrossEntropyGrad( void _SparseSoftmaxCrossEntropyGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitT* prob,
const TargetType* target, const TargetT* target,
LogitType* dx, LogitT* dx,
LogitType* mask) { LogitT* mask) {
std::array<int, 2> idx = {0, 0}; std::array<int, 2> idx = {0, 0};
std::array<int, 2> dims = {outer_dim, inner_dim}; std::array<int, 2> dims = {outer_dim, inner_dim};
int count = dims[0] * dims[1], k; int count = dims[0] * dims[1], k;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int label = (int)target[i]; const int label = (int)target[i];
if (label == ignore_index) { if (label == ignore_index) {
LogitType* offset_dx = dx + idx[0] * axis_dim * inner_dim + idx[1]; LogitT* offset_dx = dx + idx[0] * axis_dim * inner_dim + idx[1];
for (int j = 0; j < axis_dim; ++j) { for (int j = 0; j < axis_dim; ++j) {
(*offset_dx) = LogitType(0); (*offset_dx) = LogitT(0);
offset_dx += inner_dim; offset_dx += inner_dim;
} }
mask[i] = LogitType(0); mask[i] = LogitT(0);
} else { } else {
k = (idx[0] * axis_dim + label) * inner_dim + idx[1]; k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
dx[k] -= LogitType(1); dx[k] -= LogitT(1);
mask[i] = LogitType(1); mask[i] = LogitT(1);
} }
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data()); math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
} }
...@@ -68,17 +68,17 @@ void _SparseSoftmaxCrossEntropyGrad( ...@@ -68,17 +68,17 @@ void _SparseSoftmaxCrossEntropyGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, LogitType, TargetType) \ #define DEFINE_KERNEL_LAUNCHER(name, LogitT, TargetT) \
template <> \ template <> \
void name<LogitType, TargetType, CPUContext>( \ void name<LogitT, TargetT, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int ignore_index, \ const int ignore_index, \
const LogitType* prob, \ const LogitT* prob, \
const TargetType* target, \ const TargetT* target, \
LogitType* loss, \ LogitT* loss, \
LogitType* mask, \ LogitT* mask, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name( \ _##name( \
outer_dim, \ outer_dim, \
......
...@@ -9,54 +9,54 @@ namespace kernel { ...@@ -9,54 +9,54 @@ namespace kernel {
namespace { namespace {
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
__global__ void _SparseSoftmaxCrossEntropy( __global__ void _SparseSoftmaxCrossEntropy(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitT* prob,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask) { LogitT* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int label = target[i * inner_dim + j]; const int label = target[i * inner_dim + j];
if (label == ignore_index) { if (label == ignore_index) {
loss[yi] = mask[yi] = LogitType(0); loss[yi] = mask[yi] = LogitT(0);
} else { } else {
loss[yi] = -log(max( loss[yi] = -log(
prob[(i * axis_dim + label) * inner_dim + j], LogitType(FLT_MIN))); max(prob[(i * axis_dim + label) * inner_dim + j], LogitT(FLT_MIN)));
mask[yi] = LogitType(1); mask[yi] = LogitT(1);
} }
} }
} }
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
__global__ void _SparseSoftmaxCrossEntropyGrad( __global__ void _SparseSoftmaxCrossEntropyGrad(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitT* prob,
const TargetType* target, const TargetT* target,
LogitType* dx, LogitT* dx,
LogitType* mask) { LogitT* mask) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim; const int i = yi / inner_dim;
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int label = target[i * inner_dim + j]; const int label = target[i * inner_dim + j];
if (label == ignore_index) { if (label == ignore_index) {
LogitType* offset_dx = dx + i * axis_dim * inner_dim + j; LogitT* offset_dx = dx + i * axis_dim * inner_dim + j;
for (int k = 0; k < axis_dim; ++k) { for (int k = 0; k < axis_dim; ++k) {
(*offset_dx) = LogitType(0); (*offset_dx) = LogitT(0);
offset_dx += inner_dim; offset_dx += inner_dim;
} }
mask[yi] = LogitType(0); mask[yi] = LogitT(0);
} else { } else {
dx[(i * axis_dim + label) * inner_dim + j] -= LogitType(1); dx[(i * axis_dim + label) * inner_dim + j] -= LogitT(1);
mask[yi] = LogitType(1); mask[yi] = LogitT(1);
} }
} }
} }
...@@ -65,17 +65,17 @@ __global__ void _SparseSoftmaxCrossEntropyGrad( ...@@ -65,17 +65,17 @@ __global__ void _SparseSoftmaxCrossEntropyGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, LogitType, TargetType) \ #define DEFINE_KERNEL_LAUNCHER(name, LogitT, TargetT) \
template <> \ template <> \
void name<LogitType, TargetType, CUDAContext>( \ void name<LogitT, TargetT, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int ignore_index, \ const int ignore_index, \
const LogitType* prob, \ const LogitT* prob, \
const TargetType* target, \ const TargetT* target, \
LogitType* loss, \ LogitT* loss, \
LogitType* mask, \ LogitT* mask, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const auto nthreads = outer_dim * inner_dim; \ const auto nthreads = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
......
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -9,125 +10,34 @@ namespace kernel { ...@@ -9,125 +10,34 @@ namespace kernel {
namespace { namespace {
template <typename T> template <typename T, typename AccT>
__global__ void __global__ void
_Clip(const int nthreads, const T low, const T high, const T* x, T* y) { _Clip(const int nthreads, const AccT low, const AccT high, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = max(low, min(x[i], high)); y[i] = convert::To<T>(max(low, min(convert::To<AccT>(x[i]), high)));
} }
} }
template <> template <typename T, typename AccT>
__global__ void _Clip<half>(
const int nthreads,
const half low,
const half high,
const half* x,
half* y) {
#if __CUDA_ARCH__ >= 530
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __hlt(__ldg(x + i), high)
? (__hgt(__ldg(x + i), low) ? __ldg(x + i) : low)
: high;
}
#else
const float kLow = __half2float(low);
const float kHigh = __half2float(high);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __float2half(max(kLow, min(__half2float(x[i]), kHigh)));
}
#endif
}
template <typename T>
__global__ void _ClipGrad( __global__ void _ClipGrad(
const int nthreads, const int nthreads,
const T low, const AccT low,
const T high, const AccT high,
const T* dy, const T* dy,
const T* x, const T* x,
T* dx) { T* dx) {
const T kZero = convert::To<T>(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350 const AccT val = convert::To<AccT>(x[i]);
dx[i] = __ldg(x + i) < low || __ldg(x + i) > high ? T(0) : dy[i]; dx[i] = val < low || val > high ? kZero : dy[i];
#else
dx[i] = x[i] < low || x[i] > high ? T(0) : dy[i];
#endif
} }
} }
template <>
__global__ void _ClipGrad<half>(
const int nthreads,
const half low,
const half high,
const half* dy,
const half* x,
half* dx) {
const half kZero = __float2half(0.f);
#if __CUDA_ARCH__ >= 530
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] =
(__hlt(__ldg(x + i), low) || __hgt(__ldg(x + i), high)) ? kZero : dy[i];
}
#elif __CUDA_ARCH__ >= 350
const float kLow = __half2float(low);
const float kHigh = __half2float(high);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = (__half2float(__ldg(x + i)) < kLow ||
__half2float(__ldg(x + i)) > kHigh)
? kZero
: dy[i];
}
#else
const float kLow = __half2float(low);
const float kHigh = __half2float(high);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = (__half2float(x[i]) < kLow || __half2float(x[i]) > kHigh) ? kZero
: dy[i];
}
#endif
}
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> #define DEFINE_KERNEL_LAUNCHER(T, AccT) \
void Clip<float16, CUDAContext>(
const int count,
const float low,
const float high,
const float16* x,
float16* y,
CUDAContext* ctx) {
_Clip<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
convert::To<half>(low),
convert::To<half>(high),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
template <>
void ClipGrad<float16, CUDAContext>(
const int count,
const float low,
const float high,
const float16* dy,
const float16* x,
float16* dx,
CUDAContext* ctx) {
_ClipGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
convert::To<half>(low),
convert::To<half>(high),
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(dx));
} // ClipGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void Clip<T, CUDAContext>( \ void Clip<T, CUDAContext>( \
const int count, \ const int count, \
...@@ -136,11 +46,12 @@ void ClipGrad<float16, CUDAContext>( ...@@ -136,11 +46,12 @@ void ClipGrad<float16, CUDAContext>(
const T* x, \ const T* x, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Clip<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Clip<T, AccT> \
count, convert::To<T>(low), convert::To<T>(high), x, y); \ <<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, low, high, x, y); \
} }
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(T, AccT) \
template <> \ template <> \
void ClipGrad<T, CUDAContext>( \ void ClipGrad<T, CUDAContext>( \
const int count, \ const int count, \
...@@ -150,18 +61,21 @@ void ClipGrad<float16, CUDAContext>( ...@@ -150,18 +61,21 @@ void ClipGrad<float16, CUDAContext>(
const T* x, \ const T* x, \
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_ClipGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _ClipGrad<T, AccT> \
count, convert::To<T>(low), convert::To<T>(high), dy, x, dx); \ <<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, low, high, dy, x, dx); \
} }
DEFINE_KERNEL_LAUNCHER(int8_t); DEFINE_KERNEL_LAUNCHER(int8_t, int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t); DEFINE_KERNEL_LAUNCHER(uint8_t, uint8_t);
DEFINE_KERNEL_LAUNCHER(int); DEFINE_KERNEL_LAUNCHER(int, int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t, int64_t);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float16, float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(double); DEFINE_GRAD_KERNEL_LAUNCHER(float16, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(double, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -20,15 +20,15 @@ void _RowwiseMoments( ...@@ -20,15 +20,15 @@ void _RowwiseMoments(
#pragma omp parallel for num_threads(OMP_THREADS(cols)) #pragma omp parallel for num_threads(OMP_THREADS(cols))
#endif #endif
for (int i = 0; i < cols; ++i) { for (int i = 0; i < cols; ++i) {
T x_val; AccT x_val, m_val = AccT(0), v_val = AccT(0);
AccT m_val = AccT(0), v_val = AccT(0), mu;
for (int j = 0; j < rows; ++j) { for (int j = 0; j < rows; ++j) {
x_val = x[j * cols + i]; x_val = convert::To<AccT>(x[j * cols + i]);
m_val += x_val; m_val += x_val;
v_val += x_val * x_val; v_val += x_val * x_val;
} }
mean[i] = mu = m_val * scale; m_val *= scale;
var[i] = v_val * scale - mu * mu; mean[i] = m_val;
var[i] = v_val * scale - m_val * m_val;
} }
} }
...@@ -44,15 +44,15 @@ void _ColwiseMoments( ...@@ -44,15 +44,15 @@ void _ColwiseMoments(
#pragma omp parallel for num_threads(OMP_THREADS(rows)) #pragma omp parallel for num_threads(OMP_THREADS(rows))
#endif #endif
for (int i = 0; i < rows; ++i) { for (int i = 0; i < rows; ++i) {
T x_val; AccT x_val, m_val = AccT(0), v_val = AccT(0);
AccT m_val = AccT(0), v_val = AccT(0), mu;
for (int j = 0; j < cols; ++j) { for (int j = 0; j < cols; ++j) {
x_val = x[i * cols + j]; x_val = convert::To<AccT>(x[i * cols + j]);
m_val += x_val; m_val += x_val;
v_val += x_val * x_val; v_val += x_val * x_val;
} }
mean[i] = mu = m_val * scale; m_val *= scale;
var[i] = v_val * scale - mu * mu; mean[i] = m_val;
var[i] = v_val * scale - m_val * m_val;
} }
} }
...@@ -71,8 +71,7 @@ void _GenericMoments( ...@@ -71,8 +71,7 @@ void _GenericMoments(
#pragma omp parallel for num_threads(OMP_THREADS(rows)) #pragma omp parallel for num_threads(OMP_THREADS(rows))
#endif #endif
for (int i = 0; i < rows; ++i) { for (int i = 0; i < rows; ++i) {
T x_val; AccT x_val, m_val = AccT(0), v_val = AccT(0);
AccT m_val = AccT(0), v_val = AccT(0), mu;
int xi, c, r; int xi, c, r;
for (int j = 0; j < cols; ++j) { for (int j = 0; j < cols; ++j) {
xi = 0; xi = 0;
...@@ -81,12 +80,13 @@ void _GenericMoments( ...@@ -81,12 +80,13 @@ void _GenericMoments(
FIXED_DIVISOR_DIV_MOD(x_dims[d], c, &c, &r); FIXED_DIVISOR_DIV_MOD(x_dims[d], c, &c, &r);
xi += r * x_strides[d]; xi += r * x_strides[d];
} }
x_val = x[xi]; x_val = convert::To<AccT>(x[xi]);
m_val += x_val; m_val += x_val;
v_val += x_val * x_val; v_val += x_val * x_val;
} }
mean[i] = mu = m_val * scale; m_val *= scale;
var[i] = v_val * scale - mu * mu; mean[i] = m_val;
var[i] = v_val * scale - m_val * m_val;
} }
} }
...@@ -148,19 +148,6 @@ void _Moments( ...@@ -148,19 +148,6 @@ void _Moments(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <>
void Moments<float16, float, CPUContext>(
const int num_dims,
const int* dims,
const int num_axes,
const int* axes,
const float16* x,
float* mean,
float* var,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \ #define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \ template <> \
void Moments<T, AccT, CPUContext>( \ void Moments<T, AccT, CPUContext>( \
...@@ -178,7 +165,8 @@ void Moments<float16, float, CPUContext>( ...@@ -178,7 +165,8 @@ void Moments<float16, float, CPUContext>(
DEFINE_KERNEL_LAUNCHER(int8_t, float); DEFINE_KERNEL_LAUNCHER(int8_t, float);
DEFINE_KERNEL_LAUNCHER(uint8_t, float); DEFINE_KERNEL_LAUNCHER(uint8_t, float);
DEFINE_KERNEL_LAUNCHER(int, float); DEFINE_KERNEL_LAUNCHER(int, float);
DEFINE_KERNEL_LAUNCHER(int64_t, float); DEFINE_KERNEL_LAUNCHER(int64_t, double);
DEFINE_KERNEL_LAUNCHER(float16, float);
DEFINE_KERNEL_LAUNCHER(float, float); DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(double, double); DEFINE_KERNEL_LAUNCHER(double, double);
#undef DEFINE__KERNEL_LAUNCHER #undef DEFINE__KERNEL_LAUNCHER
......
...@@ -201,7 +201,7 @@ void _Moments( ...@@ -201,7 +201,7 @@ void _Moments(
DEFINE_KERNEL_LAUNCHER(int8_t, int8_t, float); DEFINE_KERNEL_LAUNCHER(int8_t, int8_t, float);
DEFINE_KERNEL_LAUNCHER(uint8_t, uint8_t, float); DEFINE_KERNEL_LAUNCHER(uint8_t, uint8_t, float);
DEFINE_KERNEL_LAUNCHER(int, int, float); DEFINE_KERNEL_LAUNCHER(int, int, float);
DEFINE_KERNEL_LAUNCHER(int64_t, int64_t, float); DEFINE_KERNEL_LAUNCHER(int64_t, int64_t, double);
DEFINE_KERNEL_LAUNCHER(float16, half, float); DEFINE_KERNEL_LAUNCHER(float16, half, float);
DEFINE_KERNEL_LAUNCHER(float, float, float); DEFINE_KERNEL_LAUNCHER(float, float, float);
DEFINE_KERNEL_LAUNCHER(double, double, double); DEFINE_KERNEL_LAUNCHER(double, double, double);
......
...@@ -70,7 +70,7 @@ void _L1NormalizeGrad( ...@@ -70,7 +70,7 @@ void _L1NormalizeGrad(
auto X = ConstEigenStridedVectorMap<T>( auto X = ConstEigenStridedVectorMap<T>(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim)); x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
auto norm = std::max(X.template lpNorm<1>() / normalizer, epsilon); auto norm = std::max(X.template lpNorm<1>() / normalizer, epsilon);
auto norm2 = std::pow(norm, 2); auto norm2 = std::pow(norm, T(2));
EigenStridedVectorMap<T>( EigenStridedVectorMap<T>(
dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) = dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
(dY / norm) - (dY / norm) -
...@@ -98,7 +98,7 @@ void _L2NormalizeGrad( ...@@ -98,7 +98,7 @@ void _L2NormalizeGrad(
auto X = ConstEigenStridedVectorMap<T>( auto X = ConstEigenStridedVectorMap<T>(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim)); x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
auto norm = std::max(std::sqrt(X.squaredNorm() / normalizer), epsilon); auto norm = std::max(std::sqrt(X.squaredNorm() / normalizer), epsilon);
auto norm3 = std::pow(norm, 3); auto norm3 = std::pow(norm, T(3));
EigenStridedVectorMap<T>( EigenStridedVectorMap<T>(
dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) = dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
(dY / norm) - ((X / norm3) * dY.dot(X) / normalizer); (dY / norm) - ((X / norm3) * dY.dot(X) / normalizer);
......
...@@ -93,7 +93,7 @@ __global__ void _L1NormalizeGrad( ...@@ -93,7 +93,7 @@ __global__ void _L1NormalizeGrad(
val2 = BlockReduce<AccT>(storage).Sum(val2); val2 = BlockReduce<AccT>(storage).Sum(val2);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
norm = max(val1 / normalizer, epsilon); norm = max(val1 / normalizer, epsilon);
norm2 = pow(norm, 2); norm2 = pow(norm, AccT(2));
sum = val2 / normalizer; sum = val2 / normalizer;
} }
__syncthreads(); __syncthreads();
...@@ -130,7 +130,7 @@ __global__ void _L2NormalizeGrad( ...@@ -130,7 +130,7 @@ __global__ void _L2NormalizeGrad(
val2 = BlockReduce<AccT>(storage).Sum(val2); val2 = BlockReduce<AccT>(storage).Sum(val2);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
norm = max(sqrt(val1 / normalizer), epsilon); norm = max(sqrt(val1 / normalizer), epsilon);
norm3 = pow(norm, 3); norm3 = pow(norm, AccT(3));
sum = val2 / normalizer; sum = val2 / normalizer;
} }
__syncthreads(); __syncthreads();
......
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
......
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename Tx, typename Ty> template <typename InputT, typename OutputT>
void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() { void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
...@@ -35,10 +35,10 @@ void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() { ...@@ -35,10 +35,10 @@ void ChannelNormalizeOp<Context>::DoRunWithTypeAndCast() {
num_dims, num_dims,
X_strides.data(), X_strides.data(),
Y_dims.data(), Y_dims.data(),
X.template data<Tx, Context>(), X.template data<InputT, Context>(),
X_mean_.template data<float, Context>(), X_mean_.template data<float, Context>(),
X_std_.template data<float, Context>(), X_std_.template data<float, Context>(),
Y->Reshape(Y_dims)->template mutable_data<Ty, Context>(), Y->Reshape(Y_dims)->template mutable_data<OutputT, Context>(),
ctx()); ctx());
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void NLLLossOp<Context>::DoRunWithType() { void NLLLossOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
...@@ -19,19 +19,19 @@ void NLLLossOp<Context>::DoRunWithType() { ...@@ -19,19 +19,19 @@ void NLLLossOp<Context>::DoRunWithType() {
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
auto scratches = ctx()->workspace()->template data<Context>({ auto scratches = ctx()->workspace()->template data<Context>({
(size_t)num_preds * sizeof(LogitType), // loss (size_t)num_preds * sizeof(LogitT), // loss
(size_t)num_preds * sizeof(LogitType) + sizeof(LogitType), // mask (size_t)num_preds * sizeof(LogitT) + sizeof(LogitT), // mask
}); });
auto* loss = static_cast<LogitType*>(scratches[0]); auto* loss = static_cast<LogitT*>(scratches[0]);
auto* mask = static_cast<LogitType*>(scratches[1]); auto* mask = static_cast<LogitT*>(scratches[1]);
kernel::NLLLoss( kernel::NLLLoss(
outer_dim, outer_dim,
inner_dim, inner_dim,
X.dim(axis), X.dim(axis),
ignore_index_, ignore_index_,
X.template data<LogitType, Context>(), X.template data<LogitT, Context>(),
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetT, Context>(),
loss, loss,
mask, mask,
ctx()); ctx());
...@@ -42,7 +42,7 @@ void NLLLossOp<Context>::DoRunWithType() { ...@@ -42,7 +42,7 @@ void NLLLossOp<Context>::DoRunWithType() {
math::Copy( math::Copy(
num_preds, num_preds,
loss, loss,
Y->Reshape(out_shape)->template mutable_data<LogitType, Context>(), Y->Reshape(out_shape)->template mutable_data<LogitT, Context>(),
ctx()); ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
...@@ -59,7 +59,7 @@ void NLLLossOp<Context>::DoRunWithType() { ...@@ -59,7 +59,7 @@ void NLLLossOp<Context>::DoRunWithType() {
normalizer, normalizer,
loss, loss,
mask, mask,
Y->Reshape({})->template mutable_data<LogitType, Context>(), Y->Reshape({})->template mutable_data<LogitT, Context>(),
ctx()); ctx());
} }
} }
...@@ -91,7 +91,7 @@ void NLLLossOp<Context>::RunOnDevice() { ...@@ -91,7 +91,7 @@ void NLLLossOp<Context>::RunOnDevice() {
} }
template <class Context> template <class Context>
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void NLLLossGradientOp<Context>::DoRunWithType() { void NLLLossGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(-1), *dX = Output(0); auto &X = Input(0), &dY = Input(-1), *dX = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
...@@ -101,19 +101,19 @@ void NLLLossGradientOp<Context>::DoRunWithType() { ...@@ -101,19 +101,19 @@ void NLLLossGradientOp<Context>::DoRunWithType() {
auto inner_dim = dX->count(axis + 1); auto inner_dim = dX->count(axis + 1);
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
auto* dy = dY.template data<LogitType, Context>(); auto* dy = dY.template data<LogitT, Context>();
auto* dx = dX->template mutable_data<LogitType, Context>(); auto* dx = dX->template mutable_data<LogitT, Context>();
auto* mask = auto* mask =
ctx()->workspace()->template data<LogitType, Context>({num_preds + 1})[0]; ctx()->workspace()->template data<LogitT, Context>({num_preds + 1})[0];
math::Set(dX->count(), convert::To<LogitType>(0.f), dx, ctx()); math::Set(dX->count(), convert::To<LogitT>(0.f), dx, ctx());
kernel::NLLLossGrad( kernel::NLLLossGrad(
outer_dim, outer_dim,
inner_dim, inner_dim,
dX->dim(axis), dX->dim(axis),
ignore_index_, ignore_index_,
X.template data<LogitType, Context>(), X.template data<LogitT, Context>(),
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetT, Context>(),
dx, dx,
mask, mask,
ctx()); ctx());
......
...@@ -28,7 +28,7 @@ class NLLLossOp final : public Operator<Context> { ...@@ -28,7 +28,7 @@ class NLLLossOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void DoRunWithType(); void DoRunWithType();
protected: protected:
...@@ -47,7 +47,7 @@ class NLLLossGradientOp final : public Operator<Context> { ...@@ -47,7 +47,7 @@ class NLLLossGradientOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void DoRunWithType(); void DoRunWithType();
protected: protected:
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void SigmoidFocalLossOp<Context>::DoRunWithType() { void SigmoidFocalLossOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
...@@ -18,11 +18,11 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() { ...@@ -18,11 +18,11 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() {
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
auto scratches = ctx()->workspace()->template data<Context>({ auto scratches = ctx()->workspace()->template data<Context>({
X.size() * sizeof(LogitType), // loss X.size() * sizeof(LogitT), // loss
X.size() * sizeof(LogitType) + sizeof(LogitType), // mask X.size() * sizeof(LogitT) + sizeof(LogitT), // mask
}); });
auto* loss = static_cast<LogitType*>(scratches[0]); auto* loss = static_cast<LogitT*>(scratches[0]);
auto* mask = static_cast<LogitType*>(scratches[1]); auto* mask = static_cast<LogitT*>(scratches[1]);
kernel::SigmoidFocalLoss( kernel::SigmoidFocalLoss(
outer_dim, outer_dim,
...@@ -32,8 +32,8 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() { ...@@ -32,8 +32,8 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() {
neg_alpha_, neg_alpha_,
gamma_, gamma_,
negative_index_, negative_index_,
X.template data<LogitType, Context>(), X.template data<LogitT, Context>(),
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetT, Context>(),
loss, loss,
mask, mask,
ctx()); ctx());
...@@ -42,7 +42,7 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() { ...@@ -42,7 +42,7 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() {
math::Copy( math::Copy(
X.count(), X.count(),
loss, loss,
Y->ReshapeLike(X)->template mutable_data<LogitType, Context>(), Y->ReshapeLike(X)->template mutable_data<LogitT, Context>(),
ctx()); ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
...@@ -59,7 +59,7 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() { ...@@ -59,7 +59,7 @@ void SigmoidFocalLossOp<Context>::DoRunWithType() {
normalizer, normalizer,
loss, loss,
mask, mask,
Y->Reshape({})->template mutable_data<LogitType, Context>(), Y->Reshape({})->template mutable_data<LogitT, Context>(),
ctx()); ctx());
} }
} }
...@@ -91,7 +91,7 @@ void SigmoidFocalLossOp<Context>::RunOnDevice() { ...@@ -91,7 +91,7 @@ void SigmoidFocalLossOp<Context>::RunOnDevice() {
} }
template <class Context> template <class Context>
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void SigmoidFocalLossGradientOp<Context>::DoRunWithType() { void SigmoidFocalLossGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &dY = Input(-1), *dX = Output(0); auto &X = Input(0), &dY = Input(-1), *dX = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
...@@ -100,10 +100,10 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() { ...@@ -100,10 +100,10 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() {
auto outer_dim = dX->count(0, axis); auto outer_dim = dX->count(0, axis);
auto inner_dim = dX->count(axis + 1); auto inner_dim = dX->count(axis + 1);
auto* dy = dY.template data<LogitType, Context>(); auto* dy = dY.template data<LogitT, Context>();
auto* dx = dX->template mutable_data<LogitType, Context>(); auto* dx = dX->template mutable_data<LogitT, Context>();
auto* mask = ctx()->workspace()->template data<LogitType, Context>( auto* mask =
{dX->count() + 1})[0]; ctx()->workspace()->template data<LogitT, Context>({dX->count() + 1})[0];
kernel::SigmoidFocalLossGrad( kernel::SigmoidFocalLossGrad(
outer_dim, outer_dim,
...@@ -113,8 +113,8 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() { ...@@ -113,8 +113,8 @@ void SigmoidFocalLossGradientOp<Context>::DoRunWithType() {
neg_alpha_, neg_alpha_,
gamma_, gamma_,
negative_index_, negative_index_,
X.template data<LogitType, Context>(), X.template data<LogitT, Context>(),
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetT, Context>(),
dx, dx,
mask, mask,
ctx()); ctx());
......
...@@ -48,7 +48,7 @@ class SigmoidFocalLossOp final : public Operator<Context> { ...@@ -48,7 +48,7 @@ class SigmoidFocalLossOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void DoRunWithType(); void DoRunWithType();
protected: protected:
...@@ -88,7 +88,7 @@ class SigmoidFocalLossGradientOp final : public Operator<Context> { ...@@ -88,7 +88,7 @@ class SigmoidFocalLossGradientOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void DoRunWithType(); void DoRunWithType();
protected: protected:
......
...@@ -45,7 +45,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -45,7 +45,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void DoRunWithType(); void DoRunWithType();
protected: protected:
...@@ -81,7 +81,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> { ...@@ -81,7 +81,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void DoRunWithType(); void DoRunWithType();
protected: protected:
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
...@@ -18,20 +18,20 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -18,20 +18,20 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
CHECK_EQ(num_preds, Input(1).count()) CHECK_EQ(num_preds, Input(1).count())
<< "\nNumber of preds must match the number of targets."; << "\nNumber of preds must match the number of targets.";
auto* X_prob = Buffer("prob")->ReshapeLike(X); auto* X_prob = Buffer("prob")->ReshapeLike(X);
auto* prob = X_prob->template mutable_data<LogitType, Context>(); auto* prob = X_prob->template mutable_data<LogitT, Context>();
auto scratches = ctx()->workspace()->template data<Context>({ auto scratches = ctx()->workspace()->template data<Context>({
(size_t)num_preds * sizeof(LogitType), // loss (size_t)num_preds * sizeof(LogitT), // loss
(size_t)num_preds * sizeof(LogitType) + sizeof(LogitType), // mask (size_t)num_preds * sizeof(LogitT) + sizeof(LogitT), // mask
}); });
auto* loss = static_cast<LogitType*>(scratches[0]); auto* loss = static_cast<LogitT*>(scratches[0]);
auto* mask = static_cast<LogitType*>(scratches[1]); auto* mask = static_cast<LogitT*>(scratches[1]);
kernel::Softmax( kernel::Softmax(
outer_dim, outer_dim,
inner_dim, inner_dim,
X.dim(axis), X.dim(axis),
X.template data<LogitType, Context>(), X.template data<LogitT, Context>(),
prob, prob,
ctx()); ctx());
...@@ -41,7 +41,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -41,7 +41,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
X.dim(axis), X.dim(axis),
ignore_index_, ignore_index_,
prob, prob,
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetT, Context>(),
loss, loss,
mask, mask,
ctx()); ctx());
...@@ -52,7 +52,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -52,7 +52,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
math::Copy( math::Copy(
num_preds, num_preds,
loss, loss,
Y->Reshape(out_shape)->template mutable_data<LogitType, Context>(), Y->Reshape(out_shape)->template mutable_data<LogitT, Context>(),
ctx()); ctx());
} else { } else {
int64_t normalizer = 1; int64_t normalizer = 1;
...@@ -69,7 +69,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() { ...@@ -69,7 +69,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::DoRunWithType() {
normalizer, normalizer,
loss, loss,
mask, mask,
Y->Reshape({})->template mutable_data<LogitType, Context>(), Y->Reshape({})->template mutable_data<LogitT, Context>(),
ctx()); ctx());
} }
} }
...@@ -101,7 +101,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() { ...@@ -101,7 +101,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() {
} }
template <class Context> template <class Context>
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
auto &dY = Input(-1), *dX = Output(0); auto &dY = Input(-1), *dX = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(Input(0)); CANONICALIZE_AXIS_WITH_TENSOR(Input(0));
...@@ -110,11 +110,11 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -110,11 +110,11 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
auto inner_dim = dX->count(axis + 1); auto inner_dim = dX->count(axis + 1);
auto num_preds = outer_dim * inner_dim; auto num_preds = outer_dim * inner_dim;
auto* prob = Buffer("prob")->template data<LogitType, Context>(); auto* prob = Buffer("prob")->template data<LogitT, Context>();
auto* dy = Input(-1).template data<LogitType, Context>(); auto* dy = Input(-1).template data<LogitT, Context>();
auto* dx = Output(0)->template mutable_data<LogitType, Context>(); auto* dx = Output(0)->template mutable_data<LogitT, Context>();
auto* mask = auto* mask =
ctx()->workspace()->template data<LogitType, Context>({num_preds + 1})[0]; ctx()->workspace()->template data<LogitT, Context>({num_preds + 1})[0];
math::Copy(dX->count(), prob, dx, ctx()); math::Copy(dX->count(), prob, dx, ctx());
...@@ -124,7 +124,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() { ...@@ -124,7 +124,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::DoRunWithType() {
dX->dim(axis), dX->dim(axis),
ignore_index_, ignore_index_,
prob, prob,
Input(1).template data<TargetType, Context>(), Input(1).template data<TargetT, Context>(),
dx, dx,
mask, mask,
ctx()); ctx());
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename Tx, typename Ty> template <typename T>
void MomentsOp<Context>::DoRunWithType() { void MomentsOp<Context>::DoRunWithType() {
using OutputT = typename math::utils::AccmulatorType<T>::type;
auto &X = Input(0), *Y1 = Output(0), *Y2 = Output(1); auto &X = Input(0), *Y1 = Output(0), *Y2 = Output(1);
// Determine the reduce axes // Determine the reduce axes
...@@ -35,13 +36,13 @@ void MomentsOp<Context>::DoRunWithType() { ...@@ -35,13 +36,13 @@ void MomentsOp<Context>::DoRunWithType() {
if (X.count() == 1) { if (X.count() == 1) {
math::Cast( math::Cast(
1, 1,
X.template data<Tx, Context>(), X.template data<T, Context>(),
Y1->Reshape(Y_shape)->template mutable_data<Ty, Context>(), Y1->Reshape(Y_shape)->template mutable_data<OutputT, Context>(),
ctx()); ctx());
math::Set( math::Set(
1, 1,
convert::To<Ty>(0.f), convert::To<OutputT>(0.f),
Y2->Reshape(Y_shape)->template mutable_data<Ty, Context>(), Y2->Reshape(Y_shape)->template mutable_data<OutputT, Context>(),
ctx()); ctx());
} else { } else {
kernel::Moments( kernel::Moments(
...@@ -49,35 +50,16 @@ void MomentsOp<Context>::DoRunWithType() { ...@@ -49,35 +50,16 @@ void MomentsOp<Context>::DoRunWithType() {
X_dims.data(), X_dims.data(),
reduce_axes.size(), reduce_axes.size(),
reduce_axes.data(), reduce_axes.data(),
X.template data<Tx, Context>(), X.template data<T, Context>(),
Y1->Reshape(Y_shape)->template mutable_data<Ty, Context>(), Y1->Reshape(Y_shape)->template mutable_data<OutputT, Context>(),
Y2->Reshape(Y_shape)->template mutable_data<Ty, Context>(), Y2->Reshape(Y_shape)->template mutable_data<OutputT, Context>(),
ctx()); ctx());
} }
} }
template <class Context> template <class Context>
void MomentsOp<Context>::RunOnDevice() { void MomentsOp<Context>::RunOnDevice() {
auto& X = Input(0); DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
if (X.template IsType<int8_t>()) {
DoRunWithType<int8_t, float>();
} else if (X.template IsType<uint8_t>()) {
DoRunWithType<uint8_t, float>();
} else if (X.template IsType<int>()) {
DoRunWithType<int, float>();
} else if (X.template IsType<int64_t>()) {
DoRunWithType<int64_t, float>();
} else if (X.template IsType<float16>()) {
DoRunWithType<float16, float>();
} else if (X.template IsType<float>()) {
DoRunWithType<float, float>();
} else if (X.template IsType<double>()) {
DoRunWithType<double, double>();
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(X.meta()),
{"int8", "uint8", "int32", "int64", "float16", "float32", "float64"});
}
} }
DEPLOY_CPU_OPERATOR(Moments); DEPLOY_CPU_OPERATOR(Moments);
......
...@@ -28,7 +28,7 @@ class MomentsOp final : public Operator<Context> { ...@@ -28,7 +28,7 @@ class MomentsOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected: protected:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void AccuracyOp<Context>::DoRunWithType() { void AccuracyOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X); CANONICALIZE_AXIS_WITH_TENSOR(X);
...@@ -18,21 +18,21 @@ void AccuracyOp<Context>::DoRunWithType() { ...@@ -18,21 +18,21 @@ void AccuracyOp<Context>::DoRunWithType() {
int64_t acc = 0, count = 0; int64_t acc = 0, count = 0;
int64_t cols = X.count() / outer_dim; int64_t cols = X.count() / outer_dim;
auto* logit = X.template data<LogitType, CPUContext>(); auto* logit = X.template data<LogitT, CPUContext>();
auto* target = Input(1).template data<TargetType, CPUContext>(); auto* target = Input(1).template data<TargetT, CPUContext>();
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < inner_dim; ++j) { for (int j = 0; j < inner_dim; ++j) {
const int label = target[i * inner_dim + j]; const int label = target[i * inner_dim + j];
if (label == ignore_index_) continue; if (label == ignore_index_) continue;
vector<pair<LogitType, int>> vec; vector<pair<LogitT, int>> vec;
for (int k = 0; k < axis_dim; k++) for (int k = 0; k < axis_dim; k++)
vec.push_back(std::make_pair(logit[i * cols + k * inner_dim + j], k)); vec.push_back(std::make_pair(logit[i * cols + k * inner_dim + j], k));
std::partial_sort( std::partial_sort(
vec.begin(), vec.begin(),
vec.begin() + top_k_, vec.begin() + top_k_,
vec.end(), vec.end(),
std::greater<pair<LogitType, int>>()); std::greater<pair<LogitT, int>>());
for (int k = 0; k < top_k_; k++) { for (int k = 0; k < top_k_; k++) {
if (vec[k].second == label) { if (vec[k].second == label) {
acc++; acc++;
......
...@@ -28,7 +28,7 @@ class AccuracyOp final : public Operator<Context> { ...@@ -28,7 +28,7 @@ class AccuracyOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename LogitType, typename TargetType> template <typename LogitT, typename TargetT>
void DoRunWithType(); void DoRunWithType();
protected: protected:
......
...@@ -8,11 +8,11 @@ namespace dragon { ...@@ -8,11 +8,11 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void BatchNormOp<Context>::TrainingImpl() { void BatchNormOp<Context>::TrainingImpl() {
using ParamType = typename math::utils::AccmulatorType<T>::type; using ParamT = typename math::utils::AccmulatorType<T>::type;
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamT);
auto* X_mu = Buffer("X_mu")->Reshape({C_}); auto* X_mu = Buffer("X_mu")->Reshape({C_});
auto* X_rsig = Buffer("X_rsig")->Reshape({C_}); auto* X_rsig = Buffer("X_rsig")->Reshape({C_});
...@@ -20,11 +20,11 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -20,11 +20,11 @@ void BatchNormOp<Context>::TrainingImpl() {
auto* X_bias = Buffer("X_bias")->Reshape({C_}); auto* X_bias = Buffer("X_bias")->Reshape({C_});
auto* x = Input(0).template data<T, Context>(); auto* x = Input(0).template data<T, Context>();
auto* rm = Input(3).template mutable_data<ParamType, Context>(); auto* rm = Input(3).template mutable_data<ParamT, Context>();
auto* rv = Input(4).template mutable_data<ParamType, Context>(); auto* rv = Input(4).template mutable_data<ParamT, Context>();
auto* mu = X_mu->template mutable_data<ParamType, Context>(); auto* mu = X_mu->template mutable_data<ParamT, Context>();
auto* rsig = X_rsig->template mutable_data<ParamType, Context>(); auto* rsig = X_rsig->template mutable_data<ParamT, Context>();
auto* scale = X_scale->template mutable_data<ParamType, Context>(); auto* scale = X_scale->template mutable_data<ParamT, Context>();
// Compute moments // Compute moments
if (sync_stats_ > 0) { if (sync_stats_ > 0) {
...@@ -45,7 +45,7 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -45,7 +45,7 @@ void BatchNormOp<Context>::TrainingImpl() {
if (enable_nccl_) { if (enable_nccl_) {
#ifdef USE_NCCL #ifdef USE_NCCL
auto coll_comm = this->nccl_comm(); auto coll_comm = this->nccl_comm();
auto coll_dtype = this->template nccl_dtype<ParamType>(); auto coll_dtype = this->template nccl_dtype<ParamT>();
NCCL_CHECK(ncclAllReduce( NCCL_CHECK(ncclAllReduce(
(void*)mu, (void*)mu,
(void*)mu, (void*)mu,
...@@ -84,8 +84,9 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -84,8 +84,9 @@ void BatchNormOp<Context>::TrainingImpl() {
// Compute running statistics // Compute running statistics
if (is_recomputing_ == 0) { if (is_recomputing_ == 0) {
math::Axpby(C_, 1.f - momentum_, mu, momentum_, rm, ctx()); auto decay_factor = momentum();
math::Axpby(C_, 1.f - momentum_, rsig, momentum_, rv, ctx()); math::Axpby(C_, 1.f - decay_factor, mu, decay_factor, rm, ctx());
math::Axpby(C_, 1.f - decay_factor, rsig, decay_factor, rv, ctx());
} }
// Inverse stddev from variance // Inverse stddev from variance
...@@ -100,10 +101,10 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -100,10 +101,10 @@ void BatchNormOp<Context>::TrainingImpl() {
x, x,
mu, mu,
rsig, rsig,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta Input(2).template data<ParamT, Context>(), // beta
scale, scale,
X_bias->template mutable_data<ParamType, Context>(), X_bias->template mutable_data<ParamT, Context>(),
Output(0)->template mutable_data<T, Context>(), Output(0)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -111,17 +112,17 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -111,17 +112,17 @@ void BatchNormOp<Context>::TrainingImpl() {
template <class Context> template <class Context>
template <typename T> template <typename T>
void BatchNormOp<Context>::InferenceImpl() { void BatchNormOp<Context>::InferenceImpl() {
using ParamType = typename math::utils::AccmulatorType<T>::type; using ParamT = typename math::utils::AccmulatorType<T>::type;
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamT);
auto* X_rsig = Buffer("X_rsig")->Reshape({C_}); auto* X_rsig = Buffer("X_rsig")->Reshape({C_});
auto* X_scale = Buffer("X_scale")->Reshape({C_}); auto* X_scale = Buffer("X_scale")->Reshape({C_});
auto* X_bias = Buffer("X_bias")->Reshape({C_}); auto* X_bias = Buffer("X_bias")->Reshape({C_});
auto* rv = Input(4).template data<ParamType, Context>(); auto* rv = Input(4).template data<ParamT, Context>();
auto* rsig = X_rsig->template mutable_data<ParamType, Context>(); auto* rsig = X_rsig->template mutable_data<ParamT, Context>();
// Inverse stddev from variance // Inverse stddev from variance
math::InvStd(C_, epsilon_, rv, rsig, ctx()); math::InvStd(C_, epsilon_, rv, rsig, ctx());
...@@ -133,12 +134,12 @@ void BatchNormOp<Context>::InferenceImpl() { ...@@ -133,12 +134,12 @@ void BatchNormOp<Context>::InferenceImpl() {
S_, S_,
data_format(), data_format(),
Input(0).template data<T, Context>(), Input(0).template data<T, Context>(),
Input(3).template data<ParamType, Context>(), Input(3).template data<ParamT, Context>(),
rsig, rsig,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta Input(2).template data<ParamT, Context>(), // beta
X_scale->template mutable_data<ParamType, Context>(), X_scale->template mutable_data<ParamT, Context>(),
X_bias->template mutable_data<ParamType, Context>(), X_bias->template mutable_data<ParamT, Context>(),
Output(0)->template mutable_data<T, Context>(), Output(0)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -159,17 +160,17 @@ void BatchNormOp<Context>::RunOnDevice() { ...@@ -159,17 +160,17 @@ void BatchNormOp<Context>::RunOnDevice() {
template <class Context> template <class Context>
template <typename T> template <typename T>
void BatchNormGradientOp<Context>::TrainingImpl() { void BatchNormGradientOp<Context>::TrainingImpl() {
using ParamType = typename math::utils::AccmulatorType<T>::type; using ParamT = typename math::utils::AccmulatorType<T>::type;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig"); auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
auto* x = Input(0).template data<T, Context>(); auto* x = Input(0).template data<T, Context>();
auto* gamma = Input(1).template data<ParamType, Context>(); auto* gamma = Input(1).template data<ParamT, Context>();
auto* dy = Input(4).template data<T, Context>(); auto* dy = Input(4).template data<T, Context>();
auto* mu = X_mu->template data<ParamType, Context>(); auto* mu = X_mu->template data<ParamT, Context>();
auto* rsig = X_rsig->template data<ParamType, Context>(); auto* rsig = X_rsig->template data<ParamT, Context>();
auto* dgamma = dW->Reshape({C_})->template mutable_data<ParamType, Context>(); auto* dgamma = dW->Reshape({C_})->template mutable_data<ParamT, Context>();
auto* dbeta = dB->Reshape({C_})->template mutable_data<ParamType, Context>(); auto* dbeta = dB->Reshape({C_})->template mutable_data<ParamT, Context>();
// Gradient w.r.t. gamma and beta // Gradient w.r.t. gamma and beta
kernel::BatchNormWGrad( kernel::BatchNormWGrad(
...@@ -181,7 +182,7 @@ void BatchNormGradientOp<Context>::TrainingImpl() { ...@@ -181,7 +182,7 @@ void BatchNormGradientOp<Context>::TrainingImpl() {
if (enable_nccl_) { if (enable_nccl_) {
#ifdef USE_NCCL #ifdef USE_NCCL
auto coll_comm = this->nccl_comm(); auto coll_comm = this->nccl_comm();
auto coll_dtype = this->template nccl_dtype<ParamType>(); auto coll_dtype = this->template nccl_dtype<ParamT>();
NCCL_CHECK(ncclAllReduce( NCCL_CHECK(ncclAllReduce(
(void*)dgamma, (void*)dgamma,
(void*)dgamma, (void*)dgamma,
...@@ -231,18 +232,18 @@ void BatchNormGradientOp<Context>::TrainingImpl() { ...@@ -231,18 +232,18 @@ void BatchNormGradientOp<Context>::TrainingImpl() {
template <class Context> template <class Context>
template <typename T> template <typename T>
void BatchNormGradientOp<Context>::InferenceImpl() { void BatchNormGradientOp<Context>::InferenceImpl() {
using ParamType = typename math::utils::AccmulatorType<T>::type; using ParamT = typename math::utils::AccmulatorType<T>::type;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto* X_scale = Buffer("X_scale")->Reshape({C_}); auto* X_scale = Buffer("X_scale")->Reshape({C_});
auto* rv = Input(3).template data<ParamType, Context>(); auto* rv = Input(3).template data<ParamT, Context>();
auto* rsig = X_scale->template mutable_data<ParamType, Context>(); auto* rsig = X_scale->template mutable_data<ParamT, Context>();
// Gradient w.r.t. gamma or beta if necessary // Gradient w.r.t. gamma or beta if necessary
ParamType *dgamma = nullptr, *dbeta = nullptr; ParamT *dgamma = nullptr, *dbeta = nullptr;
if (dW->has_name() || dB->has_name()) { if (dW->has_name() || dB->has_name()) {
dgamma = dW->Reshape({C_})->template mutable_data<ParamType, Context>(); dgamma = dW->Reshape({C_})->template mutable_data<ParamT, Context>();
dbeta = dB->Reshape({C_})->template mutable_data<ParamType, Context>(); dbeta = dB->Reshape({C_})->template mutable_data<ParamT, Context>();
} }
// Inverse stddev from variance // Inverse stddev from variance
...@@ -255,9 +256,9 @@ void BatchNormGradientOp<Context>::InferenceImpl() { ...@@ -255,9 +256,9 @@ void BatchNormGradientOp<Context>::InferenceImpl() {
S_, S_,
data_format(), data_format(),
Input(0).template data<T, Context>(), // x Input(0).template data<T, Context>(), // x
Input(2).template data<ParamType, Context>(), // rm Input(2).template data<ParamT, Context>(), // rm
rsig, rsig,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamT, Context>(), // gamma
Input(4).template data<T, Context>(), // dy Input(4).template data<T, Context>(), // dy
dgamma, dgamma,
dbeta, dbeta,
......
...@@ -33,7 +33,6 @@ class BatchNormOpBase : public GenericOpBase<Context> { ...@@ -33,7 +33,6 @@ class BatchNormOpBase : public GenericOpBase<Context> {
public: public:
BatchNormOpBase(const OperatorDef& def, Workspace* ws) BatchNormOpBase(const OperatorDef& def, Workspace* ws)
: GenericOpBase<Context>(def, ws), : GenericOpBase<Context>(def, ws),
momentum_(OP_SINGLE_ARG(float, "momentum", 0.9f)),
epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-5)), epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-5)),
use_stats_(OP_SINGLE_ARG(int64_t, "use_stats", -1)), use_stats_(OP_SINGLE_ARG(int64_t, "use_stats", -1)),
sync_stats_(OP_SINGLE_ARG(int64_t, "comm", 0) > 0 ? 1 : 0) {} sync_stats_(OP_SINGLE_ARG(int64_t, "comm", 0) > 0 ? 1 : 0) {}
...@@ -57,7 +56,6 @@ class BatchNormOpBase : public GenericOpBase<Context> { ...@@ -57,7 +56,6 @@ class BatchNormOpBase : public GenericOpBase<Context> {
} }
protected: protected:
float momentum_;
double epsilon_; double epsilon_;
int64_t N_, C_, S_; int64_t N_, C_, S_;
int64_t use_stats_, sync_stats_; int64_t use_stats_, sync_stats_;
...@@ -68,7 +66,6 @@ class BatchNormOpBase : public GenericOpBase<Context> { ...@@ -68,7 +66,6 @@ class BatchNormOpBase : public GenericOpBase<Context> {
#define USE_BATCHNORM_FUNCTIONS \ #define USE_BATCHNORM_FUNCTIONS \
using BatchNormOpBase<Context>::DetermineBaseArguments; \ using BatchNormOpBase<Context>::DetermineBaseArguments; \
using BatchNormOpBase<Context>::momentum_; \
using BatchNormOpBase<Context>::epsilon_; \ using BatchNormOpBase<Context>::epsilon_; \
using BatchNormOpBase<Context>::use_stats_; \ using BatchNormOpBase<Context>::use_stats_; \
using BatchNormOpBase<Context>::sync_stats_; \ using BatchNormOpBase<Context>::sync_stats_; \
...@@ -82,7 +79,9 @@ template <class Context> ...@@ -82,7 +79,9 @@ template <class Context>
class BatchNormOp : public BatchNormOpBase<Context> { class BatchNormOp : public BatchNormOpBase<Context> {
public: public:
BatchNormOp(const OperatorDef& def, Workspace* ws) BatchNormOp(const OperatorDef& def, Workspace* ws)
: BatchNormOpBase<Context>(def, ws) {} : BatchNormOpBase<Context>(def, ws) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, momentum, 0.9f);
}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_BATCHNORM_FUNCTIONS; USE_BATCHNORM_FUNCTIONS;
#ifdef USE_MPI #ifdef USE_MPI
...@@ -105,6 +104,8 @@ class BatchNormOp : public BatchNormOpBase<Context> { ...@@ -105,6 +104,8 @@ class BatchNormOp : public BatchNormOpBase<Context> {
InferenceImpl<T>(); InferenceImpl<T>();
} }
}; };
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, momentum);
}; };
template <class Context> template <class Context>
...@@ -146,11 +147,9 @@ class CuDNNBatchNormOp final : public BatchNormOpBase<Context> { ...@@ -146,11 +147,9 @@ class CuDNNBatchNormOp final : public BatchNormOpBase<Context> {
CuDNNCreateTensorDesc(&bn_desc_); CuDNNCreateTensorDesc(&bn_desc_);
CuDNNCreateTensorDesc(&input_desc_); CuDNNCreateTensorDesc(&input_desc_);
if (epsilon_ <= CUDNN_BN_MIN_EPSILON) { if (epsilon_ <= CUDNN_BN_MIN_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. \nSet it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
epsilon_ = CUDNN_BN_MIN_EPSILON; epsilon_ = CUDNN_BN_MIN_EPSILON;
} }
INIT_OP_SINGLE_ARG_WITH_DESC(float, momentum, 0.9f);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_BATCHNORM_FUNCTIONS; USE_BATCHNORM_FUNCTIONS;
...@@ -168,6 +167,7 @@ class CuDNNBatchNormOp final : public BatchNormOpBase<Context> { ...@@ -168,6 +167,7 @@ class CuDNNBatchNormOp final : public BatchNormOpBase<Context> {
protected: protected:
cudnnTensorDescriptor_t input_desc_, bn_desc_; cudnnTensorDescriptor_t input_desc_, bn_desc_;
cudnnBatchNormMode_t bn_mode_; cudnnBatchNormMode_t bn_mode_;
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, momentum);
}; };
template <class Context> template <class Context>
...@@ -178,9 +178,6 @@ class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> { ...@@ -178,9 +178,6 @@ class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> {
CuDNNCreateTensorDesc(&bn_desc_); CuDNNCreateTensorDesc(&bn_desc_);
CuDNNCreateTensorDesc(&input_desc_); CuDNNCreateTensorDesc(&input_desc_);
if (epsilon_ <= CUDNN_BN_MIN_EPSILON) { if (epsilon_ <= CUDNN_BN_MIN_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. \nSet it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
epsilon_ = CUDNN_BN_MIN_EPSILON; epsilon_ = CUDNN_BN_MIN_EPSILON;
} }
} }
...@@ -211,8 +208,12 @@ class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> { ...@@ -211,8 +208,12 @@ class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> {
cudnnBatchNormMode_t bn_mode_; cudnnBatchNormMode_t bn_mode_;
}; };
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, CuDNNBatchNormOp, momentum);
#endif // USE_CUDNN #endif // USE_CUDNN
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, BatchNormOp, momentum);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NORMALIZATION_BATCH_NORM_OP_H_ #endif // DRAGON_OPERATORS_NORMALIZATION_BATCH_NORM_OP_H_
...@@ -9,11 +9,11 @@ namespace dragon { ...@@ -9,11 +9,11 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void CuDNNBatchNormOp<Context>::DoRunWithType() { void CuDNNBatchNormOp<Context>::DoRunWithType() {
using ParamType = typename CuDNNType<T>::BNParamType; using ParamT = typename CuDNNType<T>::BNParamType;
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamT);
// Determine the descriptors // Determine the descriptors
if (Input(0).ndim() == 2) { if (Input(0).ndim() == 2) {
...@@ -39,14 +39,14 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() { ...@@ -39,14 +39,14 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() {
input_desc_, input_desc_,
Output(0)->template mutable_data<T, Context>(), // y Output(0)->template mutable_data<T, Context>(), // y
bn_desc_, bn_desc_,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta Input(2).template data<ParamT, Context>(), // beta
is_recomputing_ > 0 ? 0.f : 1.f - this->momentum_, is_recomputing_ == 0 ? 1.f - momentum() : 0.f,
Input(3).template mutable_data<ParamType, Context>(), // rm Input(3).template mutable_data<ParamT, Context>(), // rm
Input(4).template mutable_data<ParamType, Context>(), // rv Input(4).template mutable_data<ParamT, Context>(), // rv
epsilon_, epsilon_,
X_mu->template mutable_data<ParamType, Context>(), // sm X_mu->template mutable_data<ParamT, Context>(), // sm
X_rsig->template mutable_data<ParamType, Context>())); // sv X_rsig->template mutable_data<ParamT, Context>())); // sv
} else { } else {
CUDNN_CHECK(cudnnBatchNormalizationForwardInference( CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
...@@ -58,10 +58,10 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() { ...@@ -58,10 +58,10 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() {
input_desc_, input_desc_,
Output(0)->template mutable_data<T, Context>(), // y Output(0)->template mutable_data<T, Context>(), // y
bn_desc_, bn_desc_,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta Input(2).template data<ParamT, Context>(), // beta
Input(3).template data<ParamType, Context>(), // rm Input(3).template data<ParamT, Context>(), // rm
Input(4).template data<ParamType, Context>(), // rv Input(4).template data<ParamT, Context>(), // rv
epsilon_)); epsilon_));
} }
} }
...@@ -82,7 +82,7 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() { ...@@ -82,7 +82,7 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() {
template <class Context> template <class Context>
template <typename T> template <typename T>
void CuDNNBatchNormGradientOp<Context>::TrainingImpl() { void CuDNNBatchNormGradientOp<Context>::TrainingImpl() {
using ParamType = typename CuDNNType<T>::BNParamType; using ParamT = typename CuDNNType<T>::BNParamType;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig"); auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
...@@ -111,12 +111,12 @@ void CuDNNBatchNormGradientOp<Context>::TrainingImpl() { ...@@ -111,12 +111,12 @@ void CuDNNBatchNormGradientOp<Context>::TrainingImpl() {
input_desc_, input_desc_,
Output(0)->template mutable_data<T, Context>(), // dx Output(0)->template mutable_data<T, Context>(), // dx
bn_desc_, bn_desc_,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamT, Context>(), // gamma
dW->Reshape({C_})->template mutable_data<ParamType, Context>(), // dw dW->Reshape({C_})->template mutable_data<ParamT, Context>(), // dw
dB->Reshape({C_})->template mutable_data<ParamType, Context>(), // db dB->Reshape({C_})->template mutable_data<ParamT, Context>(), // db
epsilon_, epsilon_,
X_mu->template data<ParamType, Context>(), // mu X_mu->template data<ParamT, Context>(), // mu
X_rsig->template data<ParamType, Context>())); // rsig X_rsig->template data<ParamT, Context>())); // rsig
} }
template <class Context> template <class Context>
......
...@@ -8,9 +8,9 @@ namespace dragon { ...@@ -8,9 +8,9 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void GroupNormOp<Context>::DoRunWithType() { void GroupNormOp<Context>::DoRunWithType() {
using ParamType = typename math::utils::AccmulatorType<T>::type; using ParamT = typename math::utils::AccmulatorType<T>::type;
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamT);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamT);
auto* X_mu = Buffer("X_mu")->Reshape({N_, G_}); auto* X_mu = Buffer("X_mu")->Reshape({N_, G_});
auto* X_rsig = Buffer("X_rsig")->Reshape({N_, G_}); auto* X_rsig = Buffer("X_rsig")->Reshape({N_, G_});
...@@ -18,8 +18,8 @@ void GroupNormOp<Context>::DoRunWithType() { ...@@ -18,8 +18,8 @@ void GroupNormOp<Context>::DoRunWithType() {
auto* X_bias = Buffer("X_bias")->Reshape({N_, C_}); auto* X_bias = Buffer("X_bias")->Reshape({N_, C_});
auto* x = Input(0).template data<T, Context>(); auto* x = Input(0).template data<T, Context>();
auto* mu = X_mu->template mutable_data<ParamType, Context>(); auto* mu = X_mu->template mutable_data<ParamT, Context>();
auto* rsig = X_rsig->template mutable_data<ParamType, Context>(); auto* rsig = X_rsig->template mutable_data<ParamT, Context>();
// Compute the moments // Compute the moments
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
...@@ -45,10 +45,10 @@ void GroupNormOp<Context>::DoRunWithType() { ...@@ -45,10 +45,10 @@ void GroupNormOp<Context>::DoRunWithType() {
x, x,
mu, mu,
rsig, rsig,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta Input(2).template data<ParamT, Context>(), // beta
X_scale->template mutable_data<ParamType, Context>(), X_scale->template mutable_data<ParamT, Context>(),
X_bias->template mutable_data<ParamType, Context>(), X_bias->template mutable_data<ParamT, Context>(),
Output(0)->template mutable_data<T, Context>(), Output(0)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -63,7 +63,7 @@ void GroupNormOp<Context>::RunOnDevice() { ...@@ -63,7 +63,7 @@ void GroupNormOp<Context>::RunOnDevice() {
template <class Context> template <class Context>
template <typename T> template <typename T>
void GroupNormGradientOp<Context>::DoRunWithType() { void GroupNormGradientOp<Context>::DoRunWithType() {
using ParamType = typename math::utils::AccmulatorType<T>::type; using ParamT = typename math::utils::AccmulatorType<T>::type;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig"); auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
...@@ -78,14 +78,14 @@ void GroupNormGradientOp<Context>::DoRunWithType() { ...@@ -78,14 +78,14 @@ void GroupNormGradientOp<Context>::DoRunWithType() {
S_, S_,
data_format(), data_format(),
Input(0).template data<T, Context>(), // x Input(0).template data<T, Context>(), // x
X_mu->template data<ParamType, Context>(), X_mu->template data<ParamT, Context>(),
X_rsig->template data<ParamType, Context>(), X_rsig->template data<ParamT, Context>(),
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<T, Context>(), // dy Input(2).template data<T, Context>(), // dy
X_scale->template mutable_data<ParamType, Context>(), X_scale->template mutable_data<ParamT, Context>(),
X_bias->template mutable_data<ParamType, Context>(), X_bias->template mutable_data<ParamT, Context>(),
dW->Reshape({C_})->template mutable_data<ParamType, Context>(), dW->Reshape({C_})->template mutable_data<ParamT, Context>(),
dB->Reshape({C_})->template mutable_data<ParamType, Context>(), dB->Reshape({C_})->template mutable_data<ParamT, Context>(),
dX->template mutable_data<T, Context>(), dX->template mutable_data<T, Context>(),
ctx()); ctx());
} }
......
...@@ -58,7 +58,7 @@ def dropout(inputs, ratio=0.5, **kwargs): ...@@ -58,7 +58,7 @@ def dropout(inputs, ratio=0.5, **kwargs):
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate() \ .instantiate() \
.apply([inputs], ratio, inplace=inplace) .apply([inputs], args['ratio'], inplace=inplace)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -103,7 +103,7 @@ def drop_block2d(inputs, ratio=0.1, block_size=7, data_format='NCHW', **kwargs): ...@@ -103,7 +103,7 @@ def drop_block2d(inputs, ratio=0.1, block_size=7, data_format='NCHW', **kwargs):
.instantiate( .instantiate(
block_size=block_size, block_size=block_size,
data_format=data_format, data_format=data_format,
).apply([inputs], ratio, inplace=inplace) ).apply([inputs], args['ratio'], inplace=inplace)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -137,7 +137,7 @@ def drop_path(inputs, ratio=0.2, **kwargs): ...@@ -137,7 +137,7 @@ def drop_path(inputs, ratio=0.2, **kwargs):
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate() \ .instantiate() \
.apply([inputs], ratio, inplace=inplace) .apply([inputs], args['ratio'], inplace=inplace)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -205,9 +205,8 @@ def broadcast_to(inputs, shape, **kwargs): ...@@ -205,9 +205,8 @@ def broadcast_to(inputs, shape, **kwargs):
op_lib = array_ops_lib.Expand op_lib = array_ops_lib.Expand
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(ndim=len(args['dims'])) \
ndim=len(args['dims']), .apply([inputs], args['dims'])
).apply([inputs], args['dims'])
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -1163,6 +1162,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs): ...@@ -1163,6 +1162,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs):
return op_lib.blend(**args) return op_lib.blend(**args)
@ArgHelper.desc('limit', as_target=True)
def permutation(limit, dtype='int64', **kwargs): def permutation(limit, dtype='int64', **kwargs):
r"""Return a tensor with value in the permuted range. r"""Return a tensor with value in the permuted range.
...@@ -1174,7 +1174,7 @@ def permutation(limit, dtype='int64', **kwargs): ...@@ -1174,7 +1174,7 @@ def permutation(limit, dtype='int64', **kwargs):
Parameters Parameters
---------- ----------
limit: number limit: Union[number, dragon.Tensor]
The end of interval. The end of interval.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
...@@ -1192,7 +1192,7 @@ def permutation(limit, dtype='int64', **kwargs): ...@@ -1192,7 +1192,7 @@ def permutation(limit, dtype='int64', **kwargs):
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate(dtype=dtype) \ .instantiate(dtype=dtype) \
.apply(limit, trainable=trainable) .apply(args['limit'], trainable=trainable)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -49,10 +49,11 @@ def assign(inputs, starts=None, sizes=None, **kwargs): ...@@ -49,10 +49,11 @@ def assign(inputs, starts=None, sizes=None, **kwargs):
inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype) inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype)
op_lib = control_flow_ops_lib.Assign op_lib = control_flow_ops_lib.Assign
if context.executing_eagerly(): if context.executing_eagerly():
starts = args['starts'] if starts is not None else [0]
sizes = args['sizes'] if sizes is not None else [-1]
return op_lib \ return op_lib \
.instantiate( .instantiate(ndim=len(starts)) \
ndim=len(starts) if starts is not None else 0, .apply(inputs, starts, sizes, inplace=inplace)
).apply(inputs, starts, sizes, inplace=inplace)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -23,6 +23,7 @@ from dragon.core.util import nest ...@@ -23,6 +23,7 @@ from dragon.core.util import nest
@OpSchema.num_inputs(5) @OpSchema.num_inputs(5)
@ArgHelper.desc('momentum', as_target=False)
def batch_norm( def batch_norm(
inputs, inputs,
axis=-1, axis=-1,
...@@ -40,7 +41,8 @@ def batch_norm( ...@@ -40,7 +41,8 @@ def batch_norm(
The running average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}} .. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} +
(1 - \text{momentum}) * x_{\text{batch}}
Parameters Parameters
---------- ----------
...@@ -48,8 +50,8 @@ def batch_norm( ...@@ -48,8 +50,8 @@ def batch_norm(
The tensor ``x``, ``gamma``, ``beta``, ``mean`` and ``var``. The tensor ``x``, ``gamma``, ``beta``, ``mean`` and ``var``.
axis : int, optional, default=-1 axis : int, optional, default=-1
The channel axis. The channel axis.
momentum : float, optional, default=0.9 momentum : Union[float, dragon.Tensor], optional
The momentum for running average. The value to :math:`\text{momentum}`.
epsilon : float, optional, default=1e-5 epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`. The value to :math:`\epsilon`.
use_stats : int, optional, default=-1 use_stats : int, optional, default=-1
...@@ -62,16 +64,15 @@ def batch_norm( ...@@ -62,16 +64,15 @@ def batch_norm(
""" """
args = ArgHelper.parse(locals()) args = ArgHelper.parse(locals())
args['momentum'], args['epsilon'] = float(momentum), float(epsilon) args['epsilon'] = float(epsilon)
op_lib = normalization_ops_lib.BatchNorm op_lib = normalization_ops_lib.BatchNorm
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(
axis=axis, axis=axis,
momentum=args['momentum'],
epsilon=args['epsilon'], epsilon=args['epsilon'],
use_stats=use_stats, use_stats=use_stats,
).apply(inputs) ).apply(inputs, args['momentum'])
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -304,6 +305,7 @@ def local_response_norm( ...@@ -304,6 +305,7 @@ def local_response_norm(
@OpSchema.num_inputs(5) @OpSchema.num_inputs(5)
@ArgHelper.desc('momentum', as_target=False)
def sync_batch_norm( def sync_batch_norm(
inputs, inputs,
axis=-1, axis=-1,
...@@ -322,7 +324,8 @@ def sync_batch_norm( ...@@ -322,7 +324,8 @@ def sync_batch_norm(
The running average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} + (1 - \text{momentum}) * x_{\text{stat}} .. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} +
(1 - \text{momentum}) * x_{\text{batch}}
Parameters Parameters
---------- ----------
...@@ -330,8 +333,8 @@ def sync_batch_norm( ...@@ -330,8 +333,8 @@ def sync_batch_norm(
The tensor ``x``, ``gamma``, ``beta``, ``mean`` and ``var``. The tensor ``x``, ``gamma``, ``beta``, ``mean`` and ``var``.
axis : int, optional, default=-1 axis : int, optional, default=-1
The channel axis. The channel axis.
momentum : float, optional, default=0.9 momentum : Union[float, dragon.Tensor], optional
The momentum for average. The value to :math:`\text{momentum}`.
epsilon : float, optional, default=1e-5 epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`. The value to :math:`\epsilon`.
use_stats : int, optional, default=-1 use_stats : int, optional, default=-1
...@@ -346,7 +349,7 @@ def sync_batch_norm( ...@@ -346,7 +349,7 @@ def sync_batch_norm(
""" """
args = ArgHelper.parse(locals()) args = ArgHelper.parse(locals())
args['momentum'], args['epsilon'] = float(momentum), float(epsilon) args['epsilon'] = float(epsilon)
if process_group is None: if process_group is None:
process_group = distributed.get_group() process_group = distributed.get_group()
if process_group is None: if process_group is None:
...@@ -356,11 +359,10 @@ def sync_batch_norm( ...@@ -356,11 +359,10 @@ def sync_batch_norm(
return op_lib \ return op_lib \
.instantiate( .instantiate(
axis=axis, axis=axis,
momentum=args['momentum'],
epsilon=args['epsilon'], epsilon=args['epsilon'],
use_stats=use_stats, use_stats=use_stats,
process_group=process_group, process_group=process_group,
).apply(inputs) ).apply(inputs, args['momentum'])
else: else:
args.update(process_group.arguments) args.update(process_group.arguments)
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -23,7 +23,6 @@ class BatchNorm(Operator): ...@@ -23,7 +23,6 @@ class BatchNorm(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(BatchNorm, self).__init__(key, dev, **kwargs) super(BatchNorm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
self.momentum = kwargs.get('momentum', 0.9)
self.epsilon = kwargs.get('epsilon', 1e-5) self.epsilon = kwargs.get('epsilon', 1e-5)
self.use_stats = kwargs.get('use_stats', 0) self.use_stats = kwargs.get('use_stats', 0)
if self.use_stats not in (0, 1): if self.use_stats not in (0, 1):
...@@ -34,14 +33,21 @@ class BatchNorm(Operator): ...@@ -34,14 +33,21 @@ class BatchNorm(Operator):
'op_type': 'BatchNorm', 'op_type': 'BatchNorm',
'arguments': { 'arguments': {
'axis': self.axis, 'axis': self.axis,
'momentum': self.momentum,
'epsilon': self.epsilon, 'epsilon': self.epsilon,
'use_stats': self.use_stats, 'use_stats': self.use_stats,
'momentum_desc': '${HANDLE}/momentum',
}, },
} }
def forward(self, inputs): def setup(self, ws, handle, momentum):
return self.dispatch(inputs, [self.alloc()]) self.feed_arg(ws, '%s/momentum' % handle, momentum, 'float32')
def forward(self, inputs, momentum):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.setup(ws, handle, momentum),
)
class GroupNorm(Operator): class GroupNorm(Operator):
......
...@@ -118,6 +118,7 @@ class ArgHelper(object): ...@@ -118,6 +118,7 @@ class ArgHelper(object):
if 'extra_inputs' not in arguments: if 'extra_inputs' not in arguments:
arguments['extra_inputs'] = [] arguments['extra_inputs'] = []
arguments['extra_inputs'] += [arg] arguments['extra_inputs'] += [arg]
if name in arguments:
arguments.pop(name) arguments.pop(name)
arguments[name + '_desc'] = arg.id arguments[name + '_desc'] = arg.id
return arguments return arguments
...@@ -141,5 +142,6 @@ class ArgHelper(object): ...@@ -141,5 +142,6 @@ class ArgHelper(object):
descs.append(ele.id) descs.append(ele.id)
else: else:
descs.append(Tensor.from_value(ele, dtype, 'DescConst').id) descs.append(Tensor.from_value(ele, dtype, 'DescConst').id)
if name in arguments:
arguments.pop(name) arguments.pop(name)
arguments[name + '_descs'] = descs arguments[name + '_descs'] = descs
...@@ -176,9 +176,12 @@ def conv2d_transpose( ...@@ -176,9 +176,12 @@ def conv2d_transpose(
raise ValueError('Unsupported padding algorithm: %s' % padding) raise ValueError('Unsupported padding algorithm: %s' % padding)
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: %s' % data_format) raise ValueError('Unsupported data format: %s' % data_format)
if 'SAME' in padding and output_shape is None:
raise ValueError('Excepted <output_shape> for same padding.')
if output_shape is not None and 'SAME' not in padding: if output_shape is not None and 'SAME' not in padding:
args['padding'] = 'SAME' args['padding'] = 'SAME'
for key in ('kernel_shape', 'strides', 'pads', 'dilations'): for key in ('kernel_shape', 'strides', 'pads', 'dilations'):
if key in args and args[key] is not None:
if key == 'pads': if key == 'pads':
args[key] = _normalize_pads(args[key], 2) args[key] = _normalize_pads(args[key], 2)
else: else:
......
...@@ -26,7 +26,7 @@ def dropout_exporter(op_def, context): ...@@ -26,7 +26,7 @@ def dropout_exporter(op_def, context):
drop_ratio = arg.f drop_ratio = arg.f
elif arg.name == 'prob_desc': elif arg.name == 'prob_desc':
drop_ratio = helper.fetch_argument(op_def, arg, context.ws) drop_ratio = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'ratio', drop_ratio) helper.add_attribute(node, 'ratio', float(drop_ratio))
return node, const_tensors return node, const_tensors
......
...@@ -26,6 +26,9 @@ def batch_norm_exporter(op_def, context): ...@@ -26,6 +26,9 @@ def batch_norm_exporter(op_def, context):
helper.add_attribute(node, 'epsilon', arg.f) helper.add_attribute(node, 'epsilon', arg.f)
elif arg.name == 'momentum': elif arg.name == 'momentum':
helper.add_attribute(node, 'momentum', arg.f) helper.add_attribute(node, 'momentum', arg.f)
elif arg.name == 'momentum_desc':
momentum = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'momentum', float(momentum))
# Weight, bias, running mean and running variance # Weight, bias, running mean and running variance
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]] const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
......
...@@ -123,23 +123,51 @@ CONVERSIONS_DECL float16 To<float16, half>(half val) { ...@@ -123,23 +123,51 @@ CONVERSIONS_DECL float16 To<float16, half>(half val) {
} }
template <> template <>
CONVERSIONS_DECL half To<half, float>(float val) { CONVERSIONS_DECL half To<half, float16>(float16 val) {
return __float2half(val); return __half_raw{val.x};
} }
template <> template <>
CONVERSIONS_DECL half To<half, float16>(float16 val) { CONVERSIONS_DECL half2 To<half2, float16>(float16 val) {
return __half_raw{val.x}; return half2(__half2_raw{val.x, val.x});
} }
template <> template <>
CONVERSIONS_DECL half2 To<half2, float>(float val) { CONVERSIONS_DECL half To<half, float>(float val) {
return __float2half2_rn(val); #if CUDA_VERSION_MIN(9, 2, 0)
return __float2half(val);
#else
#if defined(__CUDA_ARCH__)
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var)))
__half ret;
asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(ret)) : "f"(val));
return ret;
#undef __HALF_TO_US
#else
return To<half>(To<float16>(val));
#endif
#endif
} }
template <> template <>
CONVERSIONS_DECL half2 To<half2, float16>(float16 val) { CONVERSIONS_DECL half2 To<half2, float>(float val) {
return half2(__half2_raw{val.x, val.x}); #if CUDA_VERSION_MIN(9, 2, 0)
return __float2half2_rn(val);
#else
#if defined(__CUDA_ARCH__)
#define __HALF2_TO_UI(var) *(reinterpret_cast<unsigned int*>(&(var)))
__half2 ret;
asm("{.reg .f16 low;\n"
" cvt.rn.f16.f32 low, %1;\n"
" mov.b32 %0, {low,low};}\n"
: "=r"(__HALF2_TO_UI(ret))
: "f"(val));
return ret;
#undef __HALF2_TO_UI
#else
return To<half2>(To<float16>(val));
#endif
#endif
} }
#endif // USE_CUDA #endif // USE_CUDA
......
...@@ -162,23 +162,17 @@ __global__ void _InvStd(const int n, const T eps, const T* x, T* y) { ...@@ -162,23 +162,17 @@ __global__ void _InvStd(const int n, const T eps, const T* x, T* y) {
} }
} }
template <> __global__ void _InvStd(const int n, const float eps, const half* x, half* y) {
__global__ void
_InvStd<half>(const int n, const half eps, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 530 y[i] = __float2half(rsqrt(__half2float(x[i]) + eps));
y[i] = hrsqrt(__hadd(x[i], eps));
#endif
} }
} }
template <>
__global__ void __global__ void
_InvStd<half2>(const int n, const half2 eps, const half2* x, half2* y) { _InvStd(const int n, const float eps, const half2* x, half2* y) {
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 530 const float2 val = __half22float2(x[i]);
y[i] = h2rsqrt(__hadd2(x[i], eps)); y[i] = __floats2half2_rn(rsqrt(val.x + eps), rsqrt(val.y + eps));
#endif
} }
} }
...@@ -206,19 +200,15 @@ __global__ void _Powx(const int n, const T exponent, const T* x, T* y) { ...@@ -206,19 +200,15 @@ __global__ void _Powx(const int n, const T exponent, const T* x, T* y) {
__global__ void __global__ void
_Powx(const int n, const float exponent, const half* x, half* y) { _Powx(const int n, const float exponent, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 530
y[i] = __float2half(pow(__half2float(x[i]), exponent)); y[i] = __float2half(pow(__half2float(x[i]), exponent));
#endif
} }
} }
__global__ void __global__ void
_Powx(const int n, const float exponent, const half2* x, half2* y) { _Powx(const int n, const float exponent, const half2* x, half2* y) {
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 530
const float2 val = __half22float2(x[i]); const float2 val = __half22float2(x[i]);
y[i] = __floats2half2_rn(pow(val.x, exponent), pow(val.y, exponent)); y[i] = __floats2half2_rn(pow(val.x, exponent), pow(val.y, exponent));
#endif
} }
} }
...@@ -269,20 +259,16 @@ __global__ void _Square(const int n, const T* x, T* y) { ...@@ -269,20 +259,16 @@ __global__ void _Square(const int n, const T* x, T* y) {
template <typename T> template <typename T>
__global__ void _NotZero(const int nthreads, const T* x, bool* y) { __global__ void _NotZero(const int nthreads, const T* x, bool* y) {
const T kZero = T(0);
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = x[i] != kZero ? true : false; y[i] = x[i] != T(0) ? true : false;
} }
} }
template <> template <>
__global__ void _NotZero<half>(const int nthreads, const half* x, bool* y) { __global__ void _NotZero<half>(const int nthreads, const half* x, bool* y) {
#if __CUDA_ARCH__ >= 530
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __hne(x[i], kZero) ? true : false; y[i] = __half2float(x[i]) != 0.f ? true : false;
} }
#endif
} }
template <typename T> template <typename T>
...@@ -560,15 +546,12 @@ DRAGON_API void InvStd<float16, CUDAContext>( ...@@ -560,15 +546,12 @@ DRAGON_API void InvStd<float16, CUDAContext>(
if ((n & 1) == 0) { if ((n & 1) == 0) {
_InvStd<<<CUDA_BLOCKS(n >> 1), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _InvStd<<<CUDA_BLOCKS(n >> 1), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n >> 1, n >> 1,
convert::To<half2>(eps), eps,
reinterpret_cast<const half2*>(x), reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y)); reinterpret_cast<half2*>(y));
} else { } else {
_InvStd<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _InvStd<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n, n, eps, reinterpret_cast<const half*>(x), reinterpret_cast<half*>(y));
convert::To<half>(eps),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
} }
} }
......
...@@ -26,7 +26,7 @@ namespace math { ...@@ -26,7 +26,7 @@ namespace math {
template <typename T> template <typename T>
struct MaxFunctor { struct MaxFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ T operator()(const T& lhs, const T& rhs) const { inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs < rhs ? rhs : lhs; return lhs < rhs ? rhs : lhs;
} }
...@@ -39,7 +39,7 @@ struct MaxFunctor { ...@@ -39,7 +39,7 @@ struct MaxFunctor {
template <> template <>
struct MaxFunctor<float16> { struct MaxFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ float16 inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const { operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -62,7 +62,7 @@ struct MaxFunctor<float16> { ...@@ -62,7 +62,7 @@ struct MaxFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct MaxFunctor<half> { struct MaxFunctor<half> {
inline __device__ half operator()(const half& lhs, const half& rhs) const { inline __device__ half operator()(const half& lhs, const half& rhs) const {
...@@ -87,7 +87,7 @@ struct MaxFunctor<half2> { ...@@ -87,7 +87,7 @@ struct MaxFunctor<half2> {
template <typename T> template <typename T>
struct MinFunctor { struct MinFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ T operator()(const T& lhs, const T& rhs) const { inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs < rhs ? lhs : rhs; return lhs < rhs ? lhs : rhs;
} }
...@@ -100,7 +100,7 @@ struct MinFunctor { ...@@ -100,7 +100,7 @@ struct MinFunctor {
template <> template <>
struct MinFunctor<float16> { struct MinFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ float16 inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const { operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -123,7 +123,7 @@ struct MinFunctor<float16> { ...@@ -123,7 +123,7 @@ struct MinFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct MinFunctor<half> { struct MinFunctor<half> {
inline __device__ half operator()(const half& lhs, const half& rhs) const { inline __device__ half operator()(const half& lhs, const half& rhs) const {
...@@ -148,7 +148,7 @@ struct MinFunctor<half2> { ...@@ -148,7 +148,7 @@ struct MinFunctor<half2> {
template <typename T> template <typename T>
struct PlusFunctor { struct PlusFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ T operator()(const T& lhs, const T& rhs) const { inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs + rhs; return lhs + rhs;
} }
...@@ -161,7 +161,7 @@ struct PlusFunctor { ...@@ -161,7 +161,7 @@ struct PlusFunctor {
template <> template <>
struct PlusFunctor<float16> { struct PlusFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ float16 inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const { operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -183,7 +183,7 @@ struct PlusFunctor<float16> { ...@@ -183,7 +183,7 @@ struct PlusFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct PlusFunctor<half> { struct PlusFunctor<half> {
inline __device__ half operator()(const half& lhs, const half& rhs) const { inline __device__ half operator()(const half& lhs, const half& rhs) const {
...@@ -211,7 +211,7 @@ struct PlusFunctor<half2> { ...@@ -211,7 +211,7 @@ struct PlusFunctor<half2> {
template <typename T> template <typename T>
struct MinusFunctor { struct MinusFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ T operator()(const T& lhs, const T& rhs) const { inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs - rhs; return lhs - rhs;
} }
...@@ -224,7 +224,7 @@ struct MinusFunctor { ...@@ -224,7 +224,7 @@ struct MinusFunctor {
template <> template <>
struct MinusFunctor<float16> { struct MinusFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ float16 inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const { operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -246,7 +246,7 @@ struct MinusFunctor<float16> { ...@@ -246,7 +246,7 @@ struct MinusFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct MinusFunctor<half> { struct MinusFunctor<half> {
inline __device__ half operator()(const half& lhs, const half& rhs) const { inline __device__ half operator()(const half& lhs, const half& rhs) const {
...@@ -274,7 +274,7 @@ struct MinusFunctor<half2> { ...@@ -274,7 +274,7 @@ struct MinusFunctor<half2> {
template <typename T> template <typename T>
struct MultipliesFunctor { struct MultipliesFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ T operator()(const T& lhs, const T& rhs) const { inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs * rhs; return lhs * rhs;
} }
...@@ -287,7 +287,7 @@ struct MultipliesFunctor { ...@@ -287,7 +287,7 @@ struct MultipliesFunctor {
template <> template <>
struct MultipliesFunctor<float16> { struct MultipliesFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ float16 inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const { operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -309,7 +309,7 @@ struct MultipliesFunctor<float16> { ...@@ -309,7 +309,7 @@ struct MultipliesFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct MultipliesFunctor<half> { struct MultipliesFunctor<half> {
inline __device__ half operator()(const half& lhs, const half& rhs) const { inline __device__ half operator()(const half& lhs, const half& rhs) const {
...@@ -337,7 +337,7 @@ struct MultipliesFunctor<half2> { ...@@ -337,7 +337,7 @@ struct MultipliesFunctor<half2> {
template <typename T> template <typename T>
struct DividesFunctor { struct DividesFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ T operator()(const T& lhs, const T& rhs) const { inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs / rhs; return lhs / rhs;
} }
...@@ -350,7 +350,7 @@ struct DividesFunctor { ...@@ -350,7 +350,7 @@ struct DividesFunctor {
template <> template <>
struct DividesFunctor<float16> { struct DividesFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ float16 inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const { operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -372,7 +372,7 @@ struct DividesFunctor<float16> { ...@@ -372,7 +372,7 @@ struct DividesFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct DividesFunctor<half> { struct DividesFunctor<half> {
inline __device__ half operator()(const half& lhs, const half& rhs) const { inline __device__ half operator()(const half& lhs, const half& rhs) const {
...@@ -396,7 +396,7 @@ struct DividesFunctor<half2> { ...@@ -396,7 +396,7 @@ struct DividesFunctor<half2> {
template <typename T> template <typename T>
struct PowFunctor { struct PowFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ T operator()(const T& lhs, const T& rhs) const { inline __device__ T operator()(const T& lhs, const T& rhs) const {
return pow(lhs, rhs); return pow(lhs, rhs);
} }
...@@ -409,7 +409,7 @@ struct PowFunctor { ...@@ -409,7 +409,7 @@ struct PowFunctor {
template <> template <>
struct PowFunctor<float16> { struct PowFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ float16 inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const { operator()(const float16& lhs, const float16& rhs) const {
half ret = __float2half( half ret = __float2half(
...@@ -425,7 +425,7 @@ struct PowFunctor<float16> { ...@@ -425,7 +425,7 @@ struct PowFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct PowFunctor<half> { struct PowFunctor<half> {
inline __device__ half operator()(const half& lhs, const half& rhs) const { inline __device__ half operator()(const half& lhs, const half& rhs) const {
...@@ -449,7 +449,7 @@ struct PowFunctor<half2> { ...@@ -449,7 +449,7 @@ struct PowFunctor<half2> {
template <typename T> template <typename T>
struct EqualFunctor { struct EqualFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const T& lhs, const T& rhs) const { inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs == rhs; return lhs == rhs;
} }
...@@ -462,7 +462,7 @@ struct EqualFunctor { ...@@ -462,7 +462,7 @@ struct EqualFunctor {
template <> template <>
struct EqualFunctor<float16> { struct EqualFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const float16& lhs, const float16& rhs) inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const { const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -481,7 +481,7 @@ struct EqualFunctor<float16> { ...@@ -481,7 +481,7 @@ struct EqualFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct EqualFunctor<half> { struct EqualFunctor<half> {
inline __device__ bool operator()(const half& lhs, const half& rhs) const { inline __device__ bool operator()(const half& lhs, const half& rhs) const {
...@@ -496,7 +496,7 @@ struct EqualFunctor<half> { ...@@ -496,7 +496,7 @@ struct EqualFunctor<half> {
template <typename T> template <typename T>
struct NotEqualFunctor { struct NotEqualFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const T& lhs, const T& rhs) const { inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs != rhs; return lhs != rhs;
} }
...@@ -509,7 +509,7 @@ struct NotEqualFunctor { ...@@ -509,7 +509,7 @@ struct NotEqualFunctor {
template <> template <>
struct NotEqualFunctor<float16> { struct NotEqualFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const float16& lhs, const float16& rhs) inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const { const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -528,7 +528,7 @@ struct NotEqualFunctor<float16> { ...@@ -528,7 +528,7 @@ struct NotEqualFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct NotEqualFunctor<half> { struct NotEqualFunctor<half> {
inline __device__ bool operator()(const half& lhs, const half& rhs) const { inline __device__ bool operator()(const half& lhs, const half& rhs) const {
...@@ -543,7 +543,7 @@ struct NotEqualFunctor<half> { ...@@ -543,7 +543,7 @@ struct NotEqualFunctor<half> {
template <typename T> template <typename T>
struct GreaterFunctor { struct GreaterFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const T& lhs, const T& rhs) const { inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs > rhs; return lhs > rhs;
} }
...@@ -556,7 +556,7 @@ struct GreaterFunctor { ...@@ -556,7 +556,7 @@ struct GreaterFunctor {
template <> template <>
struct GreaterFunctor<float16> { struct GreaterFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const float16& lhs, const float16& rhs) inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const { const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -575,7 +575,7 @@ struct GreaterFunctor<float16> { ...@@ -575,7 +575,7 @@ struct GreaterFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct GreaterFunctor<half> { struct GreaterFunctor<half> {
inline __device__ bool operator()(const half& lhs, const half& rhs) const { inline __device__ bool operator()(const half& lhs, const half& rhs) const {
...@@ -590,7 +590,7 @@ struct GreaterFunctor<half> { ...@@ -590,7 +590,7 @@ struct GreaterFunctor<half> {
template <typename T> template <typename T>
struct LessFunctor { struct LessFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const T& lhs, const T& rhs) const { inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs < rhs; return lhs < rhs;
} }
...@@ -603,7 +603,7 @@ struct LessFunctor { ...@@ -603,7 +603,7 @@ struct LessFunctor {
template <> template <>
struct LessFunctor<float16> { struct LessFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const float16& lhs, const float16& rhs) inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const { const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -622,7 +622,7 @@ struct LessFunctor<float16> { ...@@ -622,7 +622,7 @@ struct LessFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct LessFunctor<half> { struct LessFunctor<half> {
inline __device__ bool operator()(const half& lhs, const half& rhs) const { inline __device__ bool operator()(const half& lhs, const half& rhs) const {
...@@ -637,7 +637,7 @@ struct LessFunctor<half> { ...@@ -637,7 +637,7 @@ struct LessFunctor<half> {
template <typename T> template <typename T>
struct GreaterEqualFunctor { struct GreaterEqualFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const T& lhs, const T& rhs) const { inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs >= rhs; return lhs >= rhs;
} }
...@@ -650,7 +650,7 @@ struct GreaterEqualFunctor { ...@@ -650,7 +650,7 @@ struct GreaterEqualFunctor {
template <> template <>
struct GreaterEqualFunctor<float16> { struct GreaterEqualFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const float16& lhs, const float16& rhs) inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const { const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -669,7 +669,7 @@ struct GreaterEqualFunctor<float16> { ...@@ -669,7 +669,7 @@ struct GreaterEqualFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct GreaterEqualFunctor<half> { struct GreaterEqualFunctor<half> {
inline __device__ bool operator()(const half& lhs, const half& rhs) const { inline __device__ bool operator()(const half& lhs, const half& rhs) const {
...@@ -684,7 +684,7 @@ struct GreaterEqualFunctor<half> { ...@@ -684,7 +684,7 @@ struct GreaterEqualFunctor<half> {
template <typename T> template <typename T>
struct LessEqualFunctor { struct LessEqualFunctor {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const T& lhs, const T& rhs) const { inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs <= rhs; return lhs <= rhs;
} }
...@@ -697,7 +697,7 @@ struct LessEqualFunctor { ...@@ -697,7 +697,7 @@ struct LessEqualFunctor {
template <> template <>
struct LessEqualFunctor<float16> { struct LessEqualFunctor<float16> {
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
inline __device__ bool operator()(const float16& lhs, const float16& rhs) inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const { const {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
...@@ -716,7 +716,7 @@ struct LessEqualFunctor<float16> { ...@@ -716,7 +716,7 @@ struct LessEqualFunctor<float16> {
#endif #endif
}; };
#if defined(__CUDACC__) #if defined(__CUDA_ARCH__)
template <> template <>
struct LessEqualFunctor<half> { struct LessEqualFunctor<half> {
inline __device__ bool operator()(const half& lhs, const half& rhs) const { inline __device__ bool operator()(const half& lhs, const half& rhs) const {
......
...@@ -239,8 +239,8 @@ void ReduceSum<float16, CUDAContext>( ...@@ -239,8 +239,8 @@ void ReduceSum<float16, CUDAContext>(
num_axes, \ num_axes, \
axes, \ axes, \
Reducer<AccT>(), \ Reducer<AccT>(), \
AccT(kInit), \ convert::To<AccT>(kInit), \
AccT(scale), \ convert::To<AccT>(scale), \
x, \ x, \
y, \ y, \
ctx); \ ctx); \
......
...@@ -301,16 +301,16 @@ void ChannelAffine( ...@@ -301,16 +301,16 @@ void ChannelAffine(
/* array.channel_normalize */ /* array.channel_normalize */
template <typename Tx, typename Ty, class Context> template <typename InputT, typename OutputT, class Context>
void ChannelNormalize( void ChannelNormalize(
const int axis, const int axis,
const int num_dims, const int num_dims,
const int64_t* x_strides, const int64_t* x_strides,
const int64_t* y_dims, const int64_t* y_dims,
const Tx* x, const InputT* x,
const float* mean, const float* mean,
const float* std, const float* std,
Ty* y, OutputT* y,
Context* ctx); Context* ctx);
/* array.channel_shuffle */ /* array.channel_shuffle */
...@@ -648,28 +648,28 @@ void BroadcastLossGrad( ...@@ -648,28 +648,28 @@ void BroadcastLossGrad(
/* loss.nll_loss */ /* loss.nll_loss */
template <typename LogitType, typename TargetType, class Context> template <typename LogitT, typename TargetT, class Context>
void NLLLoss( void NLLLoss(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask, LogitT* mask,
Context* ctx); Context* ctx);
template <typename LogitType, typename TargetType, class Context> template <typename LogitT, typename TargetT, class Context>
void NLLLossGrad( void NLLLossGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* dlogit, LogitT* dlogit,
LogitType* mask, LogitT* mask,
Context* ctx); Context* ctx);
/* loss.sigmoid_ce_loss */ /* loss.sigmoid_ce_loss */
...@@ -694,7 +694,7 @@ void SigmoidCrossEntropyGrad( ...@@ -694,7 +694,7 @@ void SigmoidCrossEntropyGrad(
/* loss.sigmoid_focal_loss */ /* loss.sigmoid_focal_loss */
template <typename LogitType, typename TargetType, class Context> template <typename LogitT, typename TargetT, class Context>
void SigmoidFocalLoss( void SigmoidFocalLoss(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
...@@ -703,13 +703,13 @@ void SigmoidFocalLoss( ...@@ -703,13 +703,13 @@ void SigmoidFocalLoss(
const float neg_alpha, const float neg_alpha,
const float gamma, const float gamma,
const int negative_index, const int negative_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask, LogitT* mask,
Context* ctx); Context* ctx);
template <typename LogitType, typename TargetType, class Context> template <typename LogitT, typename TargetT, class Context>
void SigmoidFocalLossGrad( void SigmoidFocalLossGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
...@@ -718,10 +718,10 @@ void SigmoidFocalLossGrad( ...@@ -718,10 +718,10 @@ void SigmoidFocalLossGrad(
const float neg_alpha, const float neg_alpha,
const float gamma, const float gamma,
const int negative_index, const int negative_index,
const LogitType* logit, const LogitT* logit,
const TargetType* target, const TargetT* target,
LogitType* dlogit, LogitT* dlogit,
LogitType* mask, LogitT* mask,
Context* ctx); Context* ctx);
/* loss.smooth_l1_loss */ /* loss.smooth_l1_loss */
...@@ -754,28 +754,28 @@ void SoftmaxCrossEntropy( ...@@ -754,28 +754,28 @@ void SoftmaxCrossEntropy(
/* loss.sparse_softmax_cross_entropy */ /* loss.sparse_softmax_cross_entropy */
template <typename LogitType, typename TargetType, class Context> template <typename LogitT, typename TargetT, class Context>
void SparseSoftmaxCrossEntropy( void SparseSoftmaxCrossEntropy(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitT* prob,
const TargetType* target, const TargetT* target,
LogitType* loss, LogitT* loss,
LogitType* mask, LogitT* mask,
Context* ctx); Context* ctx);
template <typename LogitType, typename TargetType, class Context> template <typename LogitT, typename TargetT, class Context>
void SparseSoftmaxCrossEntropyGrad( void SparseSoftmaxCrossEntropyGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int ignore_index, const int ignore_index,
const LogitType* prob, const LogitT* prob,
const TargetType* target, const TargetT* target,
LogitType* dx, LogitT* dx,
LogitType* mask, LogitT* mask,
Context* ctx); Context* ctx);
/* math.abs */ /* math.abs */
......
...@@ -55,7 +55,7 @@ class BatchNormalization(Layer): ...@@ -55,7 +55,7 @@ class BatchNormalization(Layer):
axis : int, optional, default=-1 axis : int, optional, default=-1
The channel axis. The channel axis.
momentum : float, optional, default=0.99 momentum : float, optional, default=0.99
The momentum of moving average. The decay factor of running average.
epsilon : float, optional, default=1e-3 epsilon : float, optional, default=1e-3
The epsilon value. The epsilon value.
center : bool, optional, default=True center : bool, optional, default=True
......
...@@ -41,8 +41,8 @@ def batch_normalization( ...@@ -41,8 +41,8 @@ def batch_normalization(
The moving average of stats are calculated as: The moving average of stats are calculated as:
.. math:: .. math:: x_{\text{running}} = \text{momentum} * x_{\text{running}} +
x_{moving} \leftarrow momentum * x_{moving} + (1 - momentum) * x_{stat} (1 - \text{momentum}) * x_{\text{batch}}
Parameters Parameters
---------- ----------
...@@ -58,10 +58,10 @@ def batch_normalization( ...@@ -58,10 +58,10 @@ def batch_normalization(
The :math:`\gamma` tensor. The :math:`\gamma` tensor.
axis : int, optional, default=-1 axis : int, optional, default=-1
The channel axis. The channel axis.
momentum : float, optional, default=0.9 momentum : Union[float, dragon.Tensor], optional
The momentum of moving average. The value to :math:`\text{momentum}`.
variance_epsilon : float, optional, default=1e-5 variance_epsilon : float, optional, default=1e-5
The value of epsilon. The value to :math:`\epsilon`.
trainable : bool, optional, default=False trainable : bool, optional, default=False
The optional training flag. The optional training flag.
name : str, optional name : str, optional
......
...@@ -50,7 +50,7 @@ class BatchNorm(layer.Layer): ...@@ -50,7 +50,7 @@ class BatchNorm(layer.Layer):
Parameters Parameters
---------- ----------
decay : float, optional, default=0.9 decay : float, optional, default=0.9
The decay factor for moving average. The decay factor of running average.
epsilon : float, optional, default=1e-5 epsilon : float, optional, default=1e-5
The epsilon. The epsilon.
act : callable, optional act : callable, optional
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Code generator for Runtime API.""" """Code generator for Runtime API."""
from __future__ import absolute_import from __future__ import absolute_import
......
...@@ -89,7 +89,8 @@ def batch_norm( ...@@ -89,7 +89,8 @@ def batch_norm(
The moving average of stats are calculated as: The moving average of stats are calculated as:
.. math:: x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{stat} .. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} +
\text{momentum} * x_{\text{batch}}
Parameters Parameters
---------- ----------
...@@ -124,9 +125,9 @@ def batch_norm( ...@@ -124,9 +125,9 @@ def batch_norm(
.instantiate( .instantiate(
input.device, input.device,
training=training, training=training,
momentum=momentum,
epsilon=eps, epsilon=eps,
).apply(input, running_mean, running_var, weight, bias) ).apply(input, running_mean, running_var,
weight, bias, momentum)
def binary_cross_entropy_with_logits( def binary_cross_entropy_with_logits(
...@@ -1598,7 +1599,7 @@ def sync_batch_norm( ...@@ -1598,7 +1599,7 @@ def sync_batch_norm(
The moving average of stats are calculated as: The moving average of stats are calculated as:
.. math:: .. math::
x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{stat} x_{moving} \leftarrow (1 - momentum) * x_{moving} + momentum * x_{\text{batch}}
Additionally, you can specify ``process_group`` to perform synchronization. Additionally, you can specify ``process_group`` to perform synchronization.
......
...@@ -111,24 +111,31 @@ class BatchNorm(function.Function): ...@@ -111,24 +111,31 @@ class BatchNorm(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(BatchNorm, self).__init__(key, dev, **kwargs) super(BatchNorm, self).__init__(key, dev, **kwargs)
self.momentum = kwargs.get('momentum', 0.1)
self.epsilon = kwargs.get('epsilon', 1e-5) self.epsilon = kwargs.get('epsilon', 1e-5)
self.training = kwargs.get('training', False) self.training = kwargs.get('training', False)
self.track_stats = kwargs.get('track_stats', True)
def setup(self, ws, handle, momentum):
self.feed_arg(ws, '{}/momentum'.format(handle), 1.0 - momentum, 'float32')
def attributes(self): def attributes(self):
return { return {
'op_type': 'BatchNorm', 'op_type': 'BatchNorm',
'arguments': { 'arguments': {
'axis': 1, 'axis': 1,
'momentum': 1. - self.momentum,
'epsilon': self.epsilon, 'epsilon': self.epsilon,
'use_stats': int(not self.training), 'use_stats': int(not self.training),
'momentum_desc': '${HANDLE}/momentum',
} }
} }
def forward(self, input, running_mean, running_var, weight, bias): def forward(self, input, running_mean, running_var, weight, bias, momentum):
inputs = [input, weight, bias, running_mean, running_var] inputs = [input, weight, bias, running_mean, running_var]
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.setup(ws, handle, momentum),
)
class Conv2d(_ConvNd): class Conv2d(_ConvNd):
......
...@@ -25,6 +25,8 @@ from dragon.vm.torch.core.tensor import Tensor ...@@ -25,6 +25,8 @@ from dragon.vm.torch.core.tensor import Tensor
class _BatchNorm(Module): class _BatchNorm(Module):
"""BatchNorm base module."""
def __init__( def __init__(
self, self,
num_features, num_features,
...@@ -45,20 +47,26 @@ class _BatchNorm(Module): ...@@ -45,20 +47,26 @@ class _BatchNorm(Module):
else: else:
self.register_buffer('weight', init_funcs.ones(num_features)) self.register_buffer('weight', init_funcs.ones(num_features))
self.register_buffer('bias', init_funcs.zeros(num_features)) self.register_buffer('bias', init_funcs.zeros(num_features))
if self.track_running_stats:
self.num_batches_tracked = 0
else:
self.num_batches_tracked = None
self.register_buffer('running_mean', init_funcs.zeros(num_features)) self.register_buffer('running_mean', init_funcs.zeros(num_features))
self.register_buffer('running_var', init_funcs.ones(num_features)) self.register_buffer('running_var', init_funcs.ones(num_features))
self.inputs = [self.running_mean, self.running_var, self.weight, self.bias] self.inputs = [self.running_mean, self.running_var, self.weight, self.bias]
self.reset_parameters() self.reset_parameters()
def reset_parameters(self):
if self.affine:
self.weight.data.one_()
self.bias.data.zero_()
def reset_running_stats(self): def reset_running_stats(self):
if self.track_running_stats: if self.track_running_stats:
self.running_mean.zero_() self.running_mean.zero_()
self.running_var.fill_(1) self.running_var.fill_(1)
self.num_batches_tracked = 0
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
self.weight.data.one_()
self.bias.data.zero_()
def extra_repr(self): def extra_repr(self):
return '{num_features}, ' \ return '{num_features}, ' \
...@@ -72,7 +80,7 @@ class _BatchNorm(Module): ...@@ -72,7 +80,7 @@ class _BatchNorm(Module):
return F.batch_norm( return F.batch_norm(
input, *self.inputs, input, *self.inputs,
training=self.training, training=self.training,
momentum=self.momentum, momentum=self._get_momentum(),
eps=self.eps eps=self.eps
) )
...@@ -82,6 +90,19 @@ class _BatchNorm(Module): ...@@ -82,6 +90,19 @@ class _BatchNorm(Module):
return self # Float32 parameters are required. return self # Float32 parameters are required.
return super(_BatchNorm, self)._apply(fn) return super(_BatchNorm, self)._apply(fn)
def _get_momentum(self):
"""Return the current momentum value."""
momentum = 0.0 if self.momentum is None else self.momentum
if self.track_running_stats:
if self.training:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None:
momentum = 1.0 / float(self.num_batches_tracked)
else:
momentum = 0.0
return momentum
class BatchNorm1d(_BatchNorm): class BatchNorm1d(_BatchNorm):
r"""Apply the batch normalization over 2d input. r"""Apply the batch normalization over 2d input.
...@@ -93,7 +114,8 @@ class BatchNorm1d(_BatchNorm): ...@@ -93,7 +114,8 @@ class BatchNorm1d(_BatchNorm):
The running average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}} .. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} +
\text{momentum} * x_{\text{batch}}
See Also See Also
-------- --------
...@@ -109,16 +131,16 @@ class BatchNorm1d(_BatchNorm): ...@@ -109,16 +131,16 @@ class BatchNorm1d(_BatchNorm):
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
): ):
"""Create a ``BatchNorm1d`` module. r"""Create a ``BatchNorm1d`` module.
Parameters Parameters
---------- ----------
num_features : int num_features : int
The number of channels. The number of channels.
eps : float, optional, default=1e-5 eps : float, optional, default=1e-5
The epsilon value. The value to :math:`\epsilon`.
momentum : float, optional, default=0.1 momentum : float, optional, default=0.1
The momentum of moving average. The value to :math:`\text{momentum}`.
affine : bool, optional, default=True affine : bool, optional, default=True
**True** to apply a affine transformation. **True** to apply a affine transformation.
track_running_stats : bool, optional, default=True track_running_stats : bool, optional, default=True
...@@ -142,7 +164,8 @@ class BatchNorm2d(_BatchNorm): ...@@ -142,7 +164,8 @@ class BatchNorm2d(_BatchNorm):
The running average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}} .. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} +
\text{momentum} * x_{\text{batch}}
See Also See Also
-------- --------
...@@ -158,16 +181,16 @@ class BatchNorm2d(_BatchNorm): ...@@ -158,16 +181,16 @@ class BatchNorm2d(_BatchNorm):
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
): ):
"""Create a ``BatchNorm2d`` module. r"""Create a ``BatchNorm2d`` module.
Parameters Parameters
---------- ----------
num_features : int num_features : int
The number of channels. The number of channels.
eps : float, optional, default=1e-5 eps : float, optional, default=1e-5
The epsilon value. The value to :math:`\epsilon`.
momentum : float, optional, default=0.1 momentum : float, optional, default=0.1
The momentum of moving average. The value to :math:`\text{momentum}`.
affine : bool, optional, default=True affine : bool, optional, default=True
**True** to apply a affine transformation. **True** to apply a affine transformation.
track_running_stats : bool, optional, default=True track_running_stats : bool, optional, default=True
...@@ -191,7 +214,8 @@ class BatchNorm3d(_BatchNorm): ...@@ -191,7 +214,8 @@ class BatchNorm3d(_BatchNorm):
The running average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}} .. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} +
\text{momentum} * x_{\text{batch}}
See Also See Also
-------- --------
...@@ -207,16 +231,16 @@ class BatchNorm3d(_BatchNorm): ...@@ -207,16 +231,16 @@ class BatchNorm3d(_BatchNorm):
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
): ):
"""Create a ``BatchNorm3d`` module. r"""Create a ``BatchNorm3d`` module.
Parameters Parameters
---------- ----------
num_features : int num_features : int
The number of channels. The number of channels.
eps : float, optional, default=1e-5 eps : float, optional, default=1e-5
The epsilon value. The value to :math:`\epsilon`.
momentum : float, optional, default=0.1 momentum : float, optional, default=0.1
The momentum of moving average. The value to :math:`\text{momentum}`.
affine : bool, optional, default=True affine : bool, optional, default=True
**True** to apply a affine transformation. **True** to apply a affine transformation.
track_running_stats : bool, optional, default=True track_running_stats : bool, optional, default=True
...@@ -240,7 +264,8 @@ class SyncBatchNorm(_BatchNorm): ...@@ -240,7 +264,8 @@ class SyncBatchNorm(_BatchNorm):
The running average of statistics are calculated as: The running average of statistics are calculated as:
.. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} + \text{momentum} * x_{\text{stat}} .. math:: x_{\text{running}} = (1 - \text{momentum}) * x_{\text{running}} +
\text{momentum} * x_{\text{batch}}
Additionally, specify ``process_group`` to perform synchronization. Additionally, specify ``process_group`` to perform synchronization.
...@@ -261,16 +286,16 @@ class SyncBatchNorm(_BatchNorm): ...@@ -261,16 +286,16 @@ class SyncBatchNorm(_BatchNorm):
track_running_stats=True, track_running_stats=True,
process_group=None, process_group=None,
): ):
"""Create a ``SyncBatchNorm`` module. r"""Create a ``SyncBatchNorm`` module.
Parameters Parameters
---------- ----------
num_features : int num_features : int
The number of channels. The number of channels.
eps : float, optional, default=1e-5 eps : float, optional, default=1e-5
The epsilon value. The value to :math:`\epsilon`.
momentum : float, optional, default=0.1 momentum : float, optional, default=0.1
The momentum of moving average. The value to :math:`\text{momentum}`.
affine : bool, optional, default=True affine : bool, optional, default=True
**True** to apply a affine transformation. **True** to apply a affine transformation.
track_running_stats : bool, optional, default=True track_running_stats : bool, optional, default=True
...@@ -292,7 +317,7 @@ class SyncBatchNorm(_BatchNorm): ...@@ -292,7 +317,7 @@ class SyncBatchNorm(_BatchNorm):
return F.sync_batch_norm( return F.sync_batch_norm(
input, *self.inputs, input, *self.inputs,
training=self.training, training=self.training,
momentum=self.momentum, momentum=self._get_momentum(),
eps=self.eps, eps=self.eps,
process_group=self.process_group process_group=self.process_group
) )
...@@ -300,6 +325,6 @@ class SyncBatchNorm(_BatchNorm): ...@@ -300,6 +325,6 @@ class SyncBatchNorm(_BatchNorm):
return F.batch_norm( return F.batch_norm(
input, *self.inputs, input, *self.inputs,
training=self.training, training=self.training,
momentum=self.momentum, momentum=self._get_momentum(),
eps=self.eps eps=self.eps
) )
...@@ -61,7 +61,7 @@ class AffineChannel(Module): ...@@ -61,7 +61,7 @@ class AffineChannel(Module):
fix_bias=False, fix_bias=False,
inplace=False, inplace=False,
): ):
"""Create an ``Affine`` module. """Create an ``AffineChannel`` module.
Parameters Parameters
---------- ----------
...@@ -141,7 +141,7 @@ class GroupNorm(Module): ...@@ -141,7 +141,7 @@ class GroupNorm(Module):
eps=1e-5, eps=1e-5,
affine=True, affine=True,
): ):
"""Create a ``GroupNorm`` module. r"""Create a ``GroupNorm`` module.
Parameters Parameters
---------- ----------
...@@ -150,7 +150,7 @@ class GroupNorm(Module): ...@@ -150,7 +150,7 @@ class GroupNorm(Module):
num_channels : int num_channels : int
The number of channels. The number of channels.
eps : float, optional, default=1e-5 eps : float, optional, default=1e-5
The epsilon value. The value to :math:`\epsilon`.
affine : bool, optional, default=True affine : bool, optional, default=True
**True** to apply a affine transformation. **True** to apply a affine transformation.
...@@ -228,11 +228,11 @@ class LocalResponseNorm(Module): ...@@ -228,11 +228,11 @@ class LocalResponseNorm(Module):
size : int, required size : int, required
The number of neighbouring channels to sum over. The number of neighbouring channels to sum over.
alpha : float, optional, default=0.0001 alpha : float, optional, default=0.0001
The scale value :math:`\alpha`. The value to :math:`\alpha`.
beta : float, optional, default=0.75 beta : float, optional, default=0.75
The exponent value :math:`\beta`. The value to :math:`\beta`.
k : float, optional, default=1. k : float, optional, default=1.
The bias constant :math:`k`. The value to :math:`k`.
""" """
super(LocalResponseNorm, self).__init__() super(LocalResponseNorm, self).__init__()
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!