Commit 746f2cbb by Ting PAN

Add FP16 support for DepthwiseConv2d && SyncBN Operator

Summary:
This commit adds pseudo FP16 kernels with FP32 conversions
for DepthwiseConv2d and SyncBN operator.
1 parent d56e67d1
Showing with 1962 additions and 2407 deletions
......@@ -81,6 +81,9 @@ dragon
`function(...) <dragon/function.html>`_
: Compile a function and return an executable.
`get_num_threads(...) <dragon/get_num_threads.html>`_
: Return the number of threads for cpu parallelism.
`get_workspace(...) <dragon/get_workspace.html>`_
: Return the current default workspace.
......@@ -138,6 +141,9 @@ dragon
`reshape(...) <dragon/reshape.html>`_
: Change the dimensions of input.
`set_num_threads(...) <dragon/set_num_threads.html>`_
: Set the number of threads for cpu parallelism.
`shape(...) <dragon/shape.html>`_
: Return the shape of input.
......@@ -204,6 +210,7 @@ dragon
dragon/fill
dragon/flatten
dragon/function
dragon/get_num_threads
dragon/get_workspace
dragon/gradients
dragon/graph_mode
......@@ -223,6 +230,7 @@ dragon
dragon/repeat
dragon/reset_workspace
dragon/reshape
dragon/set_num_threads
dragon/shape
dragon/slice
dragon/sort
......
get_num_threads
===============
.. autofunction:: dragon.get_num_threads
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
set_num_threads
===============
.. autofunction:: dragon.set_num_threads
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
......@@ -18,7 +18,7 @@
#include "dragon/core/operator_schema.h"
#include "dragon/core/registry.h"
#include "dragon/core/tensor.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
namespace dragon {
......
......@@ -19,6 +19,11 @@
#include "dragon/core/typeid.h"
#ifndef HFLT_MAX
#define HFLT_MAX 65504.F
#define HFLT_MIN 6.10e-5F
#endif
namespace dragon {
typedef std::vector<int> vec32_t;
......
......@@ -34,7 +34,7 @@ void _DropBlock2dNCHW(
}
} // Share the mask between channels
}
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
......@@ -65,7 +65,7 @@ void _DropBlock2dNHWC(
}
} // Share the mask between channels
}
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
......
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -86,7 +85,7 @@ void DropPath<float16, CUDAContext>(
const auto nthreads = rows * cols; \
const auto thresh = 1.f - (1.f / scale); \
_DropPath<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, thresh, cast::to<T>(scale), x, mask, y); \
nthreads, cols, thresh, convert::To<T>(scale), x, mask, y); \
}
DEFINE_KERNEL_LAUNCHER(float);
......
#include "dragon/utils/cast.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -72,7 +71,7 @@ void _Dropout<float16>(
const uint8_t* mask, \
T* y, \
CPUContext* ctx) { \
_ApplyMask(count, cast::to<T>(scale), x, mask, y); \
_ApplyMask(count, convert::To<T>(scale), x, mask, y); \
} \
template <> \
void Dropout<T, CPUContext>( \
......@@ -84,7 +83,8 @@ void _Dropout<float16>(
T* y, \
uint32_t* r, \
CPUContext* ctx) { \
_Dropout(count, cast::to<T>(ratio), cast::to<T>(scale), x, mask, y, ctx); \
_Dropout( \
count, convert::To<T>(ratio), convert::To<T>(scale), x, mask, y, ctx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -113,7 +112,7 @@ void Dropout<float16, CUDAContext>(
T* y, \
CUDAContext* ctx) { \
_ApplyMask<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, cast::to<T>(scale), x, mask, y); \
count, convert::To<T>(scale), x, mask, y); \
} \
template <> \
void Dropout<T, CUDAContext>( \
......@@ -128,7 +127,7 @@ void Dropout<float16, CUDAContext>(
math::Random(count, r, ctx); \
auto threshold = static_cast<uint32_t>(UINT_MAX * ratio); \
_Dropout<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, threshold, cast::to<T>(scale), x, r, mask, y); \
count, threshold, convert::To<T>(scale), x, r, mask, y); \
}
DEFINE_KERNEL_LAUNCHER(float);
......
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -50,7 +50,7 @@ void _EluGrad<float16>(
template <> \
void Elu<T, CPUContext>( \
const int count, const float alpha, const T* x, T* y, CPUContext* ctx) { \
_Elu(count, cast::to<T>(alpha), x, y); \
_Elu(count, convert::To<T>(alpha), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -62,7 +62,7 @@ void _EluGrad<float16>(
const T* y, \
T* dx, \
CPUContext* ctx) { \
_EluGrad(count, cast::to<T>(alpha), dy, y, dx); \
_EluGrad(count, convert::To<T>(alpha), dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
......
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -65,7 +65,7 @@ void _HardSigmoidGrad<float16>(
const T* x, \
T* y, \
CPUContext* ctx) { \
_HardSigmoid(count, cast::to<T>(alpha), cast::to<T>(beta), x, y); \
_HardSigmoid(count, convert::To<T>(alpha), convert::To<T>(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -77,7 +77,7 @@ void _HardSigmoidGrad<float16>(
const T* y, \
T* dx, \
CPUContext* ctx) { \
_HardSigmoidGrad(count, cast::to<T>(alpha), dy, y, dx); \
_HardSigmoidGrad(count, convert::To<T>(alpha), dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
......
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -68,7 +68,7 @@ void _HardSwishGrad<float16>(
const T* x, \
T* y, \
CPUContext* ctx) { \
_HardSwish(count, cast::to<T>(alpha), cast::to<T>(beta), x, y); \
_HardSwish(count, convert::To<T>(alpha), convert::To<T>(beta), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -81,7 +81,8 @@ void _HardSwishGrad<float16>(
const T* x, \
T* dx, \
CPUContext* ctx) { \
_HardSwishGrad(count, cast::to<T>(alpha), cast::to<T>(beta), dy, x, dx); \
_HardSwishGrad( \
count, convert::To<T>(alpha), convert::To<T>(beta), dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -87,7 +87,7 @@ void _ReluNGrad<float16>(
template <> \
void Relu<T, CPUContext>( \
const int count, const float alpha, const T* x, T* y, CPUContext* ctx) { \
_Relu(count, cast::to<T>(alpha), x, y); \
_Relu(count, convert::To<T>(alpha), x, y); \
} \
template <> \
void ReluN<T, CPUContext>( \
......@@ -96,7 +96,7 @@ void _ReluNGrad<float16>(
const T* x, \
T* y, \
CPUContext* ctx) { \
_ReluN(count, cast::to<T>(max_value), x, y); \
_ReluN(count, convert::To<T>(max_value), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -108,7 +108,7 @@ void _ReluNGrad<float16>(
const T* y, \
T* dx, \
CPUContext* ctx) { \
_ReluGrad(count, cast::to<T>(alpha), dy, y, dx); \
_ReluGrad(count, convert::To<T>(alpha), dy, y, dx); \
} \
template <> \
void ReluNGrad<T, CPUContext>( \
......@@ -118,7 +118,7 @@ void _ReluNGrad<float16>(
const T* y, \
T* dx, \
CPUContext* ctx) { \
_ReluNGrad(count, cast::to<T>(max_value), dy, y, dx); \
_ReluNGrad(count, convert::To<T>(max_value), dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -287,13 +287,13 @@ void ReluN<float16, CUDAContext>(
0,
ctx->cuda_stream()>>>(
count >> 1,
cast::to<half>(max_value),
convert::To<half>(max_value),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y));
} else {
_ReluN<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
cast::to<half>(max_value),
convert::To<half>(max_value),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
......@@ -339,14 +339,14 @@ void ReluNGrad<float16, CUDAContext>(
0,
ctx->cuda_stream()>>>(
count >> 1,
cast::to<half2>(max_value),
convert::To<half2>(max_value),
reinterpret_cast<const half2*>(dy),
reinterpret_cast<const half2*>(y),
reinterpret_cast<half2*>(dx));
} else {
_ReluNGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
cast::to<half>(max_value),
convert::To<half>(max_value),
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y),
reinterpret_cast<half*>(dx));
......@@ -362,7 +362,7 @@ void ReluNGrad<float16, CUDAContext>(
T* y, \
CUDAContext* ctx) { \
_Relu<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, cast::to<T>(alpha), x, y); \
count, convert::To<T>(alpha), x, y); \
} \
template <> \
void ReluN<T, CUDAContext>( \
......@@ -372,7 +372,7 @@ void ReluNGrad<float16, CUDAContext>(
T* y, \
CUDAContext* ctx) { \
_ReluN<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, cast::to<T>(max_value), x, y); \
count, convert::To<T>(max_value), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -385,7 +385,7 @@ void ReluNGrad<float16, CUDAContext>(
T* dx, \
CUDAContext* ctx) { \
_ReluGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, cast::to<T>(alpha), dy, y, dx); \
count, convert::To<T>(alpha), dy, y, dx); \
} \
template <> \
void ReluNGrad<T, CUDAContext>( \
......@@ -396,7 +396,7 @@ void ReluNGrad<float16, CUDAContext>(
T* dx, \
CUDAContext* ctx) { \
_ReluNGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, cast::to<T>(max_value), dy, y, dx); \
count, convert::To<T>(max_value), dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float);
......
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -66,7 +66,7 @@ void _SeluGrad<float16>(
const T* x, \
T* y, \
CPUContext* ctx) { \
_Selu(count, cast::to<T>(alpha), cast::to<T>(gamma), x, y); \
_Selu(count, convert::To<T>(alpha), convert::To<T>(gamma), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -79,7 +79,7 @@ void _SeluGrad<float16>(
const T* y, \
T* dx, \
CPUContext* tx) { \
_SeluGrad(count, cast::to<T>(alpha), cast::to<T>(gamma), dy, y, dx); \
_SeluGrad(count, convert::To<T>(alpha), convert::To<T>(gamma), dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -200,7 +199,7 @@ void Softmax<float16, CUDAContext>(
rows,
cols,
inner_dim,
cast::to<half>(std::numeric_limits<float>::lowest()),
convert::To<half>(std::numeric_limits<float>::lowest()),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -35,7 +35,7 @@ __global__ void _Tanh<half2>(const int nthreads, const half2* x, half2* y) {
template <typename T>
__global__ void _TanhGrad(const int nthreads, const T* dy, const T* y, T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = dy[i] * (T(1) - utils::math::Square(y[i]));
dx[i] = dy[i] * (T(1) - math::utils::Square(y[i]));
}
}
......@@ -44,7 +44,7 @@ __global__ void
_TanhGrad<half>(const int nthreads, const half* dy, const half* y, half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = __float2half(
__half2float(dy[i]) * (1.f - utils::math::Square(__half2float(y[i]))));
__half2float(dy[i]) * (1.f - math::utils::Square(__half2float(y[i]))));
}
}
......@@ -58,8 +58,8 @@ __global__ void _TanhGrad<half2>(
const float2 val = __half22float2(y[i]);
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(
grad.x * (1.f - utils::math::Square(val.x)),
grad.y * (1.f - utils::math::Square(val.y)));
grad.x * (1.f - math::utils::Square(val.x)),
grad.y * (1.f - math::utils::Square(val.y)));
}
}
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -28,7 +28,7 @@ void _ChannelNormalize(
if (d == axis) wi = idx[d];
}
y[yi] = ((Ty)x[xi] - (Ty)mean[wi]) / (Ty)std[wi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, idx.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, idx.data());
}
}
......
......@@ -26,7 +26,7 @@ void _CumSum(
} else {
y[i] = exclusive ? T(0) : x[i];
}
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
......
#include "dragon/utils/cast.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -11,7 +10,7 @@ namespace {
template <typename T>
void _SetEye(const int n, const int m, const int k, T* y) {
for (int i = 0; i < n; ++i) {
y[i * m + k + i] = cast::to<T>(1.f);
y[i * m + k + i] = convert::To<T>(1.f);
}
}
......@@ -23,7 +22,7 @@ void _SetEye(const int n, const int m, const int k, T* y) {
template <> \
void Eye<T, CPUContext>( \
const int n, const int m, const int k, T* y, CPUContext* ctx) { \
math::Set(n* m, cast::to<T>(0.f), y, ctx); \
math::Set(n* m, convert::To<T>(0.f), y, ctx); \
if (k > 0) { \
if (m - k > 0) _SetEye(m - k, m, k, y); \
} else { \
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -37,7 +36,7 @@ void Eye<float16, CUDAContext>(
const int k,
float16* y,
CUDAContext* ctx) {
math::Set(n * m, cast::to<float16>(0.f), y, ctx);
math::Set(n * m, convert::To<float16>(0.f), y, ctx);
if (k > 0) {
if (m - k > 0) {
_SetEye<<<CUDA_BLOCKS(m - k), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
......
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -16,12 +16,12 @@ void _RowwiseLinSpace(
T* y) {
for (int i = 0; i < cols; ++i) {
const auto delta = (stop[i] - start[i]) / double(rows - 1);
y[i] = cast::to<T>(start[i]);
y[i] = convert::To<T>(start[i]);
if (rows > 1) {
y[i + (rows - 1) * cols] = cast::to<T>(stop[i]);
y[i + (rows - 1) * cols] = convert::To<T>(stop[i]);
}
for (int j = 1; j < rows - 1; ++j) {
y[i + j * cols] = cast::to<T>(start[i] + double(j) * delta);
y[i + j * cols] = convert::To<T>(start[i] + double(j) * delta);
}
}
}
......@@ -36,12 +36,12 @@ void _ColwiseLinSpace(
for (int i = 0; i < rows; ++i) {
const auto delta = (stop[i] - start[i]) / double(cols - 1);
auto* offset_y = y + i * cols;
offset_y[0] = cast::to<T>(start[i]);
offset_y[0] = convert::To<T>(start[i]);
if (cols > 1) {
offset_y[cols - 1] = cast::to<T>(stop[i]);
offset_y[cols - 1] = convert::To<T>(stop[i]);
}
for (int j = 1; j < cols - 1; ++j) {
offset_y[j] = cast::to<T>(start[i] + double(j) * delta);
offset_y[j] = convert::To<T>(start[i] + double(j) * delta);
}
}
}
......
#include "dragon/utils/cast.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -61,7 +60,7 @@ void _MaskedSelectGrad(
const ValueType* dy, \
ValueType* dx, \
CPUContext* ctx) { \
math::Set(count, cast::to<ValueType>(0.f), dx, ctx); \
math::Set(count, convert::To<ValueType>(0.f), dx, ctx); \
_MaskedSelectGrad(num_selected, index, dy, dx); \
}
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -61,7 +61,7 @@ __global__ void _MaskedSelectGrad(
const ValueType* dy, \
ValueType* dx, \
CUDAContext* ctx) { \
math::Set(count, cast::to<ValueType>(0.f), dx, ctx); \
math::Set(count, convert::To<ValueType>(0.f), dx, ctx); \
_MaskedSelectGrad<<< \
CUDA_BLOCKS(num_selected), \
CUDA_THREADS, \
......
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/cast.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -30,7 +29,7 @@ void _ConstPad(
xi += r * x_strides[d];
}
y[yi] = d >= 0 ? value : x[xi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
......@@ -56,7 +55,7 @@ void _ReflectPad(
xi += r * x_strides[d];
}
y[yi] = x[xi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
......@@ -80,7 +79,7 @@ void _EdgePad(
xi += r * x_strides[d];
}
y[yi] = x[xi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
......@@ -115,7 +114,14 @@ void _EdgePad(
T* y, \
CPUContext* ctx) { \
_ConstPad( \
num_dims, x_dims, x_strides, y_dims, pads, cast::to<T>(value), x, y); \
num_dims, \
x_dims, \
x_strides, \
y_dims, \
pads, \
convert::To<T>(value), \
x, \
y); \
}
DEFINE_CONST_KERNEL_LAUNCHER(bool);
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -114,7 +113,7 @@ __global__ void _EdgePad(
X_strides, \
Y_dims, \
X_pads, \
cast::to<T>(value), \
convert::To<T>(value), \
x, \
y); \
}
......
#include "dragon/utils/cast.h"
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -14,7 +14,7 @@ void _Range(const int count, const double start, const double delta, T* y) {
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = cast::to<T>(start + double(i) * delta);
y[i] = convert::To<T>(start + double(i) * delta);
}
}
......
......@@ -26,7 +26,7 @@ void _ReduceSumGrad(
yi += (index[d] % y_dims[d]) * y_strides[d];
}
dx[xi] = dy[yi] * scale;
utils::math::IncreaseIndexInDims(num_dims, x_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, x_dims, index.data());
}
}
......
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -25,7 +25,7 @@ void _Slice(
xi += (index[d] + starts[d]) * x_strides[d];
}
y[yi] = x[xi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
......@@ -47,7 +47,7 @@ void _SliceGrad(
xi += (index[d] + starts[d]) * x_strides[d];
}
dx[xi] = dy[yi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
......
......@@ -25,7 +25,7 @@ void _Tile(
xi += (index[d] % x_dims[d]) * x_strides[d];
}
y[i] = x[xi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
......
......@@ -162,7 +162,7 @@ __global__ void _SelectViaDeviceSort(
/* ------------------- Launcher Separator ------------------- */
#define PLACE_BLOCK_SORT_CASE(T, items_per_thread) \
#define BLOCKSORT_KERNEL(T, items_per_thread) \
_SelectViaBlockSort<T, items_per_thread> \
<<<CUDA_2D_BLOCKS(rows), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, \
......@@ -175,15 +175,15 @@ __global__ void _SelectViaDeviceSort(
reinterpret_cast<T*>(value), \
index)
#define PLACE_BLOCK_SORT_CASES(T) \
#define DISPATCH_BLOCKSORT_KERNEL(T) \
if (cols <= CUDA_THREADS) { \
PLACE_BLOCK_SORT_CASE(T, 1); \
BLOCKSORT_KERNEL(T, 1); \
} else if (cols <= CUDA_THREADS * 2) { \
PLACE_BLOCK_SORT_CASE(T, 2); \
BLOCKSORT_KERNEL(T, 2); \
} else if (cols <= CUDA_THREADS * 4) { \
PLACE_BLOCK_SORT_CASE(T, 4); \
BLOCKSORT_KERNEL(T, 4); \
} else if (cols <= CUDA_THREADS * 8) { \
PLACE_BLOCK_SORT_CASE(T, 8); \
BLOCKSORT_KERNEL(T, 8); \
} else { \
LOG(FATAL) << "Too larger dimension (> " << CUDA_THREADS * 8 \
<< ") to launch the cuda kernel"; \
......@@ -238,7 +238,7 @@ __global__ void _SelectViaDeviceSort(
return; \
} \
T2 init = largest > 0 ? kLowest : kMax; \
PLACE_BLOCK_SORT_CASES(T2); \
DISPATCH_BLOCKSORT_KERNEL(T2); \
}
DEFINE_KERNEL_LAUNCHER(
......@@ -277,8 +277,8 @@ DEFINE_KERNEL_LAUNCHER(
std::numeric_limits<double>::lowest(),
std::numeric_limits<double>::max());
#undef PLACE_BLOCK_SORT_CASE
#undef PLACE_BLOCK_SORT_CASES
#undef BLOCK_SORTKERNEL
#undef DISPATCH_BLOCKSORT_KERNEL
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -24,7 +24,7 @@ void _Transpose(
xi += index[d] * x_strides[d];
}
y[yi] = x[xi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
......@@ -45,7 +45,7 @@ void _TransposeGrad(
xi += index[d] * x_strides[d];
}
dx[xi] = dy[yi];
utils::math::IncreaseIndexInDims(num_dims, y_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
......
......@@ -25,7 +25,7 @@ void _Assign(
yi += (index[d] + starts[d]) * y_strides[d];
}
y[yi] = x[i];
utils::math::IncreaseIndexInDims(num_dims, x_dims, index.data());
math::utils::IncreaseIndexInDims(num_dims, x_dims, index.data());
}
}
......
......@@ -19,7 +19,7 @@ void _BroadcastLossGrad(
const int count = outer_dim * axis_dim * inner_dim;
for (int i = 0; i < count; ++i) {
dx[i] *= dy[idx[0] * inner_dim + idx[2]];
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
......@@ -93,7 +93,7 @@ void BroadcastLossGrad<float16, CPUContext>(
num_masks > 0 && normalizer < 0.f \
? (float)math::Sum(num_masks, 1.f, mask, ctx) \
: normalizer); \
math::Scale(count, cast::to<float>(dy[0]) / inv_scale, dx, dx, ctx); \
math::Scale(count, convert::To<float>(dy[0]) / inv_scale, dx, dx, ctx); \
} \
template <> \
void BroadcastLossGrad<T, CPUContext>( \
......
......@@ -28,7 +28,7 @@ void _NLLLoss(
k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
loss[i] = -logit[k], mask[i] = LogitType(1);
}
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
}
}
......@@ -53,7 +53,7 @@ void _NLLLossGrad(
k = (idx[0] * axis_dim + label) * inner_dim + idx[1];
dlogit[k] = LogitType(-1), mask[i] = LogitType(1);
}
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
}
}
......
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -48,7 +48,7 @@ void _SigmoidFocalLoss(
loss[i] += -c2 * neg_term * neg_alpha;
mask[i] = c1;
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
......@@ -96,7 +96,7 @@ void _SigmoidFocalLossGrad(
dx[i] += -c2 * neg_term * neg_alpha;
mask[i] = c1;
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -29,7 +29,7 @@ void _SparseSoftmaxCrossEntropy(
loss[i] = -std::log(std::max(prob[k], LogitType(FLT_MIN)));
mask[i] = LogitType(1);
}
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
}
}
......@@ -60,7 +60,7 @@ void _SparseSoftmaxCrossEntropyGrad(
dx[k] -= LogitType(1);
mask[i] = LogitType(1);
}
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
}
}
......
#include "dragon/utils/cast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -22,15 +22,15 @@ void _Clip<float16>(
const float16 high,
const float16* x,
float16* y) {
auto lowf = cast::to<float>(low);
auto highf = cast::to<float>(high);
auto lowf = convert::To<float>(low);
auto highf = convert::To<float>(high);
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
auto val = cast::to<float>(x[i]);
auto val = convert::To<float>(x[i]);
val = std::max(lowf, std::min(val, highf));
y[i] = cast::to<float16>(val);
y[i] = convert::To<float16>(val);
}
}
......@@ -56,14 +56,14 @@ void _ClipGrad<float16>(
const float16* dy,
const float16* x,
float16* dx) {
auto lowf = cast::to<float>(low);
auto highf = cast::to<float>(high);
auto kZero = cast::to<float16>(0.f);
auto lowf = convert::To<float>(low);
auto highf = convert::To<float>(high);
auto kZero = convert::To<float16>(0.f);
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
auto val = cast::to<float>(x[i]);
auto val = convert::To<float>(x[i]);
dx[i] = (val < lowf || val > highf) ? kZero : dy[i];
}
} // ClipGrad
......@@ -81,7 +81,7 @@ void _ClipGrad<float16>(
const T* x, \
T* y, \
CPUContext* ctx) { \
_Clip(count, cast::to<T>(low), cast::to<T>(high), x, y); \
_Clip(count, convert::To<T>(low), convert::To<T>(high), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -94,7 +94,7 @@ void _ClipGrad<float16>(
const T* x, \
T* dx, \
CPUContext* ctx) { \
_ClipGrad(count, cast::to<T>(low), cast::to<T>(high), dy, x, dx); \
_ClipGrad(count, convert::To<T>(low), convert::To<T>(high), dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -104,8 +103,8 @@ void Clip<float16, CUDAContext>(
CUDAContext* ctx) {
_Clip<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
cast::to<half>(low),
cast::to<half>(high),
convert::To<half>(low),
convert::To<half>(high),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
......@@ -121,8 +120,8 @@ void ClipGrad<float16, CUDAContext>(
CUDAContext* ctx) {
_ClipGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
cast::to<half>(low),
cast::to<half>(high),
convert::To<half>(low),
convert::To<half>(high),
reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(dx));
......@@ -138,7 +137,7 @@ void ClipGrad<float16, CUDAContext>(
T* y, \
CUDAContext* ctx) { \
_Clip<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, cast::to<T>(low), cast::to<T>(high), x, y); \
count, convert::To<T>(low), convert::To<T>(high), x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
......@@ -152,7 +151,7 @@ void ClipGrad<float16, CUDAContext>(
T* dx, \
CUDAContext* ctx) { \
_ClipGrad<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, cast::to<T>(low), cast::to<T>(high), dy, x, dx); \
count, convert::To<T>(low), convert::To<T>(high), dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -70,7 +70,7 @@ template <typename T>
__global__ void
_ReciprocalGrad(const int nthreads, const T* dy, const T* y, T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = -dy[i] * utils::math::Square(y[i]);
dx[i] = -dy[i] * math::utils::Square(y[i]);
}
}
......@@ -82,7 +82,7 @@ __global__ void _ReciprocalGrad<half>(
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = __float2half(
-__half2float(dy[i]) * utils::math::Square(__half2float(y[i])));
-__half2float(dy[i]) * math::utils::Square(__half2float(y[i])));
}
}
......@@ -103,7 +103,7 @@ __global__ void _ReciprocalGrad<half2>(
template <typename T>
__global__ void _RsqrtGrad(const int nthreads, const T* dy, const T* y, T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = T(-0.5) * dy[i] * utils::math::Cube(y[i]);
dx[i] = T(-0.5) * dy[i] * math::utils::Cube(y[i]);
}
}
......@@ -112,7 +112,7 @@ __global__ void
_RsqrtGrad<half>(const int nthreads, const half* dy, const half* y, half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = __float2half(
-0.5f * __half2float(dy[i]) * utils::math::Cube(__half2float(y[i])));
-0.5f * __half2float(dy[i]) * math::utils::Cube(__half2float(y[i])));
}
}
......
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -106,14 +106,14 @@ void _Moments(
y_dims[axes[i]] = 1;
// Case #1: Rowwise Reduce
if (utils::math::IsRowwiseReduce(
if (math::utils::IsRowwiseReduce(
num_dims, dims, y_dims.data(), &rows, &cols)) {
_RowwiseMoments(rows, cols, x, mean, var);
return;
}
// Case #2: Colwise Reduce
if (utils::math::IsColwiseReduce(
if (math::utils::IsColwiseReduce(
num_dims, dims, y_dims.data(), &rows, &cols)) {
_ColwiseMoments(rows, cols, x, mean, var);
return;
......@@ -121,8 +121,8 @@ void _Moments(
// Case #3: Generic Reduce
vec32_t axesT(num_dims), stridesT(num_dims), dimsT(num_dims);
utils::math::TransposeAxesForReduce(num_dims, num_axes, axes, axesT.data());
utils::math::ComputeTransposeStrides(
math::utils::TransposeAxesForReduce(num_dims, num_axes, axes, axesT.data());
math::utils::ComputeTransposeStrides(
num_dims, dims, axesT.data(), stridesT.data());
rows = cols = 1;
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -28,10 +28,10 @@ __global__ void _RowwiseMoments(
const int xi = j * cols + i;
#if __CUDA_ARCH__ >= 350
m_val += __ldg(x + xi);
v_val += utils::math::Square(__ldg(x + xi));
v_val += math::utils::Square(__ldg(x + xi));
#else
m_val += x[xi];
v_val += utils::math::Square(x[xi]);
v_val += math::utils::Square(x[xi]);
#endif
}
m_val = BlockReduce<Ty>(m_storage).Sum(m_val);
......@@ -59,7 +59,7 @@ __global__ void _RowwiseMoments<half, float>(
CUDA_2D_KERNEL_LOOP2(j, rows) {
const int xi = j * cols + i;
m_val += __half2float(__ldg(x + xi));
v_val += utils::math::Square(__half2float(__ldg(x + xi)));
v_val += math::utils::Square(__half2float(__ldg(x + xi)));
}
m_val = BlockReduce<float>(m_storage).Sum(m_val);
v_val = BlockReduce<float>(v_storage).Sum(v_val);
......@@ -87,10 +87,10 @@ __global__ void _ColwiseMoments(
const int xi = i * cols + j;
#if __CUDA_ARCH__ >= 350
m_val += __ldg(x + xi);
v_val += utils::math::Square(__ldg(x + xi));
v_val += math::utils::Square(__ldg(x + xi));
#else
m_val += x[xi];
v_val += utils::math::Square(x[xi]);
v_val += math::utils::Square(x[xi]);
#endif
}
m_val = BlockReduce<Ty>(m_storage).Sum(m_val);
......@@ -118,7 +118,7 @@ __global__ void _ColwiseMoments<half, float>(
CUDA_2D_KERNEL_LOOP2(j, cols) {
const int xi = i * cols + j;
m_val += __half2float(__ldg(x + xi));
v_val += utils::math::Square(__half2float(__ldg(x + xi)));
v_val += math::utils::Square(__half2float(__ldg(x + xi)));
}
m_val = BlockReduce<float>(m_storage).Sum(m_val);
v_val = BlockReduce<float>(v_storage).Sum(v_val);
......@@ -154,10 +154,10 @@ __global__ void _GenericMoments(
}
#if __CUDA_ARCH__ >= 350
m_val += __ldg(x + xi);
v_val += utils::math::Square(__ldg(x + xi));
v_val += math::utils::Square(__ldg(x + xi));
#else
m_val += x[xi];
v_val += utils::math::Square(x[xi]);
v_val += math::utils::Square(x[xi]);
#endif
}
m_val = BlockReduce<Ty>(m_storage).Sum(m_val);
......@@ -194,10 +194,10 @@ __global__ void _GenericMoments(
}
#if __CUDA_ARCH__ >= 350
m_val += __half2float(__ldg(x + xi));
v_val += utils::math::Square(__half2float(__ldg(x + xi)));
v_val += math::utils::Square(__half2float(__ldg(x + xi)));
#else
m_val += __half2float(x[xi]);
v_val += utils::math::Square(__half2float(x[xi]));
v_val += math::utils::Square(__half2float(x[xi]));
#endif
}
m_val = BlockReduce<float>(m_storage).Sum(m_val);
......@@ -226,7 +226,7 @@ void _Moments(
y_dims[axes[i]] = 1;
/*! Case #1: Rowwise Reduce */
if (utils::math::IsRowwiseReduce(
if (math::utils::IsRowwiseReduce(
num_dims, dims, y_dims.data(), &rows, &cols)) {
_RowwiseMoments<<<
CUDA_2D_BLOCKS(cols),
......@@ -237,7 +237,7 @@ void _Moments(
}
/*! Case #2: Colwise Reduce */
if (utils::math::IsColwiseReduce(
if (math::utils::IsColwiseReduce(
num_dims, dims, y_dims.data(), &rows, &cols)) {
_ColwiseMoments<<<
CUDA_2D_BLOCKS(rows),
......@@ -250,8 +250,8 @@ void _Moments(
/*! Case #3: Generic Reduce */
CUDA_TENSOR_DIMS_CHECK(num_dims);
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> axesT, stridesT, dimsT;
utils::math::TransposeAxesForReduce(num_dims, num_axes, axes, axesT.data);
utils::math::ComputeTransposeStrides(
math::utils::TransposeAxesForReduce(num_dims, num_axes, axes, axesT.data);
math::utils::ComputeTransposeStrides(
num_dims, dims, axesT.data, stridesT.data);
rows = cols = 1;
......
#include "dragon/core/memory.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -9,22 +9,22 @@ namespace kernel {
namespace {
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
void _BatchNormExpectation(
const std::array<int, 3>& dims,
const Tp denorm,
const Tx* x,
Tp* ex,
Tp* ex2) {
const AccT denorm,
const T* x,
AccT* ex,
AccT* ex2) {
const int kCDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int count = dims[0] * dims[1] * dims[2];
std::array<int, 3> idx = {0, 0, 0};
for (int i = 0; i < count; ++i) {
const Tx x_val = x[i];
const T x_val = x[i];
const int pi = idx[kCDim];
ex[pi] += x_val;
ex2[pi] += x_val * x_val;
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
for (int i = 0; i < dims[kCDim]; ++i) {
ex[i] = ex[i] * denorm;
......@@ -32,16 +32,68 @@ void _BatchNormExpectation(
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T>
void _BatchNormFusedParams(
const int C,
const T* mu,
const T* rsig,
const T* gamma,
const T* beta,
T* scale,
T* bias) {
EigenVectorArrayMap<T> scale_arr(scale, C);
scale_arr = ConstEigenVectorArrayMap<T>(gamma, C) *
ConstEigenVectorArrayMap<T>(rsig, C);
EigenVectorArrayMap<T>(bias, C) = ConstEigenVectorArrayMap<T>(beta, C) -
scale_arr * ConstEigenVectorArrayMap<T>(mu, C);
}
template <typename T, typename AccT>
void _BatchNormAffineNCHW(
const int N,
const int C,
const int S,
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < C; ++j) {
EigenVectorArrayMap<T>(y, S) =
ConstEigenVectorArrayMap<T>(x, S) * scale[j] + bias[j];
x += S;
y += S;
}
}
}
template <typename T, typename AccT>
void _BatchNormAffineNHWC(
const int N,
const int C,
const int S,
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
const auto NS = N * S;
ConstEigenVectorArrayMap<AccT> scale_arr(scale, C);
ConstEigenVectorArrayMap<AccT> bias_arr(bias, C);
EigenArrayMap<T>(y, C, NS) =
(ConstEigenArrayMap<T>(x, C, NS).colwise() * scale_arr).colwise() +
bias_arr;
}
template <typename T, typename AccT, StorageOrder kOrder>
void _BatchNormInternalGrad(
const std::array<int, 3>& dims,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tx* dy,
Tp* dgamma,
Tp* dbeta) {
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const T* dy,
AccT* dgamma,
AccT* dbeta) {
const int kCDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int count = dims[0] * dims[1] * dims[2];
std::array<int, 3> idx = {0, 0, 0};
......@@ -49,43 +101,43 @@ void _BatchNormInternalGrad(
const int pi = idx[kCDim];
dgamma[pi] += dy[i] * (x[i] - mu[pi]) * rsig[pi];
dbeta[pi] += dy[i];
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
void _BatchNormTrainingGrad(
const std::array<int, 3>& dims,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tp* dgamma,
const Tp* dbeta,
const Tx* dy,
Tx* dx) {
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* dgamma,
const AccT* dbeta,
const T* dy,
T* dx) {
const int kCDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int count = dims[0] * dims[1] * dims[2];
const Tp denom = Tp(1) / static_cast<Tp>(count / dims[kCDim]);
const AccT denom = AccT(1) / static_cast<AccT>(count / dims[kCDim]);
std::array<int, 3> idx = {0, 0, 0};
for (int i = 0; i < count; ++i) {
const int pi = idx[kCDim];
const Tp x_norm = (x[i] - mu[pi]) * rsig[pi];
const AccT x_norm = (x[i] - mu[pi]) * rsig[pi];
dx[i] = gamma[pi] * rsig[pi] *
(dy[i] - (x_norm * dgamma[pi] + dbeta[pi]) * denom);
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
void _BatchNormWGrad(
const std::array<int, 3>& dims,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tx* dy,
Tp* dgamma,
Tp* dbeta) {
const T* x,
const AccT* mu,
const AccT* rsig,
const T* dy,
AccT* dgamma,
AccT* dbeta) {
const int kCDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int count = dims[0] * dims[1] * dims[2];
std::array<int, 3> idx = {0, 0, 0};
......@@ -93,33 +145,33 @@ void _BatchNormWGrad(
const int pi = idx[kCDim];
dgamma[pi] += dy[i] * (x[i] - mu[pi]) * rsig[pi];
dbeta[pi] += dy[i];
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
void _BatchNormInferenceGrad(
const int N,
const int C,
const int S,
const Tp* rsig,
const Tp* gamma,
const Tx* dy,
Tx* dx) {
const AccT* rsig,
const AccT* gamma,
const T* dy,
T* dx) {
if (kOrder == StorageOrder::NCHW) {
const int CS = C * S;
for (int i = 0; i < N; ++i) {
EigenArrayMap<Tx>(dx + i * CS, S, C) =
(ConstEigenArrayMap<Tx>(dy + i * CS, S, C).rowwise() *
(ConstEigenVectorArrayMap<Tp>(gamma, C) *
ConstEigenVectorArrayMap<Tp>(rsig, C))
EigenArrayMap<T>(dx + i * CS, S, C) =
(ConstEigenArrayMap<T>(dy + i * CS, S, C).rowwise() *
(ConstEigenVectorArrayMap<AccT>(gamma, C) *
ConstEigenVectorArrayMap<AccT>(rsig, C))
.transpose());
}
} else if (kOrder == StorageOrder::NHWC) {
EigenArrayMap<Tx>(dx, C, N * S) =
(ConstEigenArrayMap<Tx>(dy, C, N * S).colwise() *
(ConstEigenVectorArrayMap<Tp>(gamma, C) *
ConstEigenVectorArrayMap<Tp>(rsig, C)));
EigenArrayMap<T>(dx, C, N * S) =
(ConstEigenArrayMap<T>(dy, C, N * S).colwise() *
(ConstEigenVectorArrayMap<AccT>(gamma, C) *
ConstEigenVectorArrayMap<AccT>(rsig, C)));
}
}
......@@ -127,141 +179,223 @@ void _BatchNormInferenceGrad(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_GRAD_KERNEL_LAUNCHER(Tx, Tp) \
template <>
void BatchNorm<float16, float, CPUContext>(
const int N,
const int C,
const int S,
const string& data_format,
const float16* x,
const float* mu,
const float* rsig,
const float* gamma,
const float* beta,
float* scale,
float* bias,
float16* y,
CPUContext* tx) {
CPU_FP16_NOT_SUPPORTED;
}
template <>
void BatchNormExpectation<float16, float, CPUContext>(
const int N,
const int C,
const int S,
const float denorm,
const string& data_format,
const float16* x,
float* ex,
float* ex2,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template <>
void BatchNormInternalGrad<float16, float, CPUContext>(
const int N,
const int C,
const int S,
const string& data_format,
const float16* x,
const float* mu,
const float* rsig,
const float* gamma,
const float16* dy,
float* dgamma,
float* dbeta,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
} // BatchNormInternalGrad
template <>
void BatchNormTrainingGrad<float16, float, CPUContext>(
const int N,
const int C,
const int S,
const string& data_format,
const float16* x,
const float* mu,
const float* rsig,
const float* gamma,
const float* dgamma,
const float* dbeta,
const float16* dy,
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
} // BatchNormTrainingGrad
template <>
void BatchNormInferenceGrad<float16, float, CPUContext>(
const int N,
const int C,
const int S,
const string& data_format,
const float16* x,
const float* mu,
const float* rsig,
const float* gamma,
const float16* dy,
float* dgamma,
float* dbeta,
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
} // BatchNormInferenceGrad
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \
void BatchNormExpectation<Tx, Tp, CPUContext>( \
void BatchNorm<T, AccT, CPUContext>( \
const int N, \
const int C, \
const int S, \
const Tp denorm, \
const string& data_format, \
const Tx* x, \
Tp* ex, \
Tp* ex2, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \
CPUContext* ctx) { \
math::Set(C, Tp(0), ex, ctx); \
math::Set(C, Tp(0), ex2, ctx); \
_BatchNormFusedParams(C, mu, rsig, gamma, beta, scale, bias); \
if (data_format == "NCHW") { \
_BatchNormExpectation<Tx, Tp, StorageOrder::NCHW>( \
{N, C, S}, denorm, x, ex, ex2); \
_BatchNormAffineNCHW(N, C, S, x, scale, bias, y); \
} else if (data_format == "NHWC") { \
_BatchNormExpectation<Tx, Tp, StorageOrder::NHWC>( \
{N, S, C}, denorm, x, ex, ex2); \
} \
_BatchNormAffineNHWC(N, C, S, x, scale, bias, y); \
} \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T, AccT) \
template <> \
void BatchNormInternalGrad<Tx, Tp, CPUContext>( \
void BatchNormExpectation<T, AccT, CPUContext>( \
const int N, \
const int C, \
const int S, \
const AccT denorm, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tx* dy, \
Tp* dgamma, \
Tp* dbeta, \
const T* x, \
AccT* ex, \
AccT* ex2, \
CPUContext* ctx) { \
math::Set(C, Tp(0), dgamma, ctx); \
math::Set(C, Tp(0), dbeta, ctx); \
math::Set(C, AccT(0), ex, ctx); \
math::Set(C, AccT(0), ex2, ctx); \
if (data_format == "NCHW") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NCHW>( \
{N, C, S}, x, mu, rsig, gamma, dy, dgamma, dbeta); \
_BatchNormExpectation<T, AccT, StorageOrder::NCHW>( \
{N, C, S}, denorm, x, ex, ex2); \
} else if (data_format == "NHWC") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NHWC>( \
{N, S, C}, x, mu, rsig, gamma, dy, dgamma, dbeta); \
_BatchNormExpectation<T, AccT, StorageOrder::NHWC>( \
{N, S, C}, denorm, x, ex, ex2); \
} \
} \
template <> \
void BatchNormTrainingGrad<Tx, Tp, CPUContext>( \
void BatchNormInternalGrad<T, AccT, CPUContext>( \
const int N, \
const int C, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tp* dgamma, \
const Tp* dbeta, \
const Tx* dy, \
Tx* dx, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const T* dy, \
AccT* dgamma, \
AccT* dbeta, \
CPUContext* ctx) { \
math::Set(C, AccT(0), dgamma, ctx); \
math::Set(C, AccT(0), dbeta, ctx); \
if (data_format == "NCHW") { \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NCHW>( \
{N, C, S}, x, mu, rsig, gamma, dgamma, dbeta, dy, dx); \
_BatchNormInternalGrad<T, AccT, StorageOrder::NCHW>( \
{N, C, S}, x, mu, rsig, gamma, dy, dgamma, dbeta); \
} else if (data_format == "NHWC") { \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NHWC>( \
{N, S, C}, x, mu, rsig, gamma, dgamma, dbeta, dy, dx); \
_BatchNormInternalGrad<T, AccT, StorageOrder::NHWC>( \
{N, S, C}, x, mu, rsig, gamma, dy, dgamma, dbeta); \
} \
} \
template <> \
void BatchNormBackwardTraining<Tx, Tp, CPUContext>( \
void BatchNormTrainingGrad<T, AccT, CPUContext>( \
const int N, \
const int C, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tx* dy, \
Tx* dx, \
Tp* dgamma, \
Tp* dbeta, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const AccT* dgamma, \
const AccT* dbeta, \
const T* dy, \
T* dx, \
CPUContext* ctx) { \
math::Set(C, Tp(0), dgamma, ctx); \
math::Set(C, Tp(0), dbeta, ctx); \
if (data_format == "NCHW") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NCHW>( \
{N, C, S}, x, mu, rsig, gamma, dy, dgamma, dbeta); \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NCHW>( \
_BatchNormTrainingGrad<T, AccT, StorageOrder::NCHW>( \
{N, C, S}, x, mu, rsig, gamma, dgamma, dbeta, dy, dx); \
} else if (data_format == "NHWC") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NHWC>( \
{N, S, C}, x, mu, rsig, gamma, dy, dgamma, dbeta); \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NHWC>( \
_BatchNormTrainingGrad<T, AccT, StorageOrder::NHWC>( \
{N, S, C}, x, mu, rsig, gamma, dgamma, dbeta, dy, dx); \
} \
} \
template <> \
void BatchNormBackwardInference<Tx, Tp, CPUContext>( \
void BatchNormInferenceGrad<T, AccT, CPUContext>( \
const int N, \
const int C, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tx* dy, \
Tx* dx, \
Tp* dgamma, \
Tp* dbeta, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const T* dy, \
AccT* dgamma, \
AccT* dbeta, \
T* dx, \
CPUContext* ctx) { \
if (data_format == "NCHW") { \
if (dgamma != nullptr) { \
math::Set(C, Tp(0), dgamma, ctx); \
math::Set(C, Tp(0), dbeta, ctx); \
_BatchNormWGrad<Tx, Tp, StorageOrder::NCHW>( \
math::Set(C, AccT(0), dgamma, ctx); \
math::Set(C, AccT(0), dbeta, ctx); \
_BatchNormWGrad<T, AccT, StorageOrder::NCHW>( \
{N, C, S}, x, mu, rsig, dy, dgamma, dbeta); \
} \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NCHW>( \
_BatchNormInferenceGrad<T, AccT, StorageOrder::NCHW>( \
N, C, S, rsig, gamma, dy, dx); \
} else if (data_format == "NHWC") { \
if (dgamma != nullptr) { \
math::Set(C, Tp(0), dgamma, ctx); \
math::Set(C, Tp(0), dbeta, ctx); \
_BatchNormWGrad<Tx, Tp, StorageOrder::NHWC>( \
math::Set(C, AccT(0), dgamma, ctx); \
math::Set(C, AccT(0), dbeta, ctx); \
_BatchNormWGrad<T, AccT, StorageOrder::NHWC>( \
{N, S, C}, x, mu, rsig, dy, dgamma, dbeta); \
} \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NHWC>( \
_BatchNormInferenceGrad<T, AccT, StorageOrder::NHWC>( \
N, C, S, rsig, gamma, dy, dx); \
} \
}
DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -10,35 +10,39 @@ namespace dragon {
namespace kernel {
#if __CUDA_ARCH__ >= 350
#define LOAD(x, i) __ldg(x + i)
#define LDG(x, i) __ldg(x + i)
#define LDG2(x, i) convert::To<AccT>(__ldg(x + i))
#else
#define LOAD(x, i) x[i]
#define LDG(x, i) x[i]
#define LDG2(x, i) convert::To<AccT>(x[i])
#endif
namespace {
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _BatchNormExpectation(
const int N,
const int C,
const int S,
const Tp denorm,
const Tx* x,
Tp* ex,
Tp* ex2) {
const AccT denorm,
const T* x,
AccT* ex,
AccT* ex2) {
const int outer_dim = N * S;
__shared__ typename BlockReduce<Tp>::TempStorage ex_storage;
__shared__ typename BlockReduce<Tp>::TempStorage ex2_storage;
__shared__ union {
typename BlockReduce<AccT>::TempStorage ex;
typename BlockReduce<AccT>::TempStorage ex2;
} storage;
CUDA_2D_KERNEL_LOOP1(i, C) {
Tp ex_val = Tp(0), ex2_val = Tp(0);
AccT ex_val = AccT(0), ex2_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i;
ex_val += LOAD(x, xi);
ex2_val += utils::math::Square(LOAD(x, xi));
ex_val += LDG2(x, xi);
ex2_val += math::utils::Square(LDG2(x, xi));
}
ex_val = BlockReduce<Tp>(ex_storage).Reduce(ex_val, cub::Sum());
ex2_val = BlockReduce<Tp>(ex2_storage).Reduce(ex2_val, cub::Sum());
ex_val = BlockReduce<AccT>(storage.ex).Reduce(ex_val, cub::Sum());
ex2_val = BlockReduce<AccT>(storage.ex2).Reduce(ex2_val, cub::Sum());
if (threadIdx.x == 0) {
ex[i] = ex_val * denorm;
ex2[i] = ex2_val * denorm;
......@@ -46,31 +50,64 @@ __global__ void _BatchNormExpectation(
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T>
__global__ void _BatchNormFusedParams(
const int C,
const T* mu,
const T* rsig,
const T* gamma,
const T* beta,
T* scale,
T* bias) {
CUDA_1D_KERNEL_LOOP(i, C) {
const T scale_val = scale[i] = gamma[i] * rsig[i];
bias[i] = fma(-scale_val, mu[i], beta[i]);
}
}
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _BatchNormAffine(
const int nthreads,
const int C,
const int S,
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
y[i] = convert::To<T>(
fma(convert::To<AccT>(x[i]), LDG(scale, pi), LDG(bias, pi)));
}
}
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _BatchNormInternalGrad(
const int N,
const int C,
const int S,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tx* dy,
Tp* dgamma,
Tp* dbeta) {
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const T* dy,
AccT* dgamma,
AccT* dbeta) {
const int outer_dim = N * S;
__shared__ typename BlockReduce<Tp>::TempStorage dg_storage;
__shared__ typename BlockReduce<Tp>::TempStorage db_storage;
__shared__ union {
typename BlockReduce<AccT>::TempStorage dg;
typename BlockReduce<AccT>::TempStorage db;
} storage;
CUDA_2D_KERNEL_LOOP1(i, C) {
Tp dg_val = Tp(0), db_val = Tp(0);
AccT dg_val = AccT(0), db_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i;
dg_val += LOAD(dy, xi) * (LOAD(x, xi) - LOAD(mu, i)) * LOAD(rsig, i);
db_val += LOAD(dy, xi);
dg_val += LDG2(dy, xi) * (LDG2(x, xi) - LDG(mu, i)) * LDG(rsig, i);
db_val += LDG2(dy, xi);
}
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
dg_val = BlockReduce<AccT>(storage.dg).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<AccT>(storage.db).Reduce(db_val, cub::Sum());
if (threadIdx.x == 0) {
dgamma[i] = dg_val;
dbeta[i] = db_val;
......@@ -78,53 +115,56 @@ __global__ void _BatchNormInternalGrad(
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _BatchNormTrainingGrad(
const int nthreads,
const int N,
const int C,
const int S,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tp* dgamma,
const Tp* dbeta,
const Tx* dy,
Tx* dx) {
const Tp denom = Tp(1) / Tp(N * S);
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* dgamma,
const AccT* dbeta,
const T* dy,
T* dx) {
const AccT denom = AccT(1) / AccT(N * S);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
const Tp x_norm = (LOAD(x, i) - LOAD(mu, pi)) * LOAD(rsig, pi);
dx[i] = LOAD(gamma, pi) * LOAD(rsig, pi) *
(LOAD(dy, i) - fma(x_norm, LOAD(dgamma, pi), LOAD(dbeta, pi)) * denom);
const AccT xnorm = (LDG2(x, i) - LDG(mu, pi)) * LDG(rsig, pi);
dx[i] = convert::To<T>(
LDG(gamma, pi) * LDG(rsig, pi) *
(LDG2(dy, i) - fma(xnorm, LDG(dgamma, pi), LDG(dbeta, pi)) * denom));
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _BatchNormWGrad(
const int N,
const int C,
const int S,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tx* dy,
Tp* dgamma,
Tp* dbeta) {
const T* x,
const AccT* mu,
const AccT* rsig,
const T* dy,
AccT* dgamma,
AccT* dbeta) {
const int outer_dim = N * S;
__shared__ typename BlockReduce<Tp>::TempStorage dg_storage;
__shared__ typename BlockReduce<Tp>::TempStorage db_storage;
__shared__ union {
typename BlockReduce<AccT>::TempStorage dg;
typename BlockReduce<AccT>::TempStorage db;
} storage;
CUDA_2D_KERNEL_LOOP1(i, C) {
Tp dg_val = Tp(0), db_val = Tp(0);
AccT dg_val = AccT(0), db_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i;
dg_val += LOAD(dy, xi) * (LOAD(x, xi) - LOAD(mu, i)) * LOAD(rsig, i);
db_val += LOAD(dy, xi);
dg_val += LDG2(dy, xi) * (LDG2(x, xi) - LDG(mu, i)) * LDG(rsig, i);
db_val += LDG2(dy, xi);
}
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
dg_val = BlockReduce<AccT>(storage.db).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<AccT>(storage.db).Reduce(db_val, cub::Sum());
if (threadIdx.x == 0) {
dgamma[i] = dg_val;
dbeta[i] = db_val;
......@@ -132,171 +172,223 @@ __global__ void _BatchNormWGrad(
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _BatchNormInferenceGrad(
const int nthreads,
const int C,
const int S,
const Tp* rsig,
const Tp* gamma,
const Tx* dy,
Tx* dx) {
const AccT* rsig,
const AccT* gamma,
const T* dy,
T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
dx[i] = LOAD(gamma, pi) * LOAD(dy, i) * LOAD(rsig, pi);
dx[i] = convert::To<T>(LDG(gamma, pi) * LDG2(dy, i) * LDG(rsig, pi));
}
}
#undef LOAD
#undef LDG
#undef LDG2
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_GRAD_KERNEL_LAUNCHER(Tx, Tp) \
#define DISPATCH_BATCHNORM_KERNEL(name, T, AccT, nblocks, nthreads, ...) \
if (data_format == "NCHW") { \
name<T, AccT, StorageOrder::NCHW> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else if (data_format == "NHWC") { \
name<T, AccT, StorageOrder::NHWC> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else { \
LOG(FATAL) << "Unknown DataFormat: " << data_format; \
}
#define DEFINE_KERNEL_LAUNCHER(T, ScalarT, AccT) \
template <> \
void BatchNormExpectation<Tx, Tp, CUDAContext>( \
void BatchNorm<T, AccT, CUDAContext>( \
const int N, \
const int C, \
const int S, \
const Tp denorm, \
const string& data_format, \
const Tx* x, \
Tp* ex, \
Tp* ex2, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \
CUDAContext* ctx) { \
if (data_format == "NCHW") { \
_BatchNormExpectation<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_2D_BLOCKS(C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, denorm, x, ex, ex2); \
} else if (data_format == "NHWC") { \
_BatchNormExpectation<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_2D_BLOCKS(C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, denorm, x, ex, ex2); \
} \
} \
const auto nthreads = N * C * S; \
_BatchNormFusedParams<<< \
CUDA_BLOCKS(C), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(C, mu, rsig, gamma, beta, scale, bias); \
DISPATCH_BATCHNORM_KERNEL( \
_BatchNormAffine, \
ScalarT, \
AccT, \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
nthreads, \
C, \
S, \
reinterpret_cast<const ScalarT*>(x), \
scale, \
bias, \
reinterpret_cast<ScalarT*>(y)); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T, ScalarT, AccT) \
template <> \
void BatchNormInternalGrad<Tx, Tp, CUDAContext>( \
void BatchNormExpectation<T, AccT, CUDAContext>( \
const int N, \
const int C, \
const int S, \
const AccT denorm, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tx* dy, \
Tp* dgamma, \
Tp* dbeta, \
const T* x, \
AccT* ex, \
AccT* ex2, \
CUDAContext* ctx) { \
if (data_format == "NCHW") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_2D_BLOCKS(C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, x, mu, rsig, gamma, dy, dgamma, dbeta); \
} else if (data_format == "NHWC") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_2D_BLOCKS(C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, x, mu, rsig, gamma, dy, dgamma, dbeta); \
} \
DISPATCH_BATCHNORM_KERNEL( \
_BatchNormExpectation, \
ScalarT, \
AccT, \
CUDA_2D_BLOCKS(C), \
CUDA_THREADS, \
N, \
C, \
S, \
denorm, \
reinterpret_cast<const ScalarT*>(x), \
ex, \
ex2); \
} \
template <> \
void BatchNormTrainingGrad<Tx, Tp, CUDAContext>( \
void BatchNormInternalGrad<T, AccT, CUDAContext>( \
const int N, \
const int C, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tp* dgamma, \
const Tp* dbeta, \
const Tx* dy, \
Tx* dx, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const T* dy, \
AccT* dgamma, \
AccT* dbeta, \
CUDAContext* ctx) { \
const int nthreads = N * C * S; \
if (data_format == "NCHW") { \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, N, C, S, x, mu, rsig, gamma, dgamma, dbeta, dy, dx); \
} else if (data_format == "NHWC") { \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, N, C, S, x, mu, rsig, gamma, dgamma, dbeta, dy, dx); \
} \
DISPATCH_BATCHNORM_KERNEL( \
_BatchNormInternalGrad, \
ScalarT, \
AccT, \
CUDA_2D_BLOCKS(C), \
CUDA_THREADS, \
N, \
C, \
S, \
reinterpret_cast<const ScalarT*>(x), \
mu, \
rsig, \
gamma, \
reinterpret_cast<const ScalarT*>(dy), \
dgamma, \
dbeta); \
} \
template <> \
void BatchNormBackwardTraining<Tx, Tp, CUDAContext>( \
void BatchNormTrainingGrad<T, AccT, CUDAContext>( \
const int N, \
const int C, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tx* dy, \
Tx* dx, \
Tp* dgamma, \
Tp* dbeta, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const AccT* dgamma, \
const AccT* dbeta, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
const int nthreads = N * C * S; \
if (data_format == "NCHW") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_2D_BLOCKS(C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, x, mu, rsig, gamma, dy, dgamma, dbeta); \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, N, C, S, x, mu, rsig, gamma, dgamma, dbeta, dy, dx); \
} else if (data_format == "NHWC") { \
_BatchNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_2D_BLOCKS(C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, x, mu, rsig, gamma, dy, dgamma, dbeta); \
_BatchNormTrainingGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, N, C, S, x, mu, rsig, gamma, dgamma, dbeta, dy, dx); \
} \
const auto nthreads = N * C * S; \
DISPATCH_BATCHNORM_KERNEL( \
_BatchNormTrainingGrad, \
ScalarT, \
AccT, \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
nthreads, \
N, \
C, \
S, \
reinterpret_cast<const ScalarT*>(x), \
mu, \
rsig, \
gamma, \
dgamma, \
dbeta, \
reinterpret_cast<const ScalarT*>(dy), \
reinterpret_cast<ScalarT*>(dx)); \
} \
template <> \
void BatchNormBackwardInference<Tx, Tp, CUDAContext>( \
void BatchNormInferenceGrad<T, AccT, CUDAContext>( \
const int N, \
const int C, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tx* dy, \
Tx* dx, \
Tp* dgamma, \
Tp* dbeta, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const T* dy, \
AccT* dgamma, \
AccT* dbeta, \
T* dx, \
CUDAContext* ctx) { \
const int nthreads = N * C * S; \
if (data_format == "NCHW") { \
const auto nthreads = N * C * S; \
if (dgamma != nullptr) { \
_BatchNormWGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_2D_BLOCKS(C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, x, mu, rsig, dy, dgamma, dbeta); \
} \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, C, S, rsig, gamma, dy, dx); \
} else if (data_format == "NHWC") { \
if (dgamma != nullptr) { \
_BatchNormWGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_2D_BLOCKS(C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, x, mu, rsig, dy, dgamma, dbeta); \
} \
_BatchNormInferenceGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, C, S, rsig, gamma, dy, dx); \
DISPATCH_BATCHNORM_KERNEL( \
_BatchNormWGrad, \
ScalarT, \
AccT, \
CUDA_2D_BLOCKS(C), \
CUDA_THREADS, \
N, \
C, \
S, \
reinterpret_cast<const ScalarT*>(x), \
mu, \
rsig, \
reinterpret_cast<const ScalarT*>(dy), \
dgamma, \
dbeta); \
} \
DISPATCH_BATCHNORM_KERNEL( \
_BatchNormInferenceGrad, \
ScalarT, \
AccT, \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
nthreads, \
C, \
S, \
rsig, \
gamma, \
reinterpret_cast<const ScalarT*>(dy), \
reinterpret_cast<ScalarT*>(dx)); \
}
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(float16, half, float);
DEFINE_KERNEL_LAUNCHER(float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
#undef DISPATCH_BATCHNORM_KERNEL
} // namespace kernel
......
#include "dragon/core/memory.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -33,49 +33,49 @@ void _GroupNormFusedParams(
}
}
template <typename Tx, typename Tp>
void _GroupNormForwardNCHW(
template <typename T, typename AccT>
void _GroupNormNCHW(
const int N,
const int C,
const int S,
const Tx* x,
const Tp* scale,
const Tp* bias,
Tx* y) {
EigenArrayMap<Tx>(y, S, N * C) =
(ConstEigenArrayMap<Tx>(x, S, N * C).rowwise() *
ConstEigenVectorArrayMap<Tp>(scale, N * C).transpose())
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
EigenArrayMap<T>(y, S, N * C) =
(ConstEigenArrayMap<T>(x, S, N * C).rowwise() *
ConstEigenVectorArrayMap<AccT>(scale, N * C).transpose())
.rowwise() +
ConstEigenVectorArrayMap<Tp>(bias, N * C).transpose();
ConstEigenVectorArrayMap<AccT>(bias, N * C).transpose();
}
template <typename Tx, typename Tp>
void _GroupNormForwardNHWC(
template <typename T, typename AccT>
void _GroupNormNHWC(
const int N,
const int C,
const int S,
const Tx* x,
const Tp* scale,
const Tp* bias,
Tx* y) {
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
const int SC = S * C;
for (int i = 0; i < N; ++i) {
EigenArrayMap<Tx>(y + i * SC, C, S) =
(ConstEigenArrayMap<Tx>(x + i * SC, C, S).colwise() *
ConstEigenVectorArrayMap<Tp>(scale + i * C, C))
EigenArrayMap<T>(y + i * SC, C, S) =
(ConstEigenArrayMap<T>(x + i * SC, C, S).colwise() *
ConstEigenVectorArrayMap<AccT>(scale + i * C, C))
.colwise() +
ConstEigenVectorArrayMap<Tp>(bias + i * C, C);
ConstEigenVectorArrayMap<AccT>(bias + i * C, C);
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
void _GroupNormInternalGrad(
const std::array<int, 4>& dims,
const Tx* x,
const Tp* gamma,
const Tx* dy,
Tp* ds,
Tp* db) {
const T* x,
const AccT* gamma,
const T* dy,
AccT* ds,
AccT* db) {
const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int count = dims[0] * dims[1] * dims[2] * dims[3];
......@@ -85,39 +85,39 @@ void _GroupNormInternalGrad(
const int gi = idx[kGDim] * dims[kDDim] + idx[kDDim];
ds[mi] += gamma[gi] * dy[i] * x[i];
db[mi] += gamma[gi] * dy[i];
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
void _GroupNormGrad(
const std::array<int, 4>& dims,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tp* ds,
const Tp* db,
const Tx* dy,
Tx* dx,
Tp* dgamma,
Tp* dbeta) {
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* ds,
const AccT* db,
const T* dy,
AccT* dgamma,
AccT* dbeta,
T* dx) {
const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int count = dims[0] * dims[1] * dims[2] * dims[3];
const int S = kOrder == StorageOrder::NCHW ? dims[3] : dims[1];
const Tp denom = Tp(1) / static_cast<Tp>(dims[kDDim] * S);
const AccT denom = AccT(1) / static_cast<AccT>(dims[kDDim] * S);
std::array<int, 4> idx = {0, 0, 0, 0};
for (int i = 0; i < count; ++i) {
const int mi = idx[0] * dims[kGDim] + idx[kGDim];
const int gi = idx[kGDim] * dims[kDDim] + idx[kDDim];
const Tp u = (db[mi] * mu[mi] - ds[mi]) * (x[i] - mu[mi]) *
utils::math::Cube(rsig[mi]);
const Tp v = db[mi] * rsig[mi];
const AccT u = (db[mi] * mu[mi] - ds[mi]) * (x[i] - mu[mi]) *
math::utils::Cube(rsig[mi]);
const AccT v = db[mi] * rsig[mi];
dx[i] = gamma[gi] * dy[i] * rsig[mi] + (u - v) * denom;
dgamma[gi] += dy[i] * (x[i] - mu[mi]) * rsig[mi];
dbeta[gi] += dy[i];
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -126,7 +126,7 @@ void _GroupNormGrad(
/* ------------------- Launcher Separator ------------------- */
template <>
void GroupNormForward<float16, float, CPUContext>(
void GroupNorm<float16, float, CPUContext>(
const int N,
const int G,
const int D,
......@@ -145,7 +145,7 @@ void GroupNormForward<float16, float, CPUContext>(
}
template <>
void GroupNormBackward<float16, float, CPUContext>(
void GroupNormGrad<float16, float, CPUContext>(
const int N,
const int G,
const int D,
......@@ -158,78 +158,77 @@ void GroupNormBackward<float16, float, CPUContext>(
const float16* dy,
float* ds,
float* db,
float16* dx,
float* dgamma,
float* dbeta,
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
} // GroupNormBackward
#define DEFINE_KERNEL_LAUNCHER(Tx, Tp) \
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \
void GroupNormForward<Tx, Tp, CPUContext>( \
void GroupNorm<T, AccT, CPUContext>( \
const int N, \
const int G, \
const int D, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tp* beta, \
Tp* scale, \
Tp* bias, \
Tx* y, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \
CPUContext* ctx) { \
const int C = G * D; \
_GroupNormFusedParams<Tp>(N, G, D, mu, rsig, gamma, beta, scale, bias); \
_GroupNormFusedParams(N, G, D, mu, rsig, gamma, beta, scale, bias); \
if (data_format == "NCHW") { \
_GroupNormForwardNCHW<Tx, Tp>(N, C, S, x, scale, bias, y); \
_GroupNormNCHW(N, C, S, x, scale, bias, y); \
} else if (data_format == "NHWC") { \
_GroupNormForwardNHWC<Tx, Tp>(N, C, S, x, scale, bias, y); \
_GroupNormNHWC(N, C, S, x, scale, bias, y); \
} \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(Tx, Tp) \
#define DEFINE_GRAD_KERNEL_LAUNCHER(T, AccT) \
template <> \
void GroupNormBackward<Tx, Tp, CPUContext>( \
void GroupNormGrad<T, AccT, CPUContext>( \
const int N, \
const int G, \
const int D, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tx* dy, \
Tp* ds, \
Tp* db, \
Tx* dx, \
Tp* dgamma, \
Tp* dbeta, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const T* dy, \
AccT* ds, \
AccT* db, \
AccT* dgamma, \
AccT* dbeta, \
T* dx, \
CPUContext* ctx) { \
math::Set(N* G, Tp(0), ds, ctx); \
math::Set(N* G, Tp(0), db, ctx); \
math::Set(G* D, Tp(0), dgamma, ctx); \
math::Set(G* D, Tp(0), dbeta, ctx); \
math::Set(N* G, AccT(0), ds, ctx); \
math::Set(N* G, AccT(0), db, ctx); \
math::Set(G* D, AccT(0), dgamma, ctx); \
math::Set(G* D, AccT(0), dbeta, ctx); \
if (data_format == "NCHW") { \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NCHW>( \
_GroupNormInternalGrad<T, AccT, StorageOrder::NCHW>( \
{N, G, D, S}, x, gamma, dy, ds, db); \
_GroupNormGrad<Tx, Tp, StorageOrder::NCHW>( \
{N, G, D, S}, x, mu, rsig, gamma, ds, db, dy, dx, dgamma, dbeta); \
_GroupNormGrad<T, AccT, StorageOrder::NCHW>( \
{N, G, D, S}, x, mu, rsig, gamma, ds, db, dy, dgamma, dbeta, dx); \
} else if (data_format == "NHWC") { \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NHWC>( \
_GroupNormInternalGrad<T, AccT, StorageOrder::NHWC>( \
{N, S, G, D}, x, gamma, dy, ds, db); \
_GroupNormGrad<Tx, Tp, StorageOrder::NHWC>( \
{N, S, G, D}, x, mu, rsig, gamma, ds, db, dy, dx, dgamma, dbeta); \
_GroupNormGrad<T, AccT, StorageOrder::NHWC>( \
{N, S, G, D}, x, mu, rsig, gamma, ds, db, dy, dgamma, dbeta, dx); \
} \
}
DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
......
......@@ -10,11 +10,11 @@ namespace dragon {
namespace kernel {
#if __CUDA_ARCH__ >= 350
#define LOAD(x, i) __ldg(x + i)
#define LOADF(x, i) __half2float(__ldg(x + i))
#define LDG(x, i) __ldg(x + i)
#define LDG2(x, i) convert::To<AccT>(__ldg(x + i))
#else
#define LOAD(x, i) x[i]
#define LOADF(x, i) __half2float(x[i])
#define LDG(x, i) x[i]
#define LDG2(x, i) convert::To<AccT>(x[i])
#endif
namespace {
......@@ -33,127 +33,89 @@ __global__ void _GroupNormFusedParams(
const int outer_dim = N * G;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const int g = i % G;
const T mu_val = LOAD(mu, i);
const T rsig_val = LOAD(rsig, i);
const T mu_val = LDG(mu, i);
const T rsig_val = LDG(rsig, i);
CUDA_2D_KERNEL_LOOP2(j, D) {
const int wi = i * D + j;
const int gi = g * D + j;
const T w = LOAD(gamma, gi) * rsig_val;
const T w = LDG(gamma, gi) * rsig_val;
scale[wi] = w;
bias[wi] = fma(-w, mu_val, LOAD(beta, gi));
bias[wi] = fma(-w, mu_val, LDG(beta, gi));
}
}
}
template <typename Tx, typename Tp>
__global__ void _GroupNormForwardNCHW(
template <typename T, typename AccT>
__global__ void _GroupNormAffineNCHW(
const int N,
const int C,
const int S,
const Tx* x,
const Tp* scale,
const Tp* bias,
Tx* y) {
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
const int outer_dim = N * C;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const Tp w = LOAD(scale, i);
const Tp b = LOAD(bias, i);
const AccT w = LDG(scale, i);
const AccT b = LDG(bias, i);
CUDA_2D_KERNEL_LOOP2(j, S) {
const int xi = i * S + j;
y[xi] = fma(LOAD(x, xi), w, b);
y[xi] = convert::To<AccT>(fma(LDG2(x, xi), w, b));
}
}
}
template <>
__global__ void _GroupNormForwardNCHW<half, float>(
template <typename T, typename AccT>
__global__ void _GroupNormAffineNHWC(
const int N,
const int C,
const int S,
const half* x,
const float* scale,
const float* bias,
half* y) {
const int outer_dim = N * C;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const float w = LOAD(scale, i);
const float b = LOAD(bias, i);
CUDA_2D_KERNEL_LOOP2(j, S) {
const int xi = i * S + j;
y[xi] = __float2half(fmaf(LOADF(x, xi), w, b));
}
}
}
template <typename Tx, typename Tp>
__global__ void _GroupNormForwardNHWC(
const int N,
const int C,
const int S,
const Tx* x,
const Tp* scale,
const Tp* bias,
Tx* y) {
const int outer_dim = N * S;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const int n = i / S;
CUDA_2D_KERNEL_LOOP2(j, C) {
const int xi = i * C + j;
const int wi = n * C + j;
y[xi] = fma(LOAD(x, xi), LOAD(scale, wi), LOAD(bias, wi));
}
}
}
template <>
__global__ void _GroupNormForwardNHWC<half, float>(
const int N,
const int C,
const int S,
const half* x,
const float* scale,
const float* bias,
half* y) {
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
const int outer_dim = N * S;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const int n = i / S;
CUDA_2D_KERNEL_LOOP2(j, C) {
const int xi = i * C + j;
const int wi = n * C + j;
y[xi] = __float2half(fmaf(LOADF(x, xi), LOAD(scale, wi), LOAD(bias, wi)));
y[xi] = convert::To<T>(fma(LDG2(x, xi), LDG(scale, wi), LDG(bias, wi)));
}
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _GroupNormWGrad(
const int N,
const int G,
const int D,
const int S,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tx* dy,
Tp* dgamma,
Tp* dbeta) {
const T* x,
const AccT* mu,
const AccT* rsig,
const T* dy,
AccT* dgamma,
AccT* dbeta) {
const int outer_dim = G * D;
const int inner_dim = N * S;
__shared__ typename BlockReduce<Tp>::TempStorage dg_storage;
__shared__ typename BlockReduce<Tp>::TempStorage db_storage;
__shared__ union {
typename BlockReduce<AccT>::TempStorage dg;
typename BlockReduce<AccT>::TempStorage db;
} storage;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
Tp dg_val = Tp(0), db_val = Tp(0);
AccT dg_val = AccT(0), db_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, inner_dim) {
const int n = j / S;
const int xi = kOrder == StorageOrder::NCHW
? (n * outer_dim + i) * S + j % S
: j * outer_dim + i;
const int mi = n * G + i / D;
dg_val += LOAD(dy, xi) * (LOAD(x, xi) - LOAD(mu, mi)) * LOAD(rsig, mi);
db_val += LOAD(dy, xi);
dg_val += LDG2(dy, xi) * (LDG2(x, xi) - LDG(mu, mi)) * LDG(rsig, mi);
db_val += LDG2(dy, xi);
}
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
dg_val = BlockReduce<AccT>(storage.dg).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<AccT>(storage.db).Reduce(db_val, cub::Sum());
if (threadIdx.x == 0) {
dgamma[i] = dg_val;
dbeta[i] = db_val;
......@@ -161,69 +123,35 @@ __global__ void _GroupNormWGrad(
}
}
template <StorageOrder kOrder>
__global__ void _GroupNormWGradHalf(
const int N,
const int G,
const int D,
const int S,
const half* x,
const float* mu,
const float* rsig,
const half* dy,
float* dgamma,
float* dbeta) {
const int outer_dim = G * D;
const int inner_dim = N * S;
__shared__ typename BlockReduce<float>::TempStorage dg_storage;
__shared__ typename BlockReduce<float>::TempStorage db_storage;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
float dg_val = 0.f, db_val = 0.f;
CUDA_2D_KERNEL_LOOP2(j, inner_dim) {
const int n = j / S;
const int xi = kOrder == StorageOrder::NCHW
? (n * outer_dim + i) * S + j % S
: j * outer_dim + i;
const int mi = n * G + i / D;
dg_val += LOADF(dy, xi) * (LOADF(x, xi) - LOAD(mu, mi)) * LOAD(rsig, mi);
db_val += LOADF(dy, xi);
}
dg_val = BlockReduce<float>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<float>(db_storage).Reduce(db_val, cub::Sum());
if (threadIdx.x == 0) {
dgamma[i] = dg_val;
dbeta[i] = db_val;
}
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _GroupNormInternalGrad(
const int N,
const int G,
const int D,
const int S,
const Tx* x,
const Tp* gamma,
const Tx* dy,
Tp* ds,
Tp* db) {
const T* x,
const AccT* gamma,
const T* dy,
AccT* ds,
AccT* db) {
const int outer_dim = N * G;
const int inner_dim = D * S;
__shared__ typename BlockReduce<Tp>::TempStorage ds_storage;
__shared__ typename BlockReduce<Tp>::TempStorage db_storage;
__shared__ union {
typename BlockReduce<AccT>::TempStorage ds;
typename BlockReduce<AccT>::TempStorage db;
} storage;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
Tp ds_val = Tp(0), db_val = Tp(0);
AccT ds_val = AccT(0), db_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, inner_dim) {
const int gi = i % G * D + j / S;
const int xi = kOrder == StorageOrder::NCHW
? i * inner_dim + j
: (i / G * S + j % S) * G * D + gi;
ds_val += LOAD(gamma, gi) * LOAD(dy, xi) * LOAD(x, xi);
db_val += LOAD(gamma, gi) * LOAD(dy, xi);
ds_val += LDG(gamma, gi) * LDG2(dy, xi) * LDG2(x, xi);
db_val += LDG(gamma, gi) * LDG2(dy, xi);
}
ds_val = BlockReduce<Tp>(ds_storage).Reduce(ds_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
ds_val = BlockReduce<AccT>(storage.ds).Reduce(ds_val, cub::Sum());
db_val = BlockReduce<AccT>(storage.db).Reduce(db_val, cub::Sum());
if (threadIdx.x == 0) {
ds[i] = ds_val;
db[i] = db_val;
......@@ -231,323 +159,182 @@ __global__ void _GroupNormInternalGrad(
}
}
template <StorageOrder kOrder>
__global__ void _GroupNormInternalGradHalf(
const int N,
const int G,
const int D,
const int S,
const half* x,
const float* gamma,
const half* dy,
float* ds,
float* db) {
const int outer_dim = N * G;
const int inner_dim = D * S;
__shared__ typename BlockReduce<float>::TempStorage ds_storage;
__shared__ typename BlockReduce<float>::TempStorage db_storage;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
float ds_val = 0.f, db_val = 0.f;
CUDA_2D_KERNEL_LOOP2(j, inner_dim) {
const int gi = i % G * D + j / S;
const int xi = kOrder == StorageOrder::NCHW
? i * inner_dim + j
: (i / G * S + j % S) * G * D + gi;
ds_val += LOAD(gamma, gi) * LOADF(dy, xi) * LOADF(x, xi);
db_val += LOAD(gamma, gi) * LOADF(dy, xi);
}
ds_val = BlockReduce<float>(ds_storage).Reduce(ds_val, cub::Sum());
db_val = BlockReduce<float>(db_storage).Reduce(db_val, cub::Sum());
if (threadIdx.x == 0) {
ds[i] = ds_val;
db[i] = db_val;
}
}
}
template <typename Tx, typename Tp, StorageOrder kOrder>
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _GroupNormGrad(
const int nthreads,
const int G,
const int D,
const int S,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tp* ds,
const Tp* db,
const Tx* dy,
Tx* dx) {
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* ds,
const AccT* db,
const T* dy,
T* dx) {
const int C = G * D;
const Tp denom = Tp(1) / Tp(D * S);
const AccT denom = AccT(1) / AccT(D * S);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int mi = kOrder == StorageOrder::NCHW ? i / (D * S)
: i / (C * S) * G + (i / D % G);
const int gi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
const Tp u = fma(LOAD(db, mi), LOAD(mu, mi), -LOAD(ds, mi)) *
(LOAD(x, i) - LOAD(mu, mi)) * utils::math::Cube(LOAD(rsig, mi));
const Tp v = LOAD(db, mi) * LOAD(rsig, mi);
dx[i] = LOAD(gamma, gi) * LOAD(dy, i) * LOAD(rsig, mi) + (u - v) * denom;
const AccT u = fma(LDG(db, mi), LDG(mu, mi), -LDG(ds, mi)) *
(LDG2(x, i) - LDG(mu, mi)) * math::utils::Cube(LDG(rsig, mi));
const AccT v = LDG(db, mi) * LDG(rsig, mi);
dx[i] = convert::To<T>(
LDG(gamma, gi) * LDG2(dy, i) * LDG(rsig, mi) + (u - v) * denom);
}
}
template <StorageOrder kOrder>
__global__ void _GroupNormGradHalf(
const int nthreads,
const int G,
const int D,
const int S,
const half* x,
const float* mu,
const float* rsig,
const float* gamma,
const float* ds,
const float* db,
const half* dy,
half* dx) {
const int C = G * D;
const float denom = 1.f / float(D * S);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int mi = kOrder == StorageOrder::NCHW ? i / (D * S)
: i / (C * S) * G + (i / D % G);
const int gi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
const float u = fmaf(LOAD(db, mi), LOAD(mu, mi), -LOAD(ds, mi)) *
(LOADF(x, i) - LOAD(mu, mi)) * utils::math::Cube(LOAD(rsig, mi));
const float v = LOAD(db, mi) * LOAD(rsig, mi);
dx[i] = __float2half(
LOAD(gamma, gi) * LOADF(dy, i) * LOAD(rsig, mi) + (u - v) * denom);
}
}
#undef LOAD
#undef LOADF
#undef LDG
#undef LDG2
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void GroupNormForward<float16, float, CUDAContext>(
const int N,
const int G,
const int D,
const int S,
const string& data_format,
const float16* x,
const float* mu,
const float* rsig,
const float* gamma,
const float* beta,
float* scale,
float* bias,
float16* y,
CUDAContext* ctx) {
const int C = G * D;
_GroupNormFusedParams<float>
<<<CUDA_2D_BLOCKS(N * G), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, G, D, mu, rsig, gamma, beta, scale, bias);
if (data_format == "NCHW") {
_GroupNormForwardNCHW<half, float>
<<<CUDA_2D_BLOCKS(N * C), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N,
C,
S,
reinterpret_cast<const half*>(x),
scale,
bias,
reinterpret_cast<half*>(y));
} else if (data_format == "NHWC") {
_GroupNormForwardNHWC<half, float>
<<<CUDA_2D_BLOCKS(N * C), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N,
C,
S,
reinterpret_cast<const half*>(x),
scale,
bias,
reinterpret_cast<half*>(y));
}
}
template <>
void GroupNormBackward<float16, float, CUDAContext>(
const int N,
const int G,
const int D,
const int S,
const string& data_format,
const float16* x,
const float* mu,
const float* rsig,
const float* gamma,
const float16* dy,
float* ds,
float* db,
float16* dx,
float* dgamma,
float* dbeta,
CUDAContext* ctx) {
auto nthreads = N * G * D * S;
if (data_format == "NCHW") {
_GroupNormWGradHalf<StorageOrder::NCHW>
<<<CUDA_2D_BLOCKS(G * D), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N,
G,
D,
S,
reinterpret_cast<const half*>(x),
mu,
rsig,
reinterpret_cast<const half*>(dy),
dgamma,
dbeta);
_GroupNormInternalGradHalf<StorageOrder::NCHW>
<<<CUDA_2D_BLOCKS(N * G), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N,
G,
D,
S,
reinterpret_cast<const half*>(x),
gamma,
reinterpret_cast<const half*>(dy),
ds,
db);
_GroupNormGradHalf<StorageOrder::NCHW>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
G,
D,
S,
reinterpret_cast<const half*>(x),
mu,
rsig,
gamma,
ds,
db,
reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx));
} else if (data_format == "NHWC") {
_GroupNormWGradHalf<StorageOrder::NHWC>
<<<CUDA_2D_BLOCKS(G * D), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N,
G,
D,
S,
reinterpret_cast<const half*>(x),
mu,
rsig,
reinterpret_cast<const half*>(dy),
dgamma,
dbeta);
_GroupNormInternalGradHalf<StorageOrder::NHWC>
<<<CUDA_2D_BLOCKS(N * G), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N,
G,
D,
S,
reinterpret_cast<const half*>(x),
gamma,
reinterpret_cast<const half*>(dy),
ds,
db);
_GroupNormGradHalf<StorageOrder::NHWC>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
G,
D,
S,
reinterpret_cast<const half*>(x),
mu,
rsig,
gamma,
ds,
db,
reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx));
#define DISPATCH_GROUPNORM_KERNEL(name, T, AccT, nblocks, nthreads, ...) \
if (data_format == "NCHW") { \
name<T, AccT, StorageOrder::NCHW> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else if (data_format == "NHWC") { \
name<T, AccT, StorageOrder::NHWC> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else { \
LOG(FATAL) << "Unknown DataFormat: " << data_format; \
}
} // GroupNormBackward
#define DEFINE_KERNEL_LAUNCHER(Tx, Tp) \
#define DEFINE_KERNEL_LAUNCHER(T, ScalarT, AccT) \
template <> \
void GroupNormForward<Tx, Tp, CUDAContext>( \
void GroupNorm<T, AccT, CUDAContext>( \
const int N, \
const int G, \
const int D, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tp* beta, \
Tp* scale, \
Tp* bias, \
Tx* y, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \
CUDAContext* ctx) { \
const int C = G * D; \
_GroupNormFusedParams<Tp> \
<<<CUDA_2D_BLOCKS(N* G), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, G, D, mu, rsig, gamma, beta, scale, bias); \
_GroupNormFusedParams<<< \
CUDA_2D_BLOCKS(N* G), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(N, G, D, mu, rsig, gamma, beta, scale, bias); \
if (data_format == "NCHW") { \
_GroupNormForwardNCHW<Tx, Tp> \
<<<CUDA_2D_BLOCKS(N* C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, x, scale, bias, y); \
_GroupNormAffineNCHW<<< \
CUDA_2D_BLOCKS(N* C), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N, \
C, \
S, \
reinterpret_cast<const ScalarT*>(x), \
scale, \
bias, \
reinterpret_cast<ScalarT*>(y)); \
} else if (data_format == "NHWC") { \
_GroupNormForwardNHWC<Tx, Tp> \
<<<CUDA_2D_BLOCKS(N* C), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, C, S, x, scale, bias, y); \
_GroupNormAffineNHWC<<< \
CUDA_2D_BLOCKS(N* C), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
N, \
C, \
S, \
reinterpret_cast<const ScalarT*>(x), \
scale, \
bias, \
reinterpret_cast<ScalarT*>(y)); \
} \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(Tx, Tp) \
#define DEFINE_GRAD_KERNEL_LAUNCHER(T, ScalarT, AccT) \
template <> \
void GroupNormBackward<Tx, Tp, CUDAContext>( \
void GroupNormGrad<T, AccT, CUDAContext>( \
const int N, \
const int G, \
const int D, \
const int S, \
const string& data_format, \
const Tx* x, \
const Tp* mu, \
const Tp* rsig, \
const Tp* gamma, \
const Tx* dy, \
Tp* ds, \
Tp* db, \
Tx* dx, \
Tp* dgamma, \
Tp* dbeta, \
const T* x, \
const AccT* mu, \
const AccT* rsig, \
const AccT* gamma, \
const T* dy, \
AccT* ds, \
AccT* db, \
AccT* dgamma, \
AccT* dbeta, \
T* dx, \
CUDAContext* ctx) { \
auto nthreads = N * G * D * S; \
if (data_format == "NCHW") { \
_GroupNormWGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_2D_BLOCKS(G* D), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, G, D, S, x, mu, rsig, dy, dgamma, dbeta); \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_2D_BLOCKS(N* G), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, G, D, S, x, gamma, dy, ds, db); \
_GroupNormGrad<Tx, Tp, StorageOrder::NCHW> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, G, D, S, x, mu, rsig, gamma, ds, db, dy, dx); \
} else if (data_format == "NHWC") { \
_GroupNormWGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_2D_BLOCKS(G* D), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, G, D, S, x, mu, rsig, dy, dgamma, dbeta); \
_GroupNormInternalGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_2D_BLOCKS(N* G), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, G, D, S, x, gamma, dy, ds, db); \
_GroupNormGrad<Tx, Tp, StorageOrder::NHWC> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, G, D, S, x, mu, rsig, gamma, ds, db, dy, dx); \
} \
DISPATCH_GROUPNORM_KERNEL( \
_GroupNormWGrad, \
ScalarT, \
AccT, \
CUDA_2D_BLOCKS(G* D), \
CUDA_THREADS, \
N, \
G, \
D, \
S, \
reinterpret_cast<const ScalarT*>(x), \
mu, \
rsig, \
reinterpret_cast<const ScalarT*>(dy), \
dgamma, \
dbeta); \
DISPATCH_GROUPNORM_KERNEL( \
_GroupNormInternalGrad, \
ScalarT, \
AccT, \
CUDA_2D_BLOCKS(N* G), \
CUDA_THREADS, \
N, \
G, \
D, \
S, \
reinterpret_cast<const ScalarT*>(x), \
gamma, \
reinterpret_cast<const ScalarT*>(dy), \
ds, \
db); \
DISPATCH_GROUPNORM_KERNEL( \
_GroupNormGrad, \
ScalarT, \
AccT, \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
nthreads, \
G, \
D, \
S, \
reinterpret_cast<const ScalarT*>(x), \
mu, \
rsig, \
gamma, \
ds, \
db, \
reinterpret_cast<const ScalarT*>(dy), \
reinterpret_cast<ScalarT*>(dx)); \
}
DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(float16, half, float);
DEFINE_KERNEL_LAUNCHER(float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
#undef DISPATCH_GROUPNORM_KERNEL
} // namespace kernel
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -11,142 +11,86 @@ namespace kernel {
namespace {
template <typename T>
template <typename T, typename AccT>
__global__ void _L1Normalize(
const int nblocks,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const AccT scale,
const AccT eps,
const T* x,
T* y) {
__shared__ T norm;
__shared__ typename BlockReduce<T>::TempStorage storage;
__shared__ AccT norm;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, nblocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
T sum = T(0);
AccT sum = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
sum += abs(x[offset + j * inner_dim]);
sum += abs(convert::To<AccT>(x[offset + j * inner_dim]));
}
sum = BlockReduce<T>(storage).Sum(sum);
sum = BlockReduce<AccT>(storage).Sum(sum);
if (threadIdx.x == 0) {
norm = max(sum * scale, eps);
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
y[idx] = x[idx] / norm;
y[idx] = convert::To<T>(convert::To<AccT>(x[idx]) / norm);
}
}
}
__global__ void _L1Normalize(
const int nblocks,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const half* x,
half* y) {
__shared__ float norm;
__shared__ typename BlockReduce<float>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, nblocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
float sum = 0.f;
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
sum += abs(__half2float(x[offset + j * inner_dim]));
}
sum = BlockReduce<float>(storage).Sum(sum);
if (threadIdx.x == 0) {
norm = max(sum * scale, eps);
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
y[idx] = __float2half(__half2float(x[idx]) / norm);
}
}
}
template <typename T>
template <typename T, typename AccT>
__global__ void _L2Normalize(
const int nblocks,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const AccT scale,
const AccT eps,
const T* x,
T* y) {
__shared__ T norm;
__shared__ typename BlockReduce<T>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, nblocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
T sum = T(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
sum += utils::math::Square(x[offset + j * inner_dim]);
}
sum = BlockReduce<T>(storage).Sum(sum);
if (threadIdx.x == 0) {
norm = max(sqrt(sum * scale), eps);
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
y[idx] = x[idx] / norm;
}
}
}
__global__ void _L2Normalize(
const int nblocks,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const half* x,
half* y) {
__shared__ float norm;
__shared__ typename BlockReduce<float>::TempStorage storage;
__shared__ AccT norm;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, nblocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
float sum = 0.f;
AccT sum = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
sum += utils::math::Square(__half2float(x[offset + j * inner_dim]));
sum += math::utils::Square(convert::To<AccT>(x[offset + j * inner_dim]));
}
sum = BlockReduce<float>(storage).Sum(sum);
sum = BlockReduce<AccT>(storage).Sum(sum);
if (threadIdx.x == 0) {
norm = max(sqrt(sum * scale), eps);
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
y[idx] = __float2half(__half2float(x[idx]) / norm);
y[idx] = convert::To<T>(convert::To<AccT>(x[idx]) / norm);
}
}
}
template <typename T>
template <typename T, typename AccT>
__global__ void _L1NormalizeGrad(
const int nblocks,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const AccT scale,
const AccT eps,
const T* dy,
const T* x,
T* dx) {
__shared__ T norm, norm2, sum;
__shared__ typename BlockReduce<T>::TempStorage storage;
__shared__ AccT norm, norm2, sum;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, nblocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
T val1 = T(0), val2 = T(0);
AccT val1 = AccT(0), val2 = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
val1 += abs(x[idx]);
val2 += dy[idx] * x[idx];
val1 += abs(convert::To<AccT>(x[idx]));
val2 += convert::To<AccT>(dy[idx]) * convert::To<AccT>(x[idx]);
}
val1 = BlockReduce<T>(storage).Sum(val1);
val2 = BlockReduce<T>(storage).Sum(val2);
val1 = BlockReduce<AccT>(storage).Sum(val1);
val2 = BlockReduce<AccT>(storage).Sum(val2);
if (threadIdx.x == 0) {
norm = max(val1 * scale, eps);
norm2 = pow(norm, 2);
......@@ -155,103 +99,35 @@ __global__ void _L1NormalizeGrad(
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
dx[idx] = (dy[idx] / norm) - ((utils::math::Sign(x[idx]) / norm2) * sum);
dx[idx] = convert::To<T>(
(convert::To<AccT>(dy[idx]) / norm) -
((math::utils::Sign(convert::To<AccT>(x[idx])) / norm2) * sum));
}
}
}
__global__ void _L1NormalizeGrad(
const int nblocks,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const half* dy,
const half* x,
half* dx) {
__shared__ float norm, norm2, sum;
__shared__ typename BlockReduce<float>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, nblocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
float val1 = 0.f, val2 = 0.f;
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
val1 += abs(__half2float(x[idx]));
val2 += __half2float(dy[idx]) * __half2float(x[idx]);
}
val1 = BlockReduce<float>(storage).Sum(val1);
val2 = BlockReduce<float>(storage).Sum(val2);
if (threadIdx.x == 0) {
norm = max(val1 * scale, eps);
norm2 = pow(norm, 2);
sum = val2 * scale;
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
dx[idx] = __float2half(
(__half2float(dy[idx]) / norm) -
((utils::math::Sign(__half2float(x[idx])) / norm2) * sum));
}
}
}
template <typename T>
template <typename T, typename AccT>
__global__ void _L2NormalizeGrad(
const int nblocks,
const int inner_dim,
const int reduce_dim,
const T scale,
const T eps,
const AccT scale,
const AccT eps,
const T* dy,
const T* x,
T* dx) {
__shared__ T norm, norm3, sum;
__shared__ typename BlockReduce<T>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, nblocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
T val1 = T(0), val2 = T(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
val1 += utils::math::Square(x[idx]);
val2 += dy[idx] * x[idx];
}
val1 = BlockReduce<T>(storage).Sum(val1);
val2 = BlockReduce<T>(storage).Sum(val2);
if (threadIdx.x == 0) {
norm = max(sqrt(val1 * scale), eps);
norm3 = pow(norm, 3);
sum = val2 * scale;
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
dx[idx] = (dy[idx] / norm) - ((x[idx] / norm3) * sum);
}
}
}
__global__ void _L2NormalizeGrad(
const int nblocks,
const int inner_dim,
const int reduce_dim,
const float scale,
const float eps,
const half* dy,
const half* x,
half* dx) {
__shared__ float norm, norm3, sum;
__shared__ typename BlockReduce<float>::TempStorage storage;
__shared__ AccT norm, norm3, sum;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, nblocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
float val1 = 0.f, val2 = 0.f;
AccT val1 = AccT(0), val2 = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
val1 += utils::math::Square(__half2float(x[idx]));
val2 += __half2float(dy[idx]) * __half2float(x[idx]);
val1 += math::utils::Square(convert::To<AccT>(x[idx]));
val2 += convert::To<AccT>(dy[idx]) * convert::To<AccT>(x[idx]);
}
val1 = BlockReduce<float>(storage).Sum(val1);
val2 = BlockReduce<float>(storage).Sum(val2);
val1 = BlockReduce<AccT>(storage).Sum(val1);
val2 = BlockReduce<AccT>(storage).Sum(val2);
if (threadIdx.x == 0) {
norm = max(sqrt(val1 * scale), eps);
norm3 = pow(norm, 3);
......@@ -260,9 +136,9 @@ __global__ void _L2NormalizeGrad(
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
dx[idx] = __float2half(
(__half2float(dy[idx]) / norm) -
((__half2float(x[idx]) / norm3) * sum));
dx[idx] = convert::To<T>(
(convert::To<AccT>(dy[idx]) / norm) -
((convert::To<AccT>(x[idx]) / norm3) * sum));
}
}
}
......@@ -271,7 +147,7 @@ __global__ void _L2NormalizeGrad(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \
#define DEFINE_KERNEL_LAUNCHER(name, T, ScalarT, AccT) \
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
......@@ -279,25 +155,30 @@ __global__ void _L2NormalizeGrad(
const int reduce_dim, \
const float scale, \
const float eps, \
const float16* x, \
float16* y, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
_##name<ScalarT, AccT> \
<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, \
inner_dim, \
reduce_dim, \
scale, \
eps, \
reinterpret_cast<const half*>(x), \
reinterpret_cast<half*>(y)); \
reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<ScalarT*>(y)); \
}
DEFINE_KERNEL_LAUNCHER(L1Normalize, float16);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float16);
DEFINE_KERNEL_LAUNCHER(L1Normalize, float16, half, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, float, float, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double, double, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float16, half, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float, float, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double, double, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T) \
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T, ScalarT, AccT) \
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
......@@ -305,69 +186,29 @@ DEFINE_KERNEL_LAUNCHER(L2Normalize, float16);
const int reduce_dim, \
const float scale, \
const float eps, \
const T* dy, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, inner_dim, reduce_dim, (T)scale, (T)eps, x, y); \
}
DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
DEFINE_KERNEL_LAUNCHER(L1Normalize, double);
DEFINE_KERNEL_LAUNCHER(L2Normalize, float);
DEFINE_KERNEL_LAUNCHER(L2Normalize, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int reduce_dim, \
const float scale, \
const float eps, \
const float16* dy, \
const float16* x, \
T* dx, \
CUDAContext* ctx) { \
const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
_##name<ScalarT, AccT> \
<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, \
inner_dim, \
reduce_dim, \
scale, \
eps, \
reinterpret_cast<const half*>(dy), \
reinterpret_cast<const half*>(x), \
reinterpret_cast<half*>(dx)); \
}
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16);
#undef DEFINE_GRAD_KERNEL_LAUNCHER
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int reduce_dim, \
const float scale, \
const float eps, \
const T* dy, \
const T* x, \
T* dx, \
CUDAContext* ctx) { \
const auto nblocks = outer_dim * inner_dim; \
_##name<<<CUDA_2D_BLOCKS(nblocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nblocks, inner_dim, reduce_dim, (T)scale, (T)eps, dy, x, dx); \
reinterpret_cast<const ScalarT*>(dy), \
reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<ScalarT*>(dx)); \
}
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L1NormalizeGrad, double, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(L2NormalizeGrad, double, double, double);
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
......
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/omp_utils.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -47,7 +47,7 @@ void _AvgPool2dNCHW(
for (int w = wstart; w < wend; ++w)
val += offset_x[h * W + w];
y[i] = val / area;
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -89,7 +89,7 @@ void _AvgPool2dNHWC(
for (int w = wstart; w < wend; ++w)
val += offset_x[(h * W + w) * C];
y[i] = val / area;
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -130,7 +130,7 @@ void _AvgPool2dGradNCHW(
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
offset_dx[h * W + w] += dy[i] / area;
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -170,7 +170,7 @@ void _AvgPool2dGradNHWC(
for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w)
offset_dx[(h * W + w) * C] += dy[i] / area;
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -253,7 +253,7 @@ void _AvgPool2dGradNHWC(
const T* dy, \
T* dx, \
CPUContext* ctx) { \
math::Set(N* C* H* W, cast::to<T>(0.f), dx, ctx); \
math::Set(N* C* H* W, convert::To<T>(0.f), dx, ctx); \
if (data_format == "NCHW") { \
_AvgPool2dGradNCHW( \
N, \
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/cast.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......
......@@ -117,6 +117,39 @@ void _DepthwiseConv2dNHWC(
/* ------------------- Launcher Separator ------------------- */
#define DISPATCH_DATA_KERNEL(name, ...) \
if (data_format == "NCHW") { \
name##NCHW(__VA_ARGS__); \
} else if (data_format == "NHWC") { \
name##NHWC(__VA_ARGS__); \
} else { \
LOG(FATAL) << "Unknown DataFormat: " << data_format; \
}
template <>
void DepthwiseConv2d<float16, CPUContext>(
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const float16* x,
const float16* w,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
template <>
void DepthwiseConv2d<float, CPUContext>(
const int N,
......@@ -138,27 +171,8 @@ void DepthwiseConv2d<float, CPUContext>(
const float* w,
float* y,
CPUContext* ctx) {
if (data_format == "NCHW") {
_DepthwiseConv2dNCHW(
N,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
} else {
_DepthwiseConv2dNHWC(
DISPATCH_DATA_KERNEL(
_DepthwiseConv2d,
N,
C,
H,
......@@ -176,56 +190,59 @@ void DepthwiseConv2d<float, CPUContext>(
x,
w,
y);
}
}
template <>
void DepthwiseConv2dGrad<float, CPUContext>(
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const float* dy,
const float* w,
float* dx,
CPUContext* ctx) {
NOT_IMPLEMENTED;
} // DepthwiseConv2dGrad
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void DepthwiseConv2dGrad<T, CPUContext>( \
const int N, \
const int C, \
const int H, \
const int W, \
const int out_h, \
const int out_w, \
const int kernel_h, \
const int kernel_w, \
const int stride_h, \
const int stride_w, \
const int pad_h, \
const int pad_w, \
const int dilation_h, \
const int dilation_w, \
const string& data_format, \
const T* dy, \
const T* w, \
T* dx, \
CPUContext* ctx) { \
NOT_IMPLEMENTED; \
} \
template <> \
void DepthwiseConv2dWGrad<T, CPUContext>( \
const int N, \
const int C, \
const int H, \
const int W, \
const int out_h, \
const int out_w, \
const int kernel_h, \
const int kernel_w, \
const int stride_h, \
const int stride_w, \
const int pad_h, \
const int pad_w, \
const int dilation_h, \
const int dilation_w, \
const string& data_format, \
const T* dy, \
const T* x, \
T* dw, \
CPUContext* ctx) { \
NOT_IMPLEMENTED; \
}
template <>
void DepthwiseConv2dWGrad<float, CPUContext>(
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const float* dy,
const float* x,
float* dw,
CPUContext* ctx) {
NOT_IMPLEMENTED;
} // DepthwiseConv2dWGrad
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -2,6 +2,7 @@
#include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -10,7 +11,13 @@ namespace kernel {
namespace {
template <typename T, int KKH, int KKW>
#if __CUDA_ARCH__ >= 350
#define LOAD(x, i) __ldg(x + i)
#else
#define LOAD(x, i) x[i]
#endif
template <typename T, typename AccT, int KKH, int KKW>
__global__ void _DepthwiseConv2dNCHW(
const int nthreads,
const int C,
......@@ -31,6 +38,7 @@ __global__ void _DepthwiseConv2dNCHW(
T* y) {
const int KH = KKH < 0 ? kernel_h : KKH;
const int KW = KKW < 0 ? kernel_w : KKW;
const auto Multiplies = math::MultipliesFunctor<T>();
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int ow = yi % out_w;
const int oh = (yi / out_w) % out_h;
......@@ -42,7 +50,8 @@ __global__ void _DepthwiseConv2dNCHW(
const int x_start = (n * C + c) * H * W;
int ih, iw, xi, wi = c * KH * KW;
T sum_val = T(0);
AccT sum_val = AccT(0);
#pragma unroll
for (int kh = 0; kh < KH; ++kh) {
#pragma unroll
......@@ -51,20 +60,16 @@ __global__ void _DepthwiseConv2dNCHW(
iw = iw_start + kw * dilation_w;
if (ih >= 0 && ih < H && iw >= 0 && iw < W) {
xi = x_start + ih * W + iw;
#if __CUDA_ARCH__ >= 350
sum_val += __ldg(x + xi) * __ldg(w + wi);
#else
sum_val += x[xi] * w[wi];
#endif
sum_val += convert::To<AccT>(Multiplies(LOAD(x, xi), LOAD(w, wi)));
}
++wi;
} // End kw
} // End kh
y[yi] = sum_val;
y[yi] = convert::To<T>(sum_val);
}
}
template <typename T, int KKH, int KKW>
template <typename T, typename AccT, int KKH, int KKW>
__global__ void _DepthwiseConv2dNHWC(
const int nthreads,
const int C,
......@@ -85,6 +90,7 @@ __global__ void _DepthwiseConv2dNHWC(
T* y) {
const int KH = KKH < 0 ? kernel_h : KKH;
const int KW = KKW < 0 ? kernel_w : KKW;
const auto Multiplies = math::MultipliesFunctor<T>();
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int c = yi % C;
const int ow = (yi / C) % out_w;
......@@ -96,7 +102,7 @@ __global__ void _DepthwiseConv2dNHWC(
const int x_start = n * H;
int ih, iw, xi, wi = c * KH * KW;
T sum_val = T(0);
AccT sum_val = AccT(0);
#pragma unroll
for (int kh = 0; kh < KH; ++kh) {
......@@ -106,20 +112,16 @@ __global__ void _DepthwiseConv2dNHWC(
iw = iw_start + kw * dilation_w;
if (ih >= 0 && ih < H && iw >= 0 && iw < W) {
xi = ((x_start + ih) * W + iw) * C + c;
#if __CUDA_ARCH__ >= 350
sum_val += __ldg(x + xi) * __ldg(w + wi);
#else
sum_val += x[xi] * w[wi];
#endif
sum_val += convert::To<AccT>(Multiplies(LOAD(x, xi), LOAD(w, wi)));
}
++wi;
} // End kw
} // End kh
y[yi] = sum_val;
y[yi] = convert::To<T>(sum_val);
}
}
template <typename T, int KKH, int KKW>
template <typename T, typename AccT, int KKH, int KKW>
__global__ void _DepthwiseConv2dGradNCHW(
const int nthreads,
const int C,
......@@ -140,6 +142,7 @@ __global__ void _DepthwiseConv2dGradNCHW(
T* dx) {
const int KH = KKH < 0 ? kernel_h : KKH;
const int KW = KKW < 0 ? kernel_w : KKW;
const auto Multiplies = math::MultipliesFunctor<T>();
CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int iw = xi % W;
const int ih = (xi / W) % H;
......@@ -148,7 +151,7 @@ __global__ void _DepthwiseConv2dGradNCHW(
int oh, ow, yi, wi = c * KH * KW;
const int y_start = (n * C + c) * out_h * out_w;
T sum_val = T(0);
AccT sum_val = AccT(0);
#pragma unroll
for (int kh = 0; kh < KH; ++kh) {
......@@ -161,21 +164,17 @@ __global__ void _DepthwiseConv2dGradNCHW(
ow = ow / stride_w;
if (oh >= 0 && oh < out_h && ow >= 0 && ow < out_w) {
yi = y_start + oh * out_w + ow;
#if __CUDA_ARCH__ >= 350
sum_val += __ldg(dy + yi) * __ldg(w + wi);
#else
sum_val += dy[yi] * w[wi];
#endif
sum_val += convert::To<AccT>(Multiplies(LOAD(dy, yi), LOAD(w, wi)));
}
}
++wi;
} // End kw
} // End kh
dx[xi] = sum_val;
dx[xi] = convert::To<T>(sum_val);
}
}
template <typename T, int KKH, int KKW>
template <typename T, typename AccT, int KKH, int KKW>
__global__ void _DepthwiseConv2dGradNHWC(
const int nthreads,
const int C,
......@@ -196,6 +195,7 @@ __global__ void _DepthwiseConv2dGradNHWC(
T* dx) {
const int KH = KKH < 0 ? kernel_h : KKH;
const int KW = KKW < 0 ? kernel_w : KKW;
const auto Multiplies = math::MultipliesFunctor<T>();
CUDA_1D_KERNEL_LOOP(xi, nthreads) {
const int c = xi % C;
const int iw = (xi / C) % W;
......@@ -204,7 +204,7 @@ __global__ void _DepthwiseConv2dGradNHWC(
int oh, ow, yi, wi = c * KH * KW;
const int y_start = n * out_h;
T sum_val = T(0);
AccT sum_val = AccT(0);
#pragma unroll
for (int kh = 0; kh < KH; ++kh) {
......@@ -217,11 +217,7 @@ __global__ void _DepthwiseConv2dGradNHWC(
ow = ow / stride_w;
if (oh >= 0 && oh < out_h && ow >= 0 && ow < out_w) {
yi = ((y_start + oh) * out_w + ow) * C + c;
#if __CUDA_ARCH__ >= 350
sum_val += __ldg(dy + yi) * __ldg(w + wi);
#else
sum_val += dy[yi] * w[wi];
#endif
sum_val += convert::To<AccT>(Multiplies(LOAD(dy, yi), LOAD(w, wi)));
}
}
++wi;
......@@ -231,7 +227,7 @@ __global__ void _DepthwiseConv2dGradNHWC(
}
}
template <typename T>
template <typename T, typename AccT>
__global__ void _DepthwiseConv2dWGradNCHW(
const int N,
const int C,
......@@ -250,6 +246,7 @@ __global__ void _DepthwiseConv2dWGradNCHW(
const T* dy,
const T* x,
T* dw) {
const auto Multiplies = math::MultipliesFunctor<T>();
const int block_idx = blockIdx.x;
const int kw = block_idx % kernel_w;
const int kh = (block_idx / kernel_w) % kernel_h;
......@@ -260,8 +257,8 @@ __global__ void _DepthwiseConv2dWGradNCHW(
const int lane_idx = threadIdx.x % 32;
const int ohw = out_h * out_w;
T grad = T(0);
int ih, iw, xi, yi;
AccT sum_val = AccT(0);
for (int i = n; i < N; i += nwarps) {
for (int j = lane_idx; j < ohw; j += 32) {
......@@ -270,21 +267,20 @@ __global__ void _DepthwiseConv2dWGradNCHW(
if (ih >= 0 && iw >= 0 && ih < H && iw < W) {
xi = ((i * C + c) * H + ih) * W + iw;
yi = (i * C + c) * out_h * out_w + j;
#if __CUDA_ARCH__ >= 350
grad += __ldg(dy + yi) * __ldg(x + xi);
#else
grad += dy[yi] * x[xi];
#endif
sum_val += convert::To<AccT>(Multiplies(LOAD(dy, yi), LOAD(x, xi)));
}
}
}
typedef cub::BlockReduce<T, 256> Reduce;
typedef cub::BlockReduce<AccT, 256> Reduce;
__shared__ typename Reduce::TempStorage storage;
grad = Reduce(storage).Sum(grad);
if (threadIdx.x == 0) dw[block_idx] = grad;
sum_val = Reduce(storage).Sum(sum_val);
if (threadIdx.x == 0) {
dw[block_idx] = convert::To<T>(sum_val);
}
}
template <typename T>
template <typename T, typename AccT>
__global__ void _DepthwiseConv2dWGradNHWC(
const int N,
const int C,
......@@ -303,6 +299,7 @@ __global__ void _DepthwiseConv2dWGradNHWC(
const T* dy,
const T* x,
T* dw) {
const auto Multiplies = math::MultipliesFunctor<T>();
const int block_idx = blockIdx.x;
const int kw = block_idx % kernel_w;
const int kh = (block_idx / kernel_w) % kernel_h;
......@@ -313,8 +310,8 @@ __global__ void _DepthwiseConv2dWGradNHWC(
const int lane_idx = threadIdx.x % 32;
const int ohw = out_h * out_w;
T grad = T(0);
int ih, iw, xi, yi;
AccT sum_val = AccT(0);
for (int i = n; i < N; i += nwarps) {
for (int j = lane_idx; j < ohw; j += 32) {
......@@ -323,471 +320,222 @@ __global__ void _DepthwiseConv2dWGradNHWC(
if (ih >= 0 && iw >= 0 && ih < H && iw < W) {
xi = ((i * H + ih) * W + iw) * C + c;
yi = (i * ohw + j) * C + c;
#if __CUDA_ARCH__ >= 350
grad += __ldg(dy + yi) * __ldg(x + xi);
#else
grad += dy[yi] * x[xi];
#endif
sum_val += convert::To<AccT>(Multiplies(LOAD(dy, yi), LOAD(x, xi)));
}
}
}
typedef cub::BlockReduce<T, 256> Reduce;
typedef cub::BlockReduce<AccT, 256> Reduce;
__shared__ typename Reduce::TempStorage storage;
grad = Reduce(storage).Sum(grad);
if (threadIdx.x == 0) dw[block_idx] = grad;
sum_val = Reduce(storage).Sum(sum_val);
if (threadIdx.x == 0) {
dw[block_idx] = convert::To<T>(sum_val);
}
}
#undef LOAD
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void DepthwiseConv2d<float, CUDAContext>(
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const float* x,
const float* w,
float* y,
CUDAContext* ctx) {
const auto nthreads = N * C * out_h * out_w;
if (data_format == "NCHW") {
if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dNCHW<float, 3, 3>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
} else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dNCHW<float, 5, 5>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
} else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dNCHW<float, 7, 7>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
} else {
_DepthwiseConv2dNCHW<float, -1, -1>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
}
} else if (data_format == "NHWC") {
if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dNHWC<float, 3, 3>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
} else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dNHWC<float, 5, 5>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
} else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dNHWC<float, 7, 7>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
} else {
_DepthwiseConv2dNHWC<float, -1, -1>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
x,
w,
y);
}
} else {
LOG(FATAL) << "Unknown DataFormat: " << data_format;
#define DISPATCH_DATA_KERNEL(name, T, AccT, nblocks, nthreads, ...) \
if (data_format == "NCHW") { \
if (kernel_h == 3 && kernel_w == 3) { \
name##NCHW<T, AccT, 3, 3> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else if (kernel_h == 5 && kernel_w == 5) { \
name##NCHW<T, AccT, 5, 5> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else if (kernel_h == 7 && kernel_w == 7) { \
name##NCHW<T, AccT, 7, 7> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else { \
name##NCHW<T, AccT, -1, -1> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} \
} else if (data_format == "NHWC") { \
if (kernel_h == 3 && kernel_w == 3) { \
name##NHWC<T, AccT, 3, 3> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else if (kernel_h == 5 && kernel_w == 5) { \
name##NHWC<T, AccT, 5, 5> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else if (kernel_h == 7 && kernel_w == 7) { \
name##NHWC<T, AccT, 7, 7> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else { \
name##NHWC<T, AccT, -1, -1> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} \
} else { \
LOG(FATAL) << "Unknown DataFormat: " << data_format; \
}
}
template <>
void DepthwiseConv2dGrad<float, CUDAContext>(
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const float* dy,
const float* w,
float* dx,
CUDAContext* ctx) {
auto nthreads = N * C * H * W;
if (data_format == "NCHW") {
if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dGradNCHW<float, 3, 3>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
w,
dx);
} else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dGradNCHW<float, 5, 5>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
w,
dx);
} else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dGradNCHW<float, 7, 7>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
w,
dx);
} else {
_DepthwiseConv2dGradNCHW<float, -1, -1>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
w,
dx);
#define DISPATCH_WEIGHT_KERNEL(name, T, AccT, nblocks, nthreads, ...) \
if (data_format == "NCHW") { \
name##NCHW<T, AccT> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else if (data_format == "NHWC") { \
name##NHWC<T, AccT> \
<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(__VA_ARGS__); \
} else { \
LOG(FATAL) << "Unknown DataFormat: " << data_format; \
}
} else if (data_format == "NHWC") {
if (kernel_h == 3 && kernel_w == 3) {
_DepthwiseConv2dGradNHWC<float, 3, 3>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
w,
dx);
} else if (kernel_h == 5 && kernel_w == 5) {
_DepthwiseConv2dGradNHWC<float, 5, 5>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
w,
dx);
} else if (kernel_h == 7 && kernel_w == 7) {
_DepthwiseConv2dGradNHWC<float, 7, 7>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
w,
dx);
} else {
_DepthwiseConv2dGradNHWC<float, -1, -1>
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
w,
dx);
}
} else {
LOG(FATAL) << "Unknown DataFormat: " << data_format;
#define DEFINE_KERNEL_LAUNCHER(T, ScalarT, AccT) \
template <> \
void DepthwiseConv2d<T, CUDAContext>( \
const int N, \
const int C, \
const int H, \
const int W, \
const int out_h, \
const int out_w, \
const int kernel_h, \
const int kernel_w, \
const int stride_h, \
const int stride_w, \
const int pad_h, \
const int pad_w, \
const int dilation_h, \
const int dilation_w, \
const string& data_format, \
const T* x, \
const T* w, \
T* y, \
CUDAContext* ctx) { \
const auto nthreads = N * C * out_h * out_w; \
DISPATCH_DATA_KERNEL( \
_DepthwiseConv2d, \
ScalarT, \
AccT, \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
nthreads, \
C, \
H, \
W, \
out_h, \
out_w, \
kernel_h, \
kernel_w, \
stride_h, \
stride_w, \
pad_h, \
pad_w, \
dilation_h, \
dilation_w, \
reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<const ScalarT*>(w), \
reinterpret_cast<ScalarT*>(y)); \
}
} // DepthwiseConv2dGrad
template <>
void DepthwiseConv2dWGrad<float, CUDAContext>(
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const float* dy,
const float* x,
float* dw,
CUDAContext* ctx) {
int nthreads = 256;
auto nblocks = C * kernel_h * kernel_w;
if (data_format == "NCHW") {
_DepthwiseConv2dWGradNCHW<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(
N,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
x,
dw);
} else if (data_format == "NHWC") {
_DepthwiseConv2dWGradNHWC<<<nblocks, nthreads, 0, ctx->cuda_stream()>>>(
N,
C,
H,
W,
out_h,
out_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dy,
x,
dw);
} else {
LOG(FATAL) << "Unknown DataFormat: " << data_format;
#define DEFINE_GRAD_KERNEL_LAUNCHER(T, ScalarT, AccT) \
template <> \
void DepthwiseConv2dGrad<T, CUDAContext>( \
const int N, \
const int C, \
const int H, \
const int W, \
const int out_h, \
const int out_w, \
const int kernel_h, \
const int kernel_w, \
const int stride_h, \
const int stride_w, \
const int pad_h, \
const int pad_w, \
const int dilation_h, \
const int dilation_w, \
const string& data_format, \
const T* dy, \
const T* w, \
T* dx, \
CUDAContext* ctx) { \
auto nthreads = N * C * H * W; \
DISPATCH_DATA_KERNEL( \
_DepthwiseConv2dGrad, \
ScalarT, \
AccT, \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
nthreads, \
C, \
H, \
W, \
out_h, \
out_w, \
kernel_h, \
kernel_w, \
stride_h, \
stride_w, \
pad_h, \
pad_w, \
dilation_h, \
dilation_w, \
reinterpret_cast<const ScalarT*>(dy), \
reinterpret_cast<const ScalarT*>(w), \
reinterpret_cast<ScalarT*>(dx)); \
} \
template <> \
void DepthwiseConv2dWGrad<T, CUDAContext>( \
const int N, \
const int C, \
const int H, \
const int W, \
const int out_h, \
const int out_w, \
const int kernel_h, \
const int kernel_w, \
const int stride_h, \
const int stride_w, \
const int pad_h, \
const int pad_w, \
const int dilation_h, \
const int dilation_w, \
const string& data_format, \
const T* dy, \
const T* x, \
T* dw, \
CUDAContext* ctx) { \
const auto nblocks = C * kernel_h * kernel_w; \
const auto nthreads = 256; \
DISPATCH_WEIGHT_KERNEL( \
_DepthwiseConv2dWGrad, \
ScalarT, \
AccT, \
nblocks, \
nthreads, \
N, \
C, \
H, \
W, \
out_h, \
out_w, \
kernel_h, \
kernel_w, \
stride_h, \
stride_w, \
pad_h, \
pad_w, \
dilation_h, \
dilation_w, \
reinterpret_cast<const ScalarT*>(dy), \
reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<ScalarT*>(dw)); \
}
} // DepthwiseConv2dWGrad
DEFINE_KERNEL_LAUNCHER(float16, half, float);
DEFINE_KERNEL_LAUNCHER(float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float);
#undef DISPATCH_DATA_KERNEL
#undef DISPATCH_WEIGHT_KERNEL
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
......
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -52,7 +52,7 @@ void _MaxPool2dNCHW(
}
y[i] = val;
mask[i] = mxi;
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -99,7 +99,7 @@ void _MaxPool2dNHWC(
}
y[i] = val;
mask[i] = mxi;
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -129,7 +129,7 @@ void _MaxPool2dGradNCHW(
if (mask[i] != -1) {
dx[idx[0] * CHW + idx[1] * HW + mask[i]] += dy[i];
}
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
......@@ -158,7 +158,7 @@ void _MaxPool2dGradNHWC(
if (mask[i] != -1) {
dx[idx[0] * HWC + mask[i]] += dy[i];
}
utils::math::IncreaseIndexInDims(2, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(2, dims.data(), idx.data());
}
}
......@@ -245,7 +245,7 @@ void _MaxPool2dGradNHWC(
const int* mask, \
T* dx, \
CPUContext* ctx) { \
math::Set(N* C* H* W, cast::to<T>(0.f), dx, ctx); \
math::Set(N* C* H* W, convert::To<T>(0.f), dx, ctx); \
if (data_format == "NCHW") { \
_MaxPool2dGradNCHW( \
N, \
......
......@@ -62,7 +62,7 @@ void _ResizeLinearNCHW(
t = tl + (tr - tl) * u;
b = bl + (br - bl) * u;
y[i] = static_cast<T>(t + (b - t) * v);
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -99,7 +99,7 @@ void _ResizeLinearNHWC(
t = tl + (tr - tl) * u;
b = bl + (br - bl) * u;
y[i] = static_cast<T>(t + (b - t) * v);
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -135,7 +135,7 @@ void _ResizeLinearGradNCHW(
dx[(offset + ti) * W + ri] += u * dt; // tr
dx[(offset + bi) * W + li] += (1.f - u) * db; // bl
dx[(offset + bi) * W + ri] += u * db; // br
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -171,7 +171,7 @@ void _ResizeLinearGradNHWC(
dx[((offset + ti) * W + ri) * C + idx[3]] += u * dt; // tr
dx[((offset + bi) * W + li) * C + idx[3]] += (1.f - u) * db; // bl
dx[((offset + bi) * W + ri) * C + idx[3]] += u * db; // br
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......
......@@ -27,7 +27,7 @@ void _ResizeNearestNCHW(
h_in = std::min(int(idx[2] * scale_h), h_max);
w_in = std::min(int(idx[3] * scale_w), w_max);
y[i] = x[(((idx[0] * C) + idx[1]) * H + h_in) * W + w_in];
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -52,7 +52,7 @@ void _ResizeNearestNHWC(
w_in = std::min(int(idx[2] * scale_w), w_max);
memcpy(
y + i * C, x + (((idx[0] * H) + h_in) * W + w_in) * C, C * sizeof(T));
utils::math::IncreaseIndexInDims(3, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(3, dims.data(), idx.data());
}
}
......@@ -76,7 +76,7 @@ void _ResizeNearestGradNCHW(
h_in = std::min(int(idx[2] * scale_h), h_max);
w_in = std::min(int(idx[3] * scale_w), w_max);
dx[(((idx[0] * C) + idx[1]) * H + h_in) * W + w_in] += (float)dy[i];
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......@@ -100,7 +100,7 @@ void _ResizeNearestGradNHWC(
h_in = std::min(int(idx[1] * scale_h), h_max);
w_in = std::min(int(idx[2] * scale_w), w_max);
dx[(((idx[0] * H) + h_in) * W + w_in) * C + idx[3]] += (float)dy[i];
utils::math::IncreaseIndexInDims(4, dims.data(), idx.data());
math::utils::IncreaseIndexInDims(4, dims.data(), idx.data());
}
}
......
......@@ -14,6 +14,7 @@
#define DRAGON_MODULES_PYTHON_CONFIG_H_
#include "dragon/modules/python/common.h"
#include "dragon/utils/device/common_eigen.h"
namespace dragon {
......@@ -22,9 +23,16 @@ namespace python {
namespace config {
void RegisterModule(py::module& m) {
/*! \brief Set the logging severity */
m.def("SetLoggingLevel", [](const string& severity) {
SetLogDestination(severity);
});
/*! \brief Set the number of threads for cpu parallelism */
m.def("SetNumThreads", [](int num) { Eigen::setNbThreads(num); });
/*! \brief Return the number of threads for cpu parallelism */
m.def("GetNumThreads", []() { return Eigen::nbThreads(); });
}
} // namespace config
......
......@@ -14,7 +14,6 @@
#define DRAGON_MODULES_PYTHON_OPERATOR_H_
#include "dragon/modules/python/common.h"
#include "dragon/utils/eigen_utils.h"
namespace dragon {
......
......@@ -19,7 +19,7 @@ void ExpandOp<Context>::DoRunWithType() {
// Store for the gradient calculation
STORE_INPUT_SPEC(0);
if (utils::math::IsBinaryBroadcast(X.dims(), X_dims, Y_dims)) {
if (math::utils::IsBinaryBroadcast(X.dims(), X_dims, Y_dims)) {
math::Set(
X.ndim(),
X.dims().data(),
......@@ -47,7 +47,7 @@ void ExpandGradientOp<Context>::DoRunWithType() {
vec32_t X_broadcast_axes, _;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
dX->dims(), dY.dims(), dY.dims(), X_broadcast_axes, _);
if (X_broadcast_axes.empty()) {
......
......@@ -62,7 +62,7 @@ void IndexSelectGradientOp<Context>::DoRunWithType() {
// Reset the accumulating gradient
math::Set(
dX->count(),
cast::to<T>(0.f),
convert::To<T>(0.f),
dX->template mutable_data<T, Context>(),
ctx());
......
......@@ -46,7 +46,7 @@ template <class Context>
template <typename T>
void FillOp<Context>::DoRunWithType() {
auto* y = Output(0)->template mutable_data<T, Context>();
math::Set(Output(0)->count(), cast::to<T>(value_), y, ctx());
math::Set(Output(0)->count(), convert::To<T>(value_), y, ctx());
}
template <class Context>
......
......@@ -15,7 +15,7 @@ void OneHotOp<Context>::DoRunWithType() {
// Brush the off-value over all
math::Set(
X.count() * depth_,
cast::to<T>((float)off_value_),
convert::To<T>((float)off_value_),
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
......
......@@ -87,7 +87,7 @@ void SliceGradientOp<Context>::DoRunWithType() {
// Zero the redundant gradients
auto* dx = dX->template mutable_data<T, Context>();
math::Set(dX->count(), cast::to<T>(0.f), dx, ctx());
math::Set(dX->count(), convert::To<T>(0.f), dx, ctx());
// Copy the dY to the right positions
kernel::SliceGrad(
......
......@@ -75,7 +75,7 @@ void SplitGradientOp<Context>::DoRunWithType() {
if (!Input(i).has_name()) {
math::Set(
dX->count(),
cast::to<T>(0.f),
convert::To<T>(0.f),
dX->template mutable_data<T, Context>(),
ctx());
break;
......
......@@ -14,8 +14,8 @@ void WhereOp<Context>::DoRunWithType() {
<< "\nExcepted bool or uint8 condition tensor.";
vec64_t AB_dims, Y_dims;
if (utils::math::IsBinaryBroadcast(A.dims(), B.dims(), AB_dims) &&
utils::math::IsBinaryBroadcast(AB_dims, C.dims(), Y_dims)) {
if (math::utils::IsBinaryBroadcast(A.dims(), B.dims(), AB_dims) &&
math::utils::IsBinaryBroadcast(AB_dims, C.dims(), Y_dims)) {
math::Where(
A.ndim(),
A.dims().data(),
......@@ -50,7 +50,7 @@ void WhereGradientOp<Context>::DoRunWithType() {
vec32_t A_broadcast_axes, B_broadcast_axes;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
A.dims(), B.dims(), dY.dims(), A_broadcast_axes, B_broadcast_axes);
// Temporal space to store the intermediate gradient and zeros
......@@ -68,7 +68,7 @@ void WhereGradientOp<Context>::DoRunWithType() {
if (scratch_size > 0) {
scratch = ctx()->workspace()->template data<T, Context>({scratch_size})[0];
zeros = scratch + (scratch_size - 1);
math::Set(1, cast::to<T>(0.f), zeros, ctx());
math::Set(1, convert::To<T>(0.f), zeros, ctx());
}
if (dA->has_name()) {
......
......@@ -43,11 +43,11 @@ void AssignOp<Context>::DoRunWithType() {
if (X.dims() != X_dims) {
vec64_t dims1, dims2;
if (utils::math::IsBinaryBroadcast(X.dims(), X_dims, dims1)) {
if (math::utils::IsBinaryBroadcast(X.dims(), X_dims, dims1)) {
CHECK(X_dims == dims1)
<< "\nCould not assign with shapes " << X.DimString() << " "
<< Tensor::DimString(X_dims);
utils::math::ComputeBinaryBroadcastDims(X.dims(), X_dims, dims1, dims2);
math::utils::ComputeBinaryBroadcastDims(X.dims(), X_dims, dims1, dims2);
if (dims1 != dims2) {
auto* scratch = ctx()->workspace()->template data<T, Context>(
{X_broadcast.count()})[0];
......
......@@ -14,8 +14,8 @@ void MaskedAssignOp<Context>::DoRunWithType() {
<< "\nExcepted bool or uint8 mask.";
vec64_t X_dims, Y_dims;
if (utils::math::IsBinaryBroadcast(X.dims(), X_mask.dims(), X_dims) &&
utils::math::IsBinaryBroadcast(X_dims, Y->dims(), Y_dims) &&
if (math::utils::IsBinaryBroadcast(X.dims(), X_mask.dims(), X_dims) &&
math::utils::IsBinaryBroadcast(X_dims, Y->dims(), Y_dims) &&
Y_dims == Y->dims()) {
math::Where(
X.ndim(),
......
......@@ -13,7 +13,7 @@ void GradientGenerateOp<Context>::DoRunWithType() {
Y->ReshapeLike(Input(i));
math::Set(
Y->count(),
cast::to<T>(defaults_[i]),
convert::To<T>(defaults_[i]),
Y->template mutable_data<T, Context>(),
ctx());
}
......
......@@ -105,7 +105,7 @@ void NLLLossGradientOp<Context>::DoRunWithType() {
auto* dx = dX->template mutable_data<LogitType, Context>();
auto* mask =
ctx()->workspace()->template data<LogitType, Context>({num_preds + 1})[0];
math::Set(dX->count(), cast::to<LogitType>(0.f), dx, ctx());
math::Set(dX->count(), convert::To<LogitType>(0.f), dx, ctx());
kernel::NLLLossGrad(
outer_dim,
......
......@@ -21,7 +21,7 @@ void AddOp<Context>::DoRunWithType() {
B.template data<T, Context>(),
Output(0, {0, 1})->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else if (utils::math::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) {
} else if (math::utils::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) {
auto* Y = Output(0, CheckOutputAliases(A, B, Output(0), Y_dims));
math::Add(
A.ndim(),
......@@ -51,7 +51,7 @@ void AddGradientOp<Context>::DoRunWithType() {
vec32_t A_broadcast_axes, B_broadcast_axes;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
A.dims(), B.dims(), dY.dims(), A_broadcast_axes, B_broadcast_axes);
if (dA->has_name()) {
......
......@@ -21,7 +21,7 @@ void DivOp<Context>::DoRunWithType() {
B.template data<T, Context>(),
Output(0, {0, 1})->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else if (utils::math::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) {
} else if (math::utils::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) {
auto* Y = Output(0, CheckOutputAliases(A, B, Output(0), Y_dims));
math::Div(
A.ndim(),
......@@ -52,7 +52,7 @@ void DivGradientOp<Context>::DoRunWithType() {
vec32_t A_broadcast_axes, B_broadcast_axes;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
A_ref.dims(),
B_ref.dims(),
dY.dims(),
......
......@@ -93,7 +93,7 @@ DEFINE_INPLACE_UNARY_OP_IMPL(Invert, T);
B.template data<T, Context>(), \
Y->Reshape(Y_dims)->template mutable_data<TOut, Context>(), \
ctx()); \
} else if (utils::math::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) { \
} else if (math::utils::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) { \
math::name( \
A.ndim(), \
A.dims().data(), \
......
......@@ -13,7 +13,7 @@ void MaximumGradientOp<Context>::DoRunWithType() {
vec32_t A_broadcast_axes, B_broadcast_axes;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
A.dims(), B.dims(), dY.dims(), A_broadcast_axes, B_broadcast_axes);
// Temporal space to store the intermediate gradient
......
......@@ -13,7 +13,7 @@ void MinimumGradientOp<Context>::DoRunWithType() {
vec32_t A_broadcast_axes, B_broadcast_axes;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
A.dims(), B.dims(), dY.dims(), A_broadcast_axes, B_broadcast_axes);
// Temporal space to store the intermediate gradient
......
......@@ -40,7 +40,7 @@ void MomentsOp<Context>::DoRunWithType() {
ctx());
math::Set(
1,
cast::to<Ty>(0.f),
convert::To<Ty>(0.f),
Y2->Reshape(Y_shape)->template mutable_data<Ty, Context>(),
ctx());
} else {
......
......@@ -21,7 +21,7 @@ void MulOp<Context>::DoRunWithType() {
B.template data<T, Context>(),
Output(0, {0, 1})->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else if (utils::math::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) {
} else if (math::utils::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) {
auto* Y = Output(0, CheckOutputAliases(A, B, Output(0), Y_dims));
math::Mul(
A.ndim(),
......@@ -52,7 +52,7 @@ void MulGradientOp<Context>::DoRunWithType() {
vec32_t A_broadcast_axes, B_broadcast_axes;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
A_ref.dims(),
B_ref.dims(),
dY.dims(),
......
......@@ -12,7 +12,7 @@ void PowGradientOp<Context>::DoRunWithType() {
vec32_t A_broadcast_axes, B_broadcast_axes;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
A.dims(), B.dims(), dY.dims(), A_broadcast_axes, B_broadcast_axes);
// Temporal space to store the intermediate gradient
......@@ -99,7 +99,7 @@ void PowGradientOp<Context>::DoRunWithType() {
ctx());
math::ReplaceNaN(
A.count(),
cast::to<T>(0.f),
convert::To<T>(0.f),
dA->template data<T, Context>(),
dA->template mutable_data<T, Context>(),
ctx());
......@@ -141,7 +141,7 @@ void PowGradientOp<Context>::DoRunWithType() {
A.template data<T, Context>(),
scratch,
ctx());
math::ReplaceNaN(Y.count(), cast::to<T>(0.f), scratch, scratch, ctx());
math::ReplaceNaN(Y.count(), convert::To<T>(0.f), scratch, scratch, ctx());
if (B_broadcast_axes.empty()) {
math::Mul(
Y.count(), scratch, B.template data<T, Context>(), scratch, ctx());
......
......@@ -9,7 +9,7 @@ void SignGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0);
math::Set(
dY.count(),
cast::to<T>(0.f),
convert::To<T>(0.f),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
ctx());
}
......
......@@ -21,7 +21,7 @@ void SubOp<Context>::DoRunWithType() {
B.template data<T, Context>(),
Output(0, {0, 1})->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else if (utils::math::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) {
} else if (math::utils::IsBinaryBroadcast(A.dims(), B.dims(), Y_dims)) {
auto* Y = Output(0, CheckOutputAliases(A, B, Output(0), Y_dims));
math::Sub(
A.ndim(),
......@@ -51,7 +51,7 @@ void SubGradientOp<Context>::DoRunWithType() {
vec32_t A_broadcast_axes, B_broadcast_axes;
vec32_t Y_dims(dY.dims().begin(), dY.dims().end());
utils::math::ComputeBinaryBroadcastAxes(
math::utils::ComputeBinaryBroadcastAxes(
A.dims(), B.dims(), dY.dims(), A_broadcast_axes, B_broadcast_axes);
if (dA->has_name()) {
......
......@@ -19,17 +19,57 @@ void BatchNormOp<Context>::TrainingImpl() {
auto* X_bias = Buffer("X_bias")->Reshape({C_});
auto* x = Input(0).template data<InputType, Context>();
auto* gamma = Input(1).template data<ParamType, Context>();
auto* beta = Input(2).template data<ParamType, Context>();
auto* rm = Input(3).template mutable_data<ParamType, Context>();
auto* rv = Input(4).template mutable_data<ParamType, Context>();
auto* mu = X_mu->template mutable_data<ParamType, Context>();
auto* rsig = X_rsig->template mutable_data<ParamType, Context>();
auto* scale = X_scale->template mutable_data<ParamType, Context>();
auto* bias = X_bias->template mutable_data<ParamType, Context>();
auto* y = Output(0)->template mutable_data<InputType, Context>();
// Compute moments
if (sync_stats_ > 0) {
#ifdef USE_MPI
// Compute E(X) and E(X^2)
kernel::BatchNormExpectation(
N_,
C_,
S_,
ParamType(1) / (N_ * comm_size_ * S_),
data_format(),
x,
mu,
rsig,
ctx());
// Compute D(X) = E(X^2) - E(X)^2
ctx()->FinishDeviceComputation();
if (enable_nccl_) {
#ifdef USE_NCCL
auto nccl_comm_ = this->nccl_comm();
auto nccl_dtype_ = this->template nccl_dtype<ParamType>();
NCCL_CHECK(ncclAllReduce(
(void*)mu,
(void*)mu,
C_,
nccl_dtype_,
ncclSum,
nccl_comm_,
((CUDAContext*)ctx())->cuda_stream()));
NCCL_CHECK(ncclAllReduce(
(void*)rsig,
(void*)rsig,
C_,
nccl_dtype_,
ncclSum,
nccl_comm_,
((CUDAContext*)ctx())->cuda_stream()));
#endif // USE_NCCL
} else {
AllReduce(mu, mu, C_);
AllReduce(rsig, rsig, C_);
}
math::Square(C_, mu, scale, ctx());
math::Sub(C_, rsig, scale, rsig, ctx());
#endif // USE_MPI
} else {
if (data_format() == "NCHW") {
vec32_t dims = {(int)N_, (int)C_, (int)S_};
vec32_t axes = {0, 2};
......@@ -39,27 +79,32 @@ void BatchNormOp<Context>::TrainingImpl() {
vec32_t axes = {0};
kernel::Moments(2, dims.data(), 1, axes.data(), x, mu, rsig, ctx());
}
}
// Compute running statistics
if (is_recomputing_ == 0) {
// Running(X) = (1 - momentum) * Cur(X) + momentum * Running(X)
math::Axpby(C_, 1.f - momentum_, mu, momentum_, rm, ctx());
math::Axpby(C_, 1.f - momentum_, rsig, momentum_, rv, ctx());
}
// Fuse parameters along channel axis
// [mu, rsig, alpha, beta] => [scale, bias]
// Inverse stddev from variance
math::InvStd(C_, epsilon_, rsig, rsig, ctx());
math::Mul(C_, gamma, rsig, scale, ctx());
math::Mul(C_, scale, mu, bias, ctx());
math::Sub(C_, beta, bias, bias, ctx());
// Compute affine transformation
if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
}
// Fuse parameters to compute affine transformation
kernel::BatchNorm(
N_,
C_,
S_,
data_format(),
x,
mu,
rsig,
Input(1).template data<ParamType, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta
scale,
X_bias->template mutable_data<ParamType, Context>(),
Output(0)->template mutable_data<InputType, Context>(),
ctx());
}
template <class Context>
......@@ -70,31 +115,30 @@ void BatchNormOp<Context>::InferenceImpl() {
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamType);
auto* X_rsig = Buffer("X_rsig")->Reshape({C_});
auto* X_scale = Buffer("X_scale")->Reshape({C_});
auto* X_bias = Buffer("X_bias")->Reshape({C_});
auto* x = Input(0).template data<InputType, Context>();
auto* gamma = Input(1).template data<ParamType, Context>();
auto* beta = Input(2).template data<ParamType, Context>();
auto* rm = Input(3).template data<ParamType, Context>();
auto* rv = Input(4).template data<ParamType, Context>();
auto* scale = X_scale->template mutable_data<ParamType, Context>();
auto* bias = X_bias->template mutable_data<ParamType, Context>();
auto* y = Output(0)->template mutable_data<InputType, Context>();
auto* rsig = X_rsig->template mutable_data<ParamType, Context>();
// Fuse parameters along channel axis
// [mu, rsig, alpha, beta] => [scale, bias]
math::InvStd(C_, epsilon_, rv, bias, ctx());
math::Mul(C_, gamma, bias, scale, ctx());
math::Mul(C_, scale, rm, bias, ctx());
math::Sub(C_, beta, bias, bias, ctx());
// Inverse stddev from variance
math::InvStd(C_, epsilon_, rv, rsig, ctx());
// Compute affine transformation
if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
}
// Fuse parameters to compute affine transformation
kernel::BatchNorm(
N_,
C_,
S_,
data_format(),
Input(0).template data<InputType, Context>(),
Input(3).template data<ParamType, Context>(),
rsig,
Input(1).template data<ParamType, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta
X_scale->template mutable_data<ParamType, Context>(),
X_bias->template mutable_data<ParamType, Context>(),
Output(0)->template mutable_data<InputType, Context>(),
ctx());
}
template <class Context>
......@@ -113,9 +157,15 @@ void BatchNormOp<Context>::RunOnDevice() {
} else {
InferenceImpl<float, float>();
}
} else if (Input(0).template IsType<float16>()) {
if (is_training_) {
TrainingImpl<float16, float>();
} else {
InferenceImpl<float16, float>();
}
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
types::to_string(Input(0).meta()), {"float16", "float32"});
}
}
......@@ -124,21 +174,71 @@ template <typename InputType, typename ParamType>
void BatchNormGradientOp<Context>::TrainingImpl() {
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
auto *X_scale = Buffer("X_scale"), *X_bias = Buffer("X_bias");
// Gradient w.r.t. gamma, beta and input
kernel::BatchNormBackwardTraining(
auto* x = Input(0).template data<InputType, Context>();
auto* gamma = Input(1).template data<ParamType, Context>();
auto* dy = Input(4).template data<InputType, Context>();
auto* mu = X_mu->template data<ParamType, Context>();
auto* rsig = X_rsig->template data<ParamType, Context>();
auto* scale = X_scale->template mutable_data<ParamType, Context>();
auto* bias = X_bias->template mutable_data<ParamType, Context>();
auto* dgamma = dW->Reshape({C_})->template mutable_data<ParamType, Context>();
auto* dbeta = dB->Reshape({C_})->template mutable_data<ParamType, Context>();
// Gradient w.r.t. gamma and beta
kernel::BatchNormInternalGrad(
N_, C_, S_, data_format(), x, mu, rsig, gamma, dy, dgamma, dbeta, ctx());
if (sync_stats_ > 0) {
#ifdef USE_MPI
ctx()->FinishDeviceComputation();
if (enable_nccl_) {
#ifdef USE_NCCL
auto nccl_comm_ = this->nccl_comm();
auto nccl_dtype_ = this->template nccl_dtype<ParamType>();
NCCL_CHECK(ncclAllReduce(
(void*)dgamma,
(void*)scale,
C_,
nccl_dtype_,
ncclSum,
nccl_comm_,
((CUDAContext*)ctx())->cuda_stream()));
NCCL_CHECK(ncclAllReduce(
(void*)dbeta,
(void*)bias,
C_,
nccl_dtype_,
ncclSum,
nccl_comm_,
((CUDAContext*)ctx())->cuda_stream()));
#endif // USE_NCCL
} else {
AllReduce(dgamma, scale, C_);
AllReduce(dbeta, bias, C_);
}
math::Scale(C_, ParamType(1) / comm_size_, scale, scale, ctx());
math::Scale(C_, ParamType(1) / comm_size_, bias, bias, ctx());
#endif // USE_MPI
} else {
scale = dgamma, bias = dbeta;
}
// Gradient w.r.t. input
kernel::BatchNormTrainingGrad(
N_,
C_,
S_,
data_format(),
Input(0).template data<InputType, Context>(), // x
X_mu->template data<ParamType, Context>(), // mu
X_rsig->template data<ParamType, Context>(), // rsig
Input(1).template data<ParamType, Context>(), // gamma
Input(4).template data<InputType, Context>(), // dy
Output(0)->template mutable_data<InputType, Context>(), // dx
dW->Reshape({C_})->template mutable_data<ParamType, Context>(), // dgamma
dB->Reshape({C_})->template mutable_data<ParamType, Context>(), // dbeta
x,
mu,
rsig,
gamma,
scale,
bias,
dy,
Output(0)->template mutable_data<InputType, Context>(),
ctx());
}
......@@ -158,11 +258,11 @@ void BatchNormGradientOp<Context>::InferenceImpl() {
dbeta = dB->Reshape({C_})->template mutable_data<ParamType, Context>();
}
// Restore inverse stddev from variance
// Inverse stddev from variance
math::InvStd(C_, epsilon_, rv, rsig, ctx());
// Gradient w.r.t. gamma, beta and input
kernel::BatchNormBackwardInference(
kernel::BatchNormInferenceGrad(
N_,
C_,
S_,
......@@ -172,9 +272,9 @@ void BatchNormGradientOp<Context>::InferenceImpl() {
rsig,
Input(1).template data<ParamType, Context>(), // gamma
Input(4).template data<InputType, Context>(), // dy
dX->template mutable_data<InputType, Context>(),
dgamma,
dbeta,
dX->template mutable_data<InputType, Context>(),
ctx());
}
......@@ -190,9 +290,15 @@ void BatchNormGradientOp<Context>::RunOnDevice() {
} else {
InferenceImpl<float, float>();
}
} else if (Input(0).template IsType<float16>()) {
if (is_training_ > 0) {
TrainingImpl<float16, float>();
} else {
InferenceImpl<float16, float>();
}
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
types::to_string(Input(0).meta()), {"float16", "float32"});
}
}
......
......@@ -35,7 +35,8 @@ class BatchNormOpBase : public GenericOpBase<Context> {
: GenericOpBase<Context>(def, ws),
momentum_(OP_SINGLE_ARG(float, "momentum", 0.9f)),
epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-5)),
use_stats_(OP_SINGLE_ARG(int64_t, "use_stats", -1)) {}
use_stats_(OP_SINGLE_ARG(int64_t, "use_stats", -1)),
sync_stats_(OP_SINGLE_ARG(int64_t, "comm", 0) > 0 ? 1 : 0) {}
USE_OPERATOR_FUNCTIONS;
void DetermineBaseArguments() {
......@@ -58,7 +59,8 @@ class BatchNormOpBase : public GenericOpBase<Context> {
protected:
float momentum_;
double epsilon_;
int64_t use_stats_, N_, C_, S_;
int64_t N_, C_, S_;
int64_t use_stats_, sync_stats_;
int64_t is_training_, is_recomputing_;
};
......@@ -69,6 +71,7 @@ class BatchNormOpBase : public GenericOpBase<Context> {
using BatchNormOpBase<Context>::momentum_; \
using BatchNormOpBase<Context>::epsilon_; \
using BatchNormOpBase<Context>::use_stats_; \
using BatchNormOpBase<Context>::sync_stats_; \
using BatchNormOpBase<Context>::N_; \
using BatchNormOpBase<Context>::C_; \
using BatchNormOpBase<Context>::S_; \
......@@ -82,6 +85,9 @@ class BatchNormOp : public BatchNormOpBase<Context> {
: BatchNormOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_BATCHNORM_FUNCTIONS;
#ifdef USE_MPI
USE_COLLECTIVE_FUNCTIONS;
#endif
void RunOnDevice() override;
......@@ -99,50 +105,19 @@ class BatchNormGradientOp : public BatchNormOpBase<Context> {
: BatchNormOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_BATCHNORM_FUNCTIONS;
void RunOnDevice() override;
template <typename InputType, typename ParamType>
void TrainingImpl();
template <typename InputType, typename ParamType>
void InferenceImpl();
};
#ifdef USE_MPI
template <class Context>
class SyncBatchNormOp : public BatchNormOp<Context> {
public:
SyncBatchNormOp(const OperatorDef& def, Workspace* ws)
: BatchNormOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_BATCHNORM_FUNCTIONS;
USE_COLLECTIVE_FUNCTIONS;
#endif
void RunOnDevice() override;
template <typename InputType, typename ParamType>
void TrainingImpl();
};
template <class Context>
class SyncBatchNormGradientOp : public BatchNormGradientOp<Context> {
public:
SyncBatchNormGradientOp(const OperatorDef& def, Workspace* ws)
: BatchNormGradientOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_BATCHNORM_FUNCTIONS;
USE_COLLECTIVE_FUNCTIONS;
void RunOnDevice() override;
template <typename InputType, typename ParamType>
void TrainingImpl();
void InferenceImpl();
};
#endif // USE_MPI
#ifdef USE_CUDNN
template <class Context>
......
#ifdef USE_MPI
#include "dragon/core/workspace.h"
#include "dragon/operators/normalization/batch_norm_op.h"
#include "dragon/utils/filler.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename InputType, typename ParamType>
void SyncBatchNormOp<Context>::TrainingImpl() {
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamType);
auto* X_mu = Buffer("X_mu")->Reshape({C_});
auto* X_rsig = Buffer("X_rsig")->Reshape({C_});
auto* X_scale = Buffer("X_scale")->Reshape({C_});
auto* X_bias = Buffer("X_bias")->Reshape({C_});
auto* x = Input(0).template data<InputType, Context>();
auto* gamma = Input(1).template data<ParamType, Context>();
auto* beta = Input(2).template data<ParamType, Context>();
auto* rm = Input(3).template mutable_data<ParamType, Context>();
auto* rv = Input(4).template mutable_data<ParamType, Context>();
auto* mu = X_mu->template mutable_data<ParamType, Context>();
auto* rsig = X_rsig->template mutable_data<ParamType, Context>();
auto* scale = X_scale->template mutable_data<ParamType, Context>();
auto* bias = X_bias->template mutable_data<ParamType, Context>();
auto* y = Output(0)->template mutable_data<InputType, Context>();
// Compute E(X) and E(X^2)
kernel::BatchNormExpectation(
N_,
C_,
S_,
ParamType(1) / (N_ * comm_size_ * S_),
data_format(),
x,
mu,
rsig,
ctx());
// Compute D(X) = E(X^2) - E(X)^2
ctx()->FinishDeviceComputation();
if (enable_nccl_) {
#ifdef USE_NCCL
auto nccl_comm_ = this->nccl_comm();
auto nccl_dtype_ = this->template nccl_dtype<ParamType>();
NCCL_CHECK(ncclAllReduce(
(void*)mu,
(void*)mu,
C_,
nccl_dtype_,
ncclSum,
nccl_comm_,
((CUDAContext*)ctx())->cuda_stream()));
NCCL_CHECK(ncclAllReduce(
(void*)rsig,
(void*)rsig,
C_,
nccl_dtype_,
ncclSum,
nccl_comm_,
((CUDAContext*)ctx())->cuda_stream()));
#endif
} else {
AllReduce(mu, mu, C_);
AllReduce(rsig, rsig, C_);
}
math::Square(C_, mu, y, ctx());
math::Sub(C_, rsig, y, rsig, ctx());
// Compute running statistics
if (is_recomputing_ == 0) {
// Running(X) = (1 - momentum) * Cur(X) + momentum * Running(X)
math::Axpby(C_, 1.f - momentum_, mu, momentum_, rm, ctx());
math::Axpby(C_, 1.f - momentum_, rsig, momentum_, rv, ctx());
}
// Fuse parameters along channel axis
// [mu, rsig, alpha, beta] => [scale, bias]
math::InvStd(C_, epsilon_, rsig, rsig, ctx());
math::Mul(C_, gamma, rsig, scale, ctx());
math::Mul(C_, scale, mu, bias, ctx());
math::Sub(C_, beta, bias, bias, ctx());
// Compute affine transformation
if (data_format() == "NCHW") {
kernel::ChannelAffine(N_, S_, C_, x, scale, bias, y, ctx());
} else if (data_format() == "NHWC") {
kernel::ChannelAffine(N_ * S_, 1, C_, x, scale, bias, y, ctx());
}
}
template <class Context>
void SyncBatchNormOp<Context>::RunOnDevice() {
DetermineBaseArguments();
// Get the recomputing flag
auto* flag = workspace()->GetTensor("/share/flag/recomputing");
is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0;
// Dispatch the training or inference impl
Output(0)->ReshapeLike(Input(0));
if (Input(0).template IsType<float>()) {
if (is_training_ > 0) {
TrainingImpl<float, float>();
} else {
this->template InferenceImpl<float, float>();
}
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
}
}
template <class Context>
template <typename InputType, typename ParamType>
void SyncBatchNormGradientOp<Context>::TrainingImpl() {
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
auto *X_scale = Buffer("X_scale"), *X_bias = Buffer("X_bias");
auto* x = Input(0).template data<InputType, Context>();
auto* gamma = Input(1).template data<ParamType, Context>();
auto* dy = Input(4).template data<InputType, Context>();
auto* mu = X_mu->template data<ParamType, Context>();
auto* rsig = X_rsig->template data<ParamType, Context>();
auto* scale = X_scale->template mutable_data<ParamType, Context>();
auto* bias = X_bias->template mutable_data<ParamType, Context>();
auto* dgamma = dW->Reshape({C_})->template mutable_data<ParamType, Context>();
auto* dbeta = dB->Reshape({C_})->template mutable_data<ParamType, Context>();
// Gradient w.r.t. gamma and beta of local batch
kernel::BatchNormInternalGrad(
N_, C_, S_, data_format(), x, mu, rsig, gamma, dy, dgamma, dbeta, ctx());
// Gradient w.r.t. gamma and beta of global batch
ctx()->FinishDeviceComputation();
if (enable_nccl_) {
#ifdef USE_NCCL
auto nccl_comm_ = this->nccl_comm();
auto nccl_dtype_ = this->template nccl_dtype<ParamType>();
NCCL_CHECK(ncclAllReduce(
(void*)dgamma,
(void*)scale,
C_,
nccl_dtype_,
ncclSum,
nccl_comm_,
((CUDAContext*)ctx())->cuda_stream()));
NCCL_CHECK(ncclAllReduce(
(void*)dbeta,
(void*)bias,
C_,
nccl_dtype_,
ncclSum,
nccl_comm_,
((CUDAContext*)ctx())->cuda_stream()));
#endif
} else {
AllReduce(dgamma, scale, C_);
AllReduce(dbeta, bias, C_);
}
math::Scale(C_, ParamType(1) / comm_size_, scale, scale, ctx());
math::Scale(C_, ParamType(1) / comm_size_, bias, bias, ctx());
// Gradient w.r.t. input
kernel::BatchNormTrainingGrad(
N_,
C_,
S_,
data_format(),
x,
mu,
rsig,
gamma,
scale,
bias,
dy,
Output(0)->template mutable_data<InputType, Context>(),
ctx());
}
template <class Context>
void SyncBatchNormGradientOp<Context>::RunOnDevice() {
DetermineBaseArguments();
// Dispatch the training or inference impl
Output(0)->ReshapeLike(Input(0));
if (Input(0).template IsType<float>()) {
if (is_training_ > 0) {
TrainingImpl<float, float>();
} else {
this->template InferenceImpl<float, float>();
}
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
}
}
DEPLOY_CPU_OPERATOR(SyncBatchNorm);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SyncBatchNorm);
#endif
DEPLOY_CPU_OPERATOR(SyncBatchNormGradient);
REGISTER_CPU_OPERATOR(SyncBatchNorm, BatchNormOp<CPUContext>);
REGISTER_CPU_OPERATOR(SyncBatchNormGradient, BatchNormGradientOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SyncBatchNormGradient);
REGISTER_CUDA_OPERATOR(SyncBatchNorm, BatchNormOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(SyncBatchNormGradient, BatchNormGradientOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(SyncBatchNorm)
......
......@@ -31,9 +31,11 @@ void GroupNormOp<Context>::DoRunWithType() {
kernel::Moments(4, dims.data(), 2, axes.data(), x, mu, rsig, ctx());
}
// Inverse stddev from variance
math::InvStd(N_ * G_, epsilon_, rsig, rsig, ctx());
kernel::GroupNormForward(
// Fuse parameters to compute affine transformation
kernel::GroupNorm(
N_,
G_,
D_,
......@@ -73,7 +75,8 @@ void GroupNormGradientOp<Context>::DoRunWithType() {
auto* X_scale = Buffer("X_scale")->Reshape({N_, G_});
auto* X_bias = Buffer("X_bias")->Reshape({N_, G_});
kernel::GroupNormBackward(
// Gradient w.r.t. gamma, beta and input
kernel::GroupNormGrad(
N_,
G_,
D_,
......@@ -86,9 +89,9 @@ void GroupNormGradientOp<Context>::DoRunWithType() {
Input(2).template data<InputType, Context>(), // dy
X_scale->template mutable_data<ParamType, Context>(),
X_bias->template mutable_data<ParamType, Context>(),
dX->template mutable_data<InputType, Context>(),
dW->Reshape({C_})->template mutable_data<ParamType, Context>(),
dB->Reshape({C_})->template mutable_data<ParamType, Context>(),
dX->template mutable_data<InputType, Context>(),
ctx());
}
......
......@@ -34,7 +34,7 @@ void LSTMCellGradientOp<Context>::DoRunWithType() {
auto* dhx = Output(1)->template mutable_data<T, Context>();
if (!Input(-1).has_name()) {
math::Set(Input(-1).count(), cast::to<T>(0.f), dc, ctx());
math::Set(Input(-1).count(), convert::To<T>(0.f), dc, ctx());
}
kernel::LSTMCellGrad(
......
......@@ -315,7 +315,7 @@ void CuDNNRecurrentGradientOp<Context>::DoRunWithType() {
if (Output(1)->has_name()) {
// CuDNN accumulates the gradient of weights
// We should reset them before bakcward computing
math::Set(Output(1)->count(), cast::to<T>(0.f), yAt(1), ctx());
math::Set(Output(1)->count(), convert::To<T>(0.f), yAt(1), ctx());
CUDNN_CHECK(cudnnRNNBackwardWeights(
ctx()->cudnn_handle(),
rnn_desc_,
......
......@@ -366,11 +366,15 @@ void ConvOpBase<Context>::Reshape(bool backward) {
template class ConvOpBase<CPUContext>;
INSTANTIATE_API(CPUContext, float);
INSTANTIATE_API(CPUContext, double);
template void ConvOpBase<CPUContext>::Pb(const float16*, float16*);
template void ConvOpBase<CPUContext>::Db(const float16*, float16*);
#ifdef USE_CUDA
template class ConvOpBase<CUDAContext>;
INSTANTIATE_API(CUDAContext, float);
INSTANTIATE_API(CUDAContext, double);
template void ConvOpBase<CUDAContext>::Pb(const float16*, float16*);
template void ConvOpBase<CUDAContext>::Db(const float16*, float16*);
#endif
#undef INSTANTIATE_API
......
......@@ -46,7 +46,7 @@ void DepthwiseConv2dOp<Context>::DoRunWithType() {
template <class Context>
void DepthwiseConv2dOp<Context>::RunOnDevice() {
DispatchHelper<TensorTypes<float>>::Call(this, Input(0));
DispatchHelper<TensorTypes<float16, float>>::Call(this, Input(0));
}
template <class Context>
......@@ -111,7 +111,7 @@ void DepthwiseConv2dGradientOp<Context>::DoRunWithType() {
template <class Context>
void DepthwiseConv2dGradientOp<Context>::RunOnDevice() {
DispatchHelper<TensorTypes<float>>::Call(this, Input(0));
DispatchHelper<TensorTypes<float16, float>>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(DepthwiseConv2d);
......
......@@ -56,7 +56,7 @@ void CuDNNDepthwiseConv2dOp<Context>::DoRunWithType() {
template <class Context>
void CuDNNDepthwiseConv2dOp<Context>::RunOnDevice() {
DispatchHelper<TensorTypes<float>>::Call(this, Input(0));
DispatchHelper<TensorTypes<float16, float>>::Call(this, Input(0));
}
template <class Context>
......@@ -130,7 +130,7 @@ void CuDNNDepthwiseConv2dGradientOp<Context>::DoRunWithType() {
template <class Context>
void CuDNNDepthwiseConv2dGradientOp<Context>::RunOnDevice() {
DispatchHelper<TensorTypes<float>>::Call(this, Input(0));
DispatchHelper<TensorTypes<float16, float>>::Call(this, Input(0));
}
DEPLOY_CUDNN_OPERATOR(DepthwiseConv2d);
......
......@@ -42,7 +42,7 @@ void RoiAlignGradientOp<Context>::DoRunWithType() {
math::Set(
dX->count(),
cast::to<T>(0.f),
convert::To<T>(0.f),
dX->template mutable_data<T, Context>(),
ctx());
......
......@@ -43,7 +43,7 @@ void RoiPoolGradientOp<Context>::DoRunWithType() {
math::Set(
dX->count(),
cast::to<T>(0.f),
convert::To<T>(0.f),
dX->template mutable_data<T, Context>(),
ctx());
......
......@@ -47,6 +47,8 @@ from dragon.core.autograph.function_lib import create_function
from dragon.core.autograph.grad_impl import gradients
from dragon.core.eager.context import eager_mode
from dragon.core.eager.context import graph_mode
from dragon.core.framework.config import get_num_threads
from dragon.core.framework.config import set_num_threads
from dragon.core.framework.context import device
from dragon.core.framework.context import eager_scope
from dragon.core.framework.context import name_scope
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
import threading
from dragon import backend
class Config(object):
"""Store the common configurations for frontend."""
......@@ -36,7 +38,7 @@ class Config(object):
# The graph verbosity level.
self.graph_verbosity = 0
# The execution mode for graph.
self.graph_execution = 'GRAPH_MODE'
self.graph_execution = 'EAGER_MODE'
# The directory to store logging files.
self.log_dir = None
......@@ -56,6 +58,30 @@ def config():
return _config
def get_num_threads():
"""Return the number of threads for cpu parallelism.
Returns
-------
num : int
The number of threads to use.
"""
return backend.GetNumThreads()
def set_num_threads(num):
"""Set the number of threads for cpu parallelism.
Parameters
----------
num : int
The number of threads to use.
"""
backend.SetNumThreads(num)
def set_random_seed(seed):
"""Set the global random seed.
......
......@@ -163,14 +163,11 @@ class DragonFrontend(object):
helper.make_tensor_value_info(
name=name,
elem_type=value_info[name][0],
shape=value_info[name][1],
)
])
shape=value_info[name][1])])
except KeyError:
raise ValueError(
'Info of tensor `{}` is missing, '
'specify it in <value_info>.'.format(name)
)
'specify it in <value_info>.'.format(name))
# Add outputs.
onnx_graph.output.extend(
......@@ -238,8 +235,7 @@ class DragonFrontend(object):
raise RuntimeError(
'OpSet {} requires ONNX version >= {}. '
'({} currently installed.)'
.format(opset_version, onnx_version, onnx.__version__)
)
.format(opset_version, onnx_version, onnx.__version__))
return opset_version
@staticmethod
......
......@@ -10,26 +10,29 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_CAST_H_
#define DRAGON_UTILS_CAST_H_
#ifndef DRAGON_UTILS_CONVERSIONS_H_
#define DRAGON_UTILS_CONVERSIONS_H_
#include "dragon/core/types.h"
#include "dragon/utils/device/common_cuda.h"
namespace dragon {
#if defined(__CUDACC__)
#define CONVERSIONS_DECL inline __host__ __device__
#else
#define CONVERSIONS_DECL inline
#endif
#define HFLT_MAX 65504.F
#define HFLT_MIN 6.10e-5F
namespace dragon {
namespace cast {
namespace convert {
template <typename DType, typename SType>
DType to(SType val) {
return static_cast<DType>(val);
template <typename DestType, typename SrcType>
CONVERSIONS_DECL DestType To(SrcType val) {
return static_cast<DestType>(val);
}
template <>
inline float16 to<float16, float>(float val) {
inline float16 To<float16, float>(float val) {
float16 ret;
unsigned* xp = reinterpret_cast<unsigned int*>(&val);
unsigned x = *xp;
......@@ -78,7 +81,7 @@ inline float16 to<float16, float>(float val) {
}
template <>
inline float to<float, float16>(float16 val) {
inline float To<float, float16>(float16 val) {
unsigned sign = ((val.x >> 15) & 1);
unsigned exponent = ((val.x >> 10) & 0x1f);
unsigned mantissa = ((val.x & 0x3ff) << 13);
......@@ -108,41 +111,41 @@ inline float to<float, float16>(float16 val) {
}
template <>
inline float16 to<float16, double>(double val) {
return to<float16>(static_cast<float>(val));
inline float16 To<float16, double>(double val) {
return To<float16>(static_cast<float>(val));
}
#ifdef USE_CUDA
template <>
inline float16 to<float16, half>(half val) {
CONVERSIONS_DECL float16 To<float16, half>(half val) {
return float16{__half_raw(val).x};
}
template <>
inline half to<half, float>(float val) {
CONVERSIONS_DECL half To<half, float>(float val) {
return __float2half(val);
}
template <>
inline half to<half, float16>(float16 val) {
CONVERSIONS_DECL half To<half, float16>(float16 val) {
return __half_raw{val.x};
}
template <>
inline half2 to<half2, float>(float val) {
CONVERSIONS_DECL half2 To<half2, float>(float val) {
return __float2half2_rn(val);
}
template <>
inline half2 to<half2, float16>(float16 val) {
CONVERSIONS_DECL half2 To<half2, float16>(float16 val) {
return half2(__half2_raw{val.x, val.x});
}
#endif // USE_CUDA
} // namespace cast
} // namespace convert
} // namespace dragon
#endif // DRAGON_UTILS_CAST_H_
#endif // DRAGON_UTILS_CONVERSIONS_H_
......@@ -10,8 +10,8 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_EIGEN_UTILS_H_
#define DRAGON_UTILS_EIGEN_UTILS_H_
#ifndef DRAGON_UTILS_DEVICE_COMMON_EIGEN_H_
#define DRAGON_UTILS_DEVICE_COMMON_EIGEN_H_
#include <Eigen/Core>
......@@ -64,4 +64,4 @@ using ConstEigenArrayMap =
} // namespace dragon
#endif // DRAGON_UTILS_EIGEN_UTILS_H_
#endif // DRAGON_UTILS_DEVICE_COMMON_EIGEN_H_
......@@ -10,25 +10,27 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_OMP_UTILS_H_
#define DRAGON_UTILS_OMP_UTILS_H_
#ifndef DRAGON_UTILS_DEVICE_COMMON_OPENMP_H_
#define DRAGON_UTILS_DEVICE_COMMON_OPENMP_H_
#ifdef USE_OPENMP
#include <omp.h>
#include <algorithm>
#include "dragon/utils/device/common_eigen.h"
namespace dragon {
#define OMP_MIN_ITERATORS_PER_CORE 200000
inline int OMP_THREADS(const int N) {
int threads = std::max(N / OMP_MIN_ITERATORS_PER_CORE, 1);
return std::min(threads, omp_get_num_procs());
int nthreads = std::max(N / OMP_MIN_ITERATORS_PER_CORE, 1);
return std::min(nthreads, Eigen::nbThreads());
}
} // namespace dragon
#endif // USE_OPENMP
#endif // DRAGON_UTILS_OMP_UTILS_H_
#endif // DRAGON_UTILS_DEVICE_COMMON_OPENMP_H_
......@@ -42,7 +42,7 @@ class ConstantFiller final : public Filler<T, Context> {
void Fill(Tensor* X, Context* ctx) override {
math::Set(
X->count(),
cast::to<T>(info().value()),
convert::To<T>(info().value()),
X->mutable_data<T, Context>(),
ctx);
}
......
#include "dragon/utils/math/blas.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
namespace dragon {
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/math/blas.h"
namespace dragon {
......@@ -357,16 +357,16 @@ DRAGON_API void Axpby<float16, CUDAContext>(
if ((n & 1) == 0) {
_Axpby<<<CUDA_BLOCKS(n >> 1), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n >> 1,
cast::to<half2>(alpha),
convert::To<half2>(alpha),
reinterpret_cast<const half2*>(x),
cast::to<half2>(beta),
convert::To<half2>(beta),
reinterpret_cast<half2*>(y));
} else {
_Axpby<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n,
cast::to<half>(alpha),
convert::To<half>(alpha),
reinterpret_cast<const half*>(x),
cast::to<half>(beta),
convert::To<half>(beta),
reinterpret_cast<half*>(y));
}
}
......@@ -531,8 +531,8 @@ DRAGON_API void Gemv<float16, CUDAContext>(
LDC));
#endif
} else if (math_type == "float16") {
const half alpha_val = cast::to<half>(alpha);
const half beta_val = cast::to<half>(beta);
const half alpha_val = convert::To<half>(alpha);
const half beta_val = convert::To<half>(beta);
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
// GEMV + MATH16 + TENSOR-CORE
......@@ -733,8 +733,8 @@ DRAGON_API void Gemm<float16, CUDAContext>(
N));
#endif
} else if (math_type == "float16") {
const half alpha_val = cast::to<half>(alpha);
const half beta_val = cast::to<half>(beta);
const half alpha_val = convert::To<half>(alpha);
const half beta_val = convert::To<half>(beta);
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
// GEMM + MATH16 + TENSOR-CORE
......
#include "dragon/utils/math/broadcast.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/math/blas.h"
#include "dragon/utils/math/elementwise.h"
#include "dragon/utils/math/utils.h"
......@@ -342,7 +342,7 @@ DEFINE_ROWWISE_COLWISE_BIANRY_FUNC(Maximum, double, max);
bi += idx[d] * b_strides[d]; \
} \
y[yi] = a[ai] expr b[bi]; \
utils::math::IncreaseIndexInDims(num_dims, y_dims, idx.data()); \
math::utils::IncreaseIndexInDims(num_dims, y_dims, idx.data()); \
} \
}
......@@ -379,7 +379,7 @@ DEFINE_BROADCAST_BINARY_FUNC(GreaterEqual, bool, >=);
bi += idx[d] * b_strides[d]; \
} \
y[yi] = func(a[ai], b[bi]); \
utils::math::IncreaseIndexInDims(num_dims, y_dims, idx.data()); \
math::utils::IncreaseIndexInDims(num_dims, y_dims, idx.data()); \
} \
}
......@@ -406,7 +406,7 @@ DEFINE_BROADCAST_BINARY_FUNC(Maximum, T, std::max);
vec64_t X_dims(x_dims, x_dims + x_ndim); \
vec64_t Y_dims(y_dims, y_dims + y_ndim); \
vec64_t X_broadcast_dims, Y_broadcast_dims; \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
X_dims, Y_dims, X_broadcast_dims, Y_broadcast_dims); \
if (X_broadcast_dims == Y_broadcast_dims) { \
auto count = std::accumulate( \
......@@ -414,18 +414,18 @@ DEFINE_BROADCAST_BINARY_FUNC(Maximum, T, std::max);
Copy(count, x, y, ctx); \
return; \
} \
if (utils::math::IsRowwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
if (math::utils::IsRowwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
EigenArrayMap<T>(y, cols, rows).colwise() = \
ConstEigenVectorArrayMap<T>(x, cols); \
return; \
} \
if (utils::math::IsColwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
if (math::utils::IsColwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
EigenArrayMap<T>(y, cols, rows).rowwise() = \
ConstEigenVectorArrayMap<T>(x, rows).transpose(); \
return; \
} \
vec64_t X_broadcast_strides, _; \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
X_dims, Y_dims, X_broadcast_strides, _, _); \
const int num_dims = Y_dims.size(); \
const auto count = std::accumulate( \
......@@ -438,7 +438,7 @@ DEFINE_BROADCAST_BINARY_FUNC(Maximum, T, std::max);
xi += idx[d] * X_broadcast_strides[d]; \
} \
y[yi] = x[xi]; \
utils::math::IncreaseIndexInDims(num_dims, Y_dims.data(), idx.data()); \
math::utils::IncreaseIndexInDims(num_dims, Y_dims.data(), idx.data()); \
} \
}
......@@ -467,7 +467,7 @@ DEFINE_SET_FUNC(double);
vec64_t A_dims(a_dims, a_dims + a_ndim); \
vec64_t B_dims(b_dims, b_dims + b_ndim); \
vec64_t A_broadcast_dims, B_broadcast_dims; \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
A_dims, B_dims, A_broadcast_dims, B_broadcast_dims); \
if (A_broadcast_dims == B_broadcast_dims) { \
auto count = std::accumulate( \
......@@ -475,7 +475,7 @@ DEFINE_SET_FUNC(double);
name(count, a, b, y, ctx); \
return; \
} \
if (utils::math::IsRowwiseBroadcast( \
if (math::utils::IsRowwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
if (broadcast_1st > 0) { \
_Rowwise##name<TIn, true>(rows, cols, a, b, y); \
......@@ -484,7 +484,7 @@ DEFINE_SET_FUNC(double);
} \
return; \
} \
if (utils::math::IsColwiseBroadcast( \
if (math::utils::IsColwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
if (broadcast_1st > 0) { \
_Colwise##name<TIn, true>(rows, cols, a, b, y); \
......@@ -494,7 +494,7 @@ DEFINE_SET_FUNC(double);
return; \
} \
vec64_t A_broadcast_strides, B_broadcast_strides, Y_dims; \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
A_dims, B_dims, A_broadcast_strides, B_broadcast_strides, Y_dims); \
_Broadcast##name( \
Y_dims.size(), \
......@@ -658,13 +658,13 @@ DEFINE_BINARY_FUNC(GreaterEqual, bool);
vec64_t A_broadcast_dims, B_broadcast_dims, C_broadcast_dims; \
vec64_t A_broadcast_strides, B_broadcast_strides, C_broadcast_strides; \
vec64_t Y_dims, _, __; \
utils::math::ComputeBinaryBroadcastStrides(A_dims, B_dims, _, _, __); \
utils::math::ComputeBinaryBroadcastStrides(C_dims, __, _, _, Y_dims); \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastStrides(A_dims, B_dims, _, _, __); \
math::utils::ComputeBinaryBroadcastStrides(C_dims, __, _, _, Y_dims); \
math::utils::ComputeBinaryBroadcastDims( \
A_dims, Y_dims, A_broadcast_dims, _); \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
B_dims, Y_dims, B_broadcast_dims, _); \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
C_dims, Y_dims, C_broadcast_dims, _); \
if (A_broadcast_dims == B_broadcast_dims && \
B_broadcast_dims == C_broadcast_dims) { \
......@@ -673,11 +673,11 @@ DEFINE_BINARY_FUNC(GreaterEqual, bool);
Where(count, a, b, c, y, ctx); \
return; \
} \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
A_dims, Y_dims, A_broadcast_strides, _, _); \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
B_dims, Y_dims, B_broadcast_strides, _, _); \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
C_dims, Y_dims, C_broadcast_strides, _, _); \
const int num_dims = Y_dims.size(); \
const auto count = std::accumulate( \
......@@ -692,7 +692,7 @@ DEFINE_BINARY_FUNC(GreaterEqual, bool);
ci += idx[d] * C_broadcast_strides[d]; \
} \
y[yi] = c[ci] ? a[ai] : b[bi]; \
utils::math::IncreaseIndexInDims(num_dims, Y_dims.data(), idx.data()); \
math::utils::IncreaseIndexInDims(num_dims, Y_dims.data(), idx.data()); \
} \
}
......
......@@ -172,7 +172,7 @@ __global__ void _BroadcastWhere(
vec64_t X_dims(x_dims, x_dims + x_ndim); \
vec64_t Y_dims(y_dims, y_dims + y_ndim); \
vec64_t X_broadcast_dims, Y_broadcast_dims; \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
X_dims, Y_dims, X_broadcast_dims, Y_broadcast_dims); \
if (X_broadcast_dims == Y_broadcast_dims) { \
auto count = std::accumulate( \
......@@ -180,7 +180,7 @@ __global__ void _BroadcastWhere(
Copy(count, x, y, ctx); \
return; \
} \
if (utils::math::IsRowwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
if (math::utils::IsRowwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
const auto nthreads = rows * cols; \
_RowwiseSet<<< \
CUDA_BLOCKS(nthreads), \
......@@ -193,7 +193,7 @@ __global__ void _BroadcastWhere(
reinterpret_cast<T2*>(y)); \
return; \
} \
if (utils::math::IsColwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
if (math::utils::IsColwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
const auto nthreads = rows * cols; \
_ColwiseSet<<< \
CUDA_BLOCKS(nthreads), \
......@@ -208,7 +208,7 @@ __global__ void _BroadcastWhere(
} \
vec64_t X_broadcast_strides, _; \
CUDA_TENSOR_DIMS_CHECK((int)Y_dims.size()); \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
X_dims, Y_dims, X_broadcast_strides, _, _); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> strides, dims; \
const auto nthreads = std::accumulate( \
......@@ -255,7 +255,7 @@ DEFINE_SET_FUNC(double, double);
vec64_t A_dims(a_dims, a_dims + a_ndim); \
vec64_t B_dims(b_dims, b_dims + b_ndim); \
vec64_t A_broadcast_dims, B_broadcast_dims; \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
A_dims, B_dims, A_broadcast_dims, B_broadcast_dims); \
if (A_broadcast_dims == B_broadcast_dims) { \
auto count = std::accumulate( \
......@@ -263,7 +263,7 @@ DEFINE_SET_FUNC(double, double);
name(count, a, b, y, ctx); \
return; \
} \
if (utils::math::IsRowwiseBroadcast( \
if (math::utils::IsRowwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
const auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
......@@ -277,7 +277,7 @@ DEFINE_SET_FUNC(double, double);
} \
return; \
} \
if (utils::math::IsColwiseBroadcast( \
if (math::utils::IsColwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
const auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
......@@ -292,7 +292,7 @@ DEFINE_SET_FUNC(double, double);
return; \
} \
vec64_t A_broadcast_strides, B_broadcast_strides, Y_dims; \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
A_dims, B_dims, A_broadcast_strides, B_broadcast_strides, Y_dims); \
CUDA_TENSOR_DIMS_CHECK((int)Y_dims.size()); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> a_strides, b_strides, y_dims; \
......@@ -434,7 +434,7 @@ DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
vec64_t A_dims(a_dims, a_dims + a_ndim); \
vec64_t B_dims(b_dims, b_dims + b_ndim); \
vec64_t A_broadcast_dims, B_broadcast_dims; \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
A_dims, B_dims, A_broadcast_dims, B_broadcast_dims); \
if (A_broadcast_dims == B_broadcast_dims) { \
auto count = std::accumulate( \
......@@ -442,7 +442,7 @@ DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
name(count, a, b, y, ctx); \
return; \
} \
if (utils::math::IsRowwiseBroadcast( \
if (math::utils::IsRowwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
......@@ -466,7 +466,7 @@ DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
} \
return; \
} \
if (utils::math::IsColwiseBroadcast( \
if (math::utils::IsColwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
......@@ -491,7 +491,7 @@ DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
return; \
} \
vec64_t A_broadcast_strides, B_broadcast_strides, Y_dims; \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
A_dims, B_dims, A_broadcast_strides, B_broadcast_strides, Y_dims); \
CUDA_TENSOR_DIMS_CHECK((int)Y_dims.size()); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> a_strides, b_strides, y_dims; \
......@@ -550,13 +550,13 @@ DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, math::GreaterEqualFunctor);
vec64_t A_broadcast_dims, B_broadcast_dims, C_broadcast_dims; \
vec64_t A_broadcast_strides, B_broadcast_strides, C_broadcast_strides; \
vec64_t Y_dims, _, __; \
utils::math::ComputeBinaryBroadcastStrides(A_dims, B_dims, _, _, __); \
utils::math::ComputeBinaryBroadcastStrides(C_dims, __, _, _, Y_dims); \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastStrides(A_dims, B_dims, _, _, __); \
math::utils::ComputeBinaryBroadcastStrides(C_dims, __, _, _, Y_dims); \
math::utils::ComputeBinaryBroadcastDims( \
A_dims, Y_dims, A_broadcast_dims, _); \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
B_dims, Y_dims, B_broadcast_dims, _); \
utils::math::ComputeBinaryBroadcastDims( \
math::utils::ComputeBinaryBroadcastDims( \
C_dims, Y_dims, C_broadcast_dims, _); \
if (A_broadcast_dims == B_broadcast_dims && \
B_broadcast_dims == C_broadcast_dims) { \
......@@ -566,11 +566,11 @@ DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, math::GreaterEqualFunctor);
return; \
} \
CUDA_TENSOR_DIMS_CHECK((int)Y_dims.size()); \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
A_dims, Y_dims, A_broadcast_strides, _, _); \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
B_dims, Y_dims, B_broadcast_strides, _, _); \
utils::math::ComputeBinaryBroadcastStrides( \
math::utils::ComputeBinaryBroadcastStrides( \
C_dims, Y_dims, C_broadcast_strides, _, _); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> a_strides, b_strides, c_strides; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> y_dims; \
......
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/math/elementwise.h"
#include "dragon/utils/omp_utils.h"
namespace dragon {
......@@ -14,7 +14,7 @@ void _Cast(const int n, const Tx* x, Ty* y) {
#pragma omp parallel for num_threads(OMP_THREADS(n))
#endif
for (int i = 0; i < n; ++i) {
y[i] = cast::to<Ty>(x[i]);
y[i] = convert::To<Ty>(x[i]);
}
}
......
#include "dragon/utils/math/elementwise.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
......@@ -254,7 +254,7 @@ DRAGON_API void IsInf<float16, CPUContext>(
bool* y,
CPUContext* ctx) {
for (int i = 0; i < n; ++i) {
y[i] = utils::math::IsInf(x[i]);
y[i] = math::utils::IsInf(x[i]);
}
}
......@@ -279,7 +279,7 @@ DRAGON_API void IsNaN<float16, CPUContext>(
bool* y,
CPUContext* ctx) {
for (int i = 0; i < n; ++i) {
y[i] = utils::math::IsNaN(x[i]);
y[i] = math::utils::IsNaN(x[i]);
}
}
......@@ -306,7 +306,7 @@ DRAGON_API void ReplaceNaN<float16, CPUContext>(
CPUContext* ctx) {
EigenVectorArrayMap<float16>(y, n) =
ConstEigenVectorArrayMap<float16>(x, n).unaryExpr(
[&](float16 x) { return utils::math::IsNaN(x) ? value : x; });
[&](float16 x) { return math::utils::IsNaN(x) ? value : x; });
}
DEFINE_REPLACE_NAN_FUNC(float);
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/math/elementwise.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/math/utils.h"
......@@ -233,7 +232,7 @@ __global__ void _Set(const int n, const T alpha, T* x) {
template <typename T>
__global__ void _Sign(const int n, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = utils::math::Sign(x[i]);
y[i] = math::utils::Sign(x[i]);
}
}
......@@ -248,7 +247,7 @@ template <>
__global__ void _Sign<half>(const int n, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
const float val = __half2float(x[i]);
y[i] = __float2half(utils::math::Sign(val));
y[i] = __float2half(math::utils::Sign(val));
}
}
......@@ -257,14 +256,14 @@ __global__ void _Sign<half2>(const int n, const half2* x, half2* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
const float2 val = __half22float2(x[i]);
y[i] =
__floats2half2_rn(utils::math::Sign(val.x), utils::math::Sign(val.y));
__floats2half2_rn(math::utils::Sign(val.x), math::utils::Sign(val.y));
}
}
template <typename T>
__global__ void _Square(const int n, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = utils::math::Square(x[i]);
y[i] = math::utils::Square(x[i]);
}
}
......@@ -289,28 +288,28 @@ __global__ void _NotZero<half>(const int nthreads, const half* x, bool* y) {
template <typename T>
__global__ void _IsInf(const int n, const T* x, bool* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = utils::math::IsInf(x[i]);
y[i] = math::utils::IsInf(x[i]);
}
}
template <>
__global__ void _IsInf<half>(const int n, const half* x, bool* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = utils::math::IsInf(x[i]);
y[i] = math::utils::IsInf(x[i]);
}
}
template <typename T>
__global__ void _IsNaN(const int n, const T* x, bool* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = utils::math::IsNaN(x[i]);
y[i] = math::utils::IsNaN(x[i]);
}
}
template <>
__global__ void _IsNaN<half>(const int n, const half* x, bool* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = utils::math::IsNaN(x[i]);
y[i] = math::utils::IsNaN(x[i]);
}
}
......@@ -318,9 +317,9 @@ template <typename T>
__global__ void _ReplaceNaN(const int n, const T value, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 350
y[i] = utils::math::IsNaN(__ldg(x + i)) ? value : __ldg(x + i);
y[i] = math::utils::IsNaN(__ldg(x + i)) ? value : __ldg(x + i);
#else
y[i] = utils::math::IsNaN(x[i]) ? value : x[i];
y[i] = math::utils::IsNaN(x[i]) ? value : x[i];
#endif
}
}
......@@ -330,9 +329,9 @@ __global__ void
_ReplaceNaN<half>(const int n, const half value, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 350
y[i] = utils::math::IsNaN(__ldg(x + i)) ? value : __ldg(x + i);
y[i] = math::utils::IsNaN(__ldg(x + i)) ? value : __ldg(x + i);
#else
y[i] = utils::math::IsNaN(x[i]) ? value : x[i];
y[i] = math::utils::IsNaN(x[i]) ? value : x[i];
#endif
}
}
......@@ -526,7 +525,7 @@ DRAGON_API void Set<float16, CUDAContext>(
}
if ((n & 1) == 0) {
_Set<<<CUDA_BLOCKS(n >> 1), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n >> 1, cast::to<half2>(value), reinterpret_cast<half2*>(y));
n >> 1, convert::To<half2>(value), reinterpret_cast<half2*>(y));
} else {
_Set<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(n, value, y);
}
......@@ -561,13 +560,13 @@ DRAGON_API void InvStd<float16, CUDAContext>(
if ((n & 1) == 0) {
_InvStd<<<CUDA_BLOCKS(n >> 1), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n >> 1,
cast::to<half2>(eps),
convert::To<half2>(eps),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y));
} else {
_InvStd<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n,
cast::to<half>(eps),
convert::To<half>(eps),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
......@@ -707,7 +706,7 @@ DRAGON_API void ReplaceNaN<float16, CUDAContext>(
CUDAContext* ctx) {
_ReplaceNaN<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n,
cast::to<half>(value),
convert::To<half>(value),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
......@@ -738,14 +737,14 @@ DRAGON_API void Bias<float16, CUDAContext>(
if ((n & 1) == 0) {
_Bias<<<CUDA_BLOCKS(n >> 1), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n >> 1,
cast::to<half2>(beta),
convert::To<half2>(beta),
math::PlusFunctor<half2>(),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y));
} else {
_Bias<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n,
cast::to<half>(beta),
convert::To<half>(beta),
math::PlusFunctor<half>(),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
......
......@@ -14,7 +14,7 @@
#define DRAGON_UTILS_MATH_FUNCTIONAL_H_
#include "dragon/core/types.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
namespace dragon {
......@@ -27,7 +27,7 @@ namespace math {
template <typename T>
struct MaxFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs < rhs ? rhs : lhs;
}
#else
......@@ -40,7 +40,7 @@ struct MaxFunctor {
template <>
struct MaxFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(
......@@ -57,7 +57,7 @@ struct MaxFunctor<float16> {
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) < cast::to<float>(rhs) ? rhs : lhs;
return convert::To<float>(lhs) < convert::To<float>(rhs) ? rhs : lhs;
}
#endif
};
......@@ -65,8 +65,7 @@ struct MaxFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct MaxFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
inline __device__ half operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(lhs, rhs) ? rhs : lhs;
#else
......@@ -77,8 +76,7 @@ struct MaxFunctor<half> {
template <>
struct MaxFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
inline __device__ half2 operator()(const half2& lhs, const half2& rhs) const {
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(
......@@ -90,7 +88,7 @@ struct MaxFunctor<half2> {
template <typename T>
struct MinFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs < rhs ? lhs : rhs;
}
#else
......@@ -103,7 +101,7 @@ struct MinFunctor {
template <>
struct MinFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(
......@@ -120,7 +118,7 @@ struct MinFunctor<float16> {
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) < cast::to<float>(rhs) ? lhs : rhs;
return convert::To<float>(lhs) < convert::To<float>(rhs) ? lhs : rhs;
}
#endif
};
......@@ -128,8 +126,7 @@ struct MinFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct MinFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
inline __device__ half operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(lhs, rhs) ? lhs : rhs;
#else
......@@ -140,8 +137,7 @@ struct MinFunctor<half> {
template <>
struct MinFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
inline __device__ half2 operator()(const half2& lhs, const half2& rhs) const {
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(
......@@ -153,7 +149,7 @@ struct MinFunctor<half2> {
template <typename T>
struct PlusFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs + rhs;
}
#else
......@@ -166,7 +162,7 @@ struct PlusFunctor {
template <>
struct PlusFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
half ret = __hadd(
......@@ -181,7 +177,8 @@ struct PlusFunctor<float16> {
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(cast::to<float>(lhs) + cast::to<float>(rhs));
return convert::To<float16>(
convert::To<float>(lhs) + convert::To<float>(rhs));
}
#endif
};
......@@ -189,8 +186,7 @@ struct PlusFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct PlusFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
inline __device__ half operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hadd(lhs, rhs);
#else
......@@ -201,8 +197,7 @@ struct PlusFunctor<half> {
template <>
struct PlusFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
inline __device__ half2 operator()(const half2& lhs, const half2& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hadd2(lhs, rhs);
#else
......@@ -217,7 +212,7 @@ struct PlusFunctor<half2> {
template <typename T>
struct MinusFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs - rhs;
}
#else
......@@ -230,7 +225,7 @@ struct MinusFunctor {
template <>
struct MinusFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
half ret = __hsub(
......@@ -245,7 +240,8 @@ struct MinusFunctor<float16> {
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(cast::to<float>(lhs) - cast::to<float>(rhs));
return convert::To<float16>(
convert::To<float>(lhs) - convert::To<float>(rhs));
}
#endif
};
......@@ -253,8 +249,7 @@ struct MinusFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct MinusFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
inline __device__ half operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hsub(lhs, rhs);
#else
......@@ -265,8 +260,7 @@ struct MinusFunctor<half> {
template <>
struct MinusFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
inline __device__ half2 operator()(const half2& lhs, const half2& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hsub2(lhs, rhs);
#else
......@@ -281,7 +275,7 @@ struct MinusFunctor<half2> {
template <typename T>
struct MultipliesFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs * rhs;
}
#else
......@@ -294,7 +288,7 @@ struct MultipliesFunctor {
template <>
struct MultipliesFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
half ret = __hmul(
......@@ -309,7 +303,8 @@ struct MultipliesFunctor<float16> {
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(cast::to<float>(lhs) * cast::to<float>(rhs));
return convert::To<float16>(
convert::To<float>(lhs) * convert::To<float>(rhs));
}
#endif
};
......@@ -317,8 +312,7 @@ struct MultipliesFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct MultipliesFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
inline __device__ half operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hmul(lhs, rhs);
#else
......@@ -329,8 +323,7 @@ struct MultipliesFunctor<half> {
template <>
struct MultipliesFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
inline __device__ half2 operator()(const half2& lhs, const half2& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hmul2(lhs, rhs);
#else
......@@ -345,7 +338,7 @@ struct MultipliesFunctor<half2> {
template <typename T>
struct DividesFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
inline __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs / rhs;
}
#else
......@@ -358,7 +351,7 @@ struct DividesFunctor {
template <>
struct DividesFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
half ret = __hdiv(
......@@ -373,7 +366,8 @@ struct DividesFunctor<float16> {
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(cast::to<float>(lhs) / cast::to<float>(rhs));
return convert::To<float16>(
convert::To<float>(lhs) / convert::To<float>(rhs));
}
#endif
};
......@@ -381,8 +375,7 @@ struct DividesFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct DividesFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
inline __device__ half operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hdiv(lhs, rhs);
#else
......@@ -393,8 +386,7 @@ struct DividesFunctor<half> {
template <>
struct DividesFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
inline __device__ half2 operator()(const half2& lhs, const half2& rhs) const {
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(v1.x / v2.x, v1.y / v2.y);
......@@ -405,7 +397,7 @@ struct DividesFunctor<half2> {
template <typename T>
struct PowFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
inline __device__ T operator()(const T& lhs, const T& rhs) const {
return pow(lhs, rhs);
}
#else
......@@ -418,7 +410,7 @@ struct PowFunctor {
template <>
struct PowFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
inline __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
half ret = __float2half(
pow(__half2float(*reinterpret_cast<const half*>(&lhs)),
......@@ -427,8 +419,8 @@ struct PowFunctor<float16> {
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(
std::pow(cast::to<float>(lhs), cast::to<float>(rhs)));
return convert::To<float16>(
std::pow(convert::To<float>(lhs), convert::To<float>(rhs)));
}
#endif
};
......@@ -436,16 +428,14 @@ struct PowFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct PowFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
inline __device__ half operator()(const half& lhs, const half& rhs) const {
return __float2half(pow(__half2float(lhs), __half2float(rhs)));
}
};
template <>
struct PowFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
inline __device__ half2 operator()(const half2& lhs, const half2& rhs) const {
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(pow(v1.x, v2.x), pow(v1.y, v2.y));
......@@ -460,7 +450,7 @@ struct PowFunctor<half2> {
template <typename T>
struct EqualFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs == rhs;
}
#else
......@@ -473,9 +463,8 @@ struct EqualFunctor {
template <>
struct EqualFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __heq(
*reinterpret_cast<const half*>(&lhs),
......@@ -487,7 +476,7 @@ struct EqualFunctor<float16> {
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) == cast::to<float>(rhs);
return convert::To<float>(lhs) == convert::To<float>(rhs);
}
#endif
};
......@@ -495,8 +484,7 @@ struct EqualFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct EqualFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
inline __device__ bool operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __heq(lhs, rhs);
#else
......@@ -509,7 +497,7 @@ struct EqualFunctor<half> {
template <typename T>
struct NotEqualFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs != rhs;
}
#else
......@@ -522,9 +510,8 @@ struct NotEqualFunctor {
template <>
struct NotEqualFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hne(
*reinterpret_cast<const half*>(&lhs),
......@@ -536,7 +523,7 @@ struct NotEqualFunctor<float16> {
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) != cast::to<float>(rhs);
return convert::To<float>(lhs) != convert::To<float>(rhs);
}
#endif
};
......@@ -544,8 +531,7 @@ struct NotEqualFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct NotEqualFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
inline __device__ bool operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hne(lhs, rhs);
#else
......@@ -558,7 +544,7 @@ struct NotEqualFunctor<half> {
template <typename T>
struct GreaterFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs > rhs;
}
#else
......@@ -571,9 +557,8 @@ struct GreaterFunctor {
template <>
struct GreaterFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hgt(
*reinterpret_cast<const half*>(&lhs),
......@@ -585,7 +570,7 @@ struct GreaterFunctor<float16> {
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) > cast::to<float>(rhs);
return convert::To<float>(lhs) > convert::To<float>(rhs);
}
#endif
};
......@@ -593,8 +578,7 @@ struct GreaterFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct GreaterFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
inline __device__ bool operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hgt(lhs, rhs);
#else
......@@ -607,7 +591,7 @@ struct GreaterFunctor<half> {
template <typename T>
struct LessFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs < rhs;
}
#else
......@@ -620,9 +604,8 @@ struct LessFunctor {
template <>
struct LessFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hlt(
*reinterpret_cast<const half*>(&lhs),
......@@ -634,7 +617,7 @@ struct LessFunctor<float16> {
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) < cast::to<float>(rhs);
return convert::To<float>(lhs) < convert::To<float>(rhs);
}
#endif
};
......@@ -642,8 +625,7 @@ struct LessFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct LessFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
inline __device__ bool operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(lhs, rhs);
#else
......@@ -656,7 +638,7 @@ struct LessFunctor<half> {
template <typename T>
struct GreaterEqualFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs >= rhs;
}
#else
......@@ -669,9 +651,8 @@ struct GreaterEqualFunctor {
template <>
struct GreaterEqualFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hge(
*reinterpret_cast<const half*>(&lhs),
......@@ -683,7 +664,7 @@ struct GreaterEqualFunctor<float16> {
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) >= cast::to<float>(rhs);
return convert::To<float>(lhs) >= convert::To<float>(rhs);
}
#endif
};
......@@ -691,8 +672,7 @@ struct GreaterEqualFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct GreaterEqualFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
inline __device__ bool operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hge(lhs, rhs);
#else
......@@ -705,7 +685,7 @@ struct GreaterEqualFunctor<half> {
template <typename T>
struct LessEqualFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
inline __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs <= rhs;
}
#else
......@@ -718,9 +698,8 @@ struct LessEqualFunctor {
template <>
struct LessEqualFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
inline __device__ bool operator()(const float16& lhs, const float16& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hle(
*reinterpret_cast<const half*>(&lhs),
......@@ -732,7 +711,7 @@ struct LessEqualFunctor<float16> {
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) <= cast::to<float>(rhs);
return convert::To<float>(lhs) <= convert::To<float>(rhs);
}
#endif
};
......@@ -740,8 +719,7 @@ struct LessEqualFunctor<float16> {
#if defined(__CUDACC__)
template <>
struct LessEqualFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
inline __device__ bool operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hle(lhs, rhs);
#else
......
#include "dragon/utils/math/reduce.h"
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/device/common_eigen.h"
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/math/utils.h"
#include "dragon/utils/omp_utils.h"
namespace dragon {
......@@ -169,22 +169,22 @@ void _GenericReduceSum(
for (int i = 0; i < num_axes; ++i) \
y_dims[axes[i]] = 1; \
/* Case #1: Rowwise Reduce */ \
if (utils::math::IsRowwiseReduce( \
if (math::utils::IsRowwiseReduce( \
num_dims, dims, y_dims.data(), &rows, &cols)) { \
_RowwiseReduce##name(rows, cols, scale, x, y); \
return; \
} \
/* Case #2: Colwise Reduce */ \
if (utils::math::IsColwiseReduce( \
if (math::utils::IsColwiseReduce( \
num_dims, dims, y_dims.data(), &rows, &cols)) { \
_ColwiseReduce##name(rows, cols, scale, x, y); \
return; \
} \
/* Case #3: Generic Reduce */ \
vec32_t axesT(num_dims), stridesT(num_dims), dimsT(num_dims); \
utils::math::TransposeAxesForReduce( \
math::utils::TransposeAxesForReduce( \
num_dims, num_axes, axes, axesT.data()); \
utils::math::ComputeTransposeStrides( \
math::utils::ComputeTransposeStrides( \
num_dims, dims, axesT.data(), stridesT.data()); \
rows = cols = 1; \
const int pivot = num_dims - num_axes; \
......
......@@ -188,7 +188,7 @@ __global__ void _GenericReduce(
y, \
count, \
reducer, \
cast::to<T>(init), \
convert::To<T>(init), \
ctx->cuda_stream()); \
cub::DeviceReduce::Reduce( \
ctx->workspace()->data<CUDAContext>({ws_nbytes}, "data:1")[0], \
......@@ -197,7 +197,7 @@ __global__ void _GenericReduce(
y, \
count, \
reducer, \
cast::to<T>(init), \
convert::To<T>(init), \
ctx->cuda_stream()); \
return 0; \
} \
......@@ -206,33 +206,33 @@ __global__ void _GenericReduce(
for (int i = 0; i < num_axes; ++i) { \
out_dims[axes[i]] = 1; \
} \
if (utils::math::IsRowwiseReduce( \
if (math::utils::IsRowwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_RowwiseReduce<<< \
CUDA_2D_BLOCKS(cols), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
rows, cols, reducer, init, cast::to<T>(scale), x, y); \
rows, cols, reducer, init, convert::To<T>(scale), x, y); \
return 1; \
} \
if (utils::math::IsColwiseReduce( \
if (math::utils::IsColwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_ColwiseReduce<<< \
CUDA_2D_BLOCKS(rows), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
rows, cols, reducer, init, cast::to<T>(scale), x, y); \
rows, cols, reducer, init, convert::To<T>(scale), x, y); \
return 2; \
} \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_axes; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_strides; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_dims; \
utils::math::TransposeAxesForReduce( \
math::utils::TransposeAxesForReduce( \
num_dims, num_axes, axes, transpose_axes.data); \
utils::math::ComputeTransposeStrides( \
math::utils::ComputeTransposeStrides( \
num_dims, dims, transpose_axes.data, transpose_strides.data); \
rows = cols = 1; \
const int pivot = num_dims - num_axes; \
......@@ -257,7 +257,7 @@ __global__ void _GenericReduce(
transpose_strides, \
reducer, \
init, \
cast::to<T>(scale), \
convert::To<T>(scale), \
x, \
y); \
return 3; \
......@@ -293,7 +293,7 @@ void ReduceSum<float16, CUDAContext>(
for (int i = 0; i < num_axes; ++i) {
out_dims[axes[i]] = 1;
}
if (utils::math::IsRowwiseReduce(
if (math::utils::IsRowwiseReduce(
num_dims, dims, out_dims.data(), &rows, &cols)) {
_RowwiseReduce<<<
CUDA_2D_BLOCKS(cols),
......@@ -309,7 +309,7 @@ void ReduceSum<float16, CUDAContext>(
reinterpret_cast<half*>(y));
return;
}
if (utils::math::IsColwiseReduce(
if (math::utils::IsColwiseReduce(
num_dims, dims, out_dims.data(), &rows, &cols)) {
_ColwiseReduce<<<
CUDA_2D_BLOCKS(rows),
......@@ -329,9 +329,9 @@ void ReduceSum<float16, CUDAContext>(
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_axes;
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_strides;
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_dims;
utils::math::TransposeAxesForReduce(
math::utils::TransposeAxesForReduce(
num_dims, num_axes, axes, transpose_axes.data);
utils::math::ComputeTransposeStrides(
math::utils::ComputeTransposeStrides(
num_dims, dims, transpose_axes.data, transpose_strides.data);
rows = cols = 1;
const int pivot = num_dims - num_axes;
......@@ -408,7 +408,7 @@ DEFINE_KERNEL_LAUNCHER(
Max,
float16,
math::MaxFunctor,
cast::to<float16>(cub::Traits<half>::Lowest()));
convert::To<float16>(cub::Traits<half>::Lowest()));
DEFINE_KERNEL_LAUNCHER(
Max,
float,
......@@ -443,7 +443,7 @@ DEFINE_KERNEL_LAUNCHER(
Min,
float16,
math::MinFunctor,
cast::to<float16>(cub::Traits<half>::Max()));
convert::To<float16>(cub::Traits<half>::Max()));
DEFINE_KERNEL_LAUNCHER(
Min,
float,
......
......@@ -7,10 +7,10 @@
namespace dragon {
namespace utils {
namespace math {
namespace utils {
bool IsRowwiseBroadcast(
const vec64_t& A_dims,
const vec64_t& B_dims,
......@@ -279,8 +279,8 @@ void ComputeTransposeStrides(
}
}
} // namespace math
} // namespace utils
} // namespace math
} // namespace dragon
......@@ -13,11 +13,9 @@
#ifndef DRAGON_UTILS_MATH_UTILS_H_
#define DRAGON_UTILS_MATH_UTILS_H_
#include "dragon/core/context.h"
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/conversions.h"
#ifdef USE_CUDA
#if defined(__CUDACC__)
#define MATH_UTILS_DECL inline __host__ __device__
#else
#define MATH_UTILS_DECL inline
......@@ -32,10 +30,10 @@
namespace dragon {
namespace utils {
namespace math {
namespace utils {
template <
typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
......@@ -73,11 +71,11 @@ MATH_UTILS_DECL bool IsNaN(T x) {
}
inline bool IsInf(float16 x) {
return std::abs(cast::to<float>(x)) > HFLT_MAX;
return std::abs(convert::To<float>(x)) > HFLT_MAX;
}
inline bool IsNaN(float16 x) {
return IsNaN(cast::to<float>(x));
return IsNaN(convert::To<float>(x));
}
template <typename T>
......@@ -249,10 +247,10 @@ inline void IncreaseIndexInDims(const int num_dims, const T* dims, T* index) {
}
}
} // namespace math
} // namespace utils
} // namespace math
} // namespace dragon
#endif // DRAGON_UTILS_MATH_UTILS_H_
......@@ -844,117 +844,117 @@ void SinGrad(const int count, const T* dy, const T* x, T* dx, Context* ctx);
/* normalization.batch_norm */
template <typename Tx, typename Tp, class Context>
void BatchNormExpectation(
template <typename T, typename AccT, class Context>
void BatchNorm(
const int N,
const int C,
const int S,
const Tp denorm,
const string& data_format,
const Tx* x,
Tp* ex,
Tp* ex2,
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* beta,
AccT* scale,
AccT* bias,
T* y,
Context* ctx);
template <typename Tx, typename Tp, class Context>
void BatchNormInternalGrad(
template <typename T, typename AccT, class Context>
void BatchNormExpectation(
const int N,
const int C,
const int S,
const AccT denorm,
const string& data_format,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tx* dy,
Tp* dgamma,
Tp* dbeta,
const T* x,
AccT* ex,
AccT* ex2,
Context* ctx);
template <typename Tx, typename Tp, class Context>
void BatchNormTrainingGrad(
template <typename T, typename AccT, class Context>
void BatchNormInternalGrad(
const int N,
const int C,
const int S,
const string& data_format,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tp* dgamma,
const Tp* dbeta,
const Tx* dy,
Tx* dx,
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const T* dy,
AccT* dgamma,
AccT* dbeta,
Context* ctx);
template <typename Tx, typename Tp, class Context>
void BatchNormBackwardTraining(
template <typename T, typename AccT, class Context>
void BatchNormTrainingGrad(
const int N,
const int C,
const int S,
const string& data_format,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tx* dy,
Tx* dx,
Tp* dgamma,
Tp* dbeta,
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* dgamma,
const AccT* dbeta,
const T* dy,
T* dx,
Context* ctx);
template <typename Tx, typename Tp, class Context>
void BatchNormBackwardInference(
template <typename T, typename AccT, class Context>
void BatchNormInferenceGrad(
const int N,
const int C,
const int S,
const string& data_format,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tx* dy,
Tx* dx,
Tp* dgamma,
Tp* dbeta,
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const T* dy,
AccT* dgamma,
AccT* dbeta,
T* dx,
Context* ctx);
/* normalization.group_norm */
template <typename Tx, typename Tp, class Context>
void GroupNormForward(
template <typename T, typename AccT, class Context>
void GroupNorm(
const int N,
const int G,
const int D,
const int S,
const string& data_format,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tp* beta,
Tp* scale,
Tp* bias,
Tx* y,
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* beta,
AccT* scale,
AccT* bias,
T* y,
Context* ctx);
template <typename Tx, typename Tp, class Context>
void GroupNormBackward(
template <typename T, typename AccT, class Context>
void GroupNormGrad(
const int N,
const int G,
const int D,
const int S,
const string& data_format,
const Tx* x,
const Tp* mu,
const Tp* rsig,
const Tp* gamma,
const Tx* dy,
Tp* ds,
Tp* db,
Tx* dx,
Tp* dgamma,
Tp* dbeta,
const T* x,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const T* dy,
AccT* ds,
AccT* db,
AccT* dgamma,
AccT* dbeta,
T* dx,
Context* ctx);
/* normalization.lp_norm */
......
......@@ -49,13 +49,11 @@ def load_weights_from_pickle(f, layer, verbose=False):
if verbose:
logging.info(
'Weight({}) loaded, Size: ({})'
.format(name, ', '.join([str(d) for d in value_shape]))
)
.format(name, ', '.join([str(d) for d in value_shape])))
else:
logging.warning(
'Weight({}) is not created '
'in current workspace. Skip.'.format(name)
)
'in current workspace. Skip.'.format(name))
def save_weights_to_pickle(f, layer):
......
......@@ -49,8 +49,7 @@ def assign_weights(value_list, module):
if len(weight_list) != len(value_list):
raise ValueError(
'Excepted %d values for weights, got %d.'
% (len(weight_list), len(value_list))
)
% (len(weight_list), len(value_list)))
for weight, value in zip(weight_list, value_list):
_set_value(weight, value)
matched_info.append((weight.name, weight.shape))
......@@ -143,8 +142,7 @@ def load_hdf5_to_weights(filepath, module, skip=False):
except Exception:
raise NameError(
"The loaded hdf5 file needs to have 'layer_names' as attributes.\n"
"Please check whether this hdf5 file is saved from TL."
)
"Please check whether this hdf5 file is saved from TL.")
matched_info = _load_weights_from_hdf5_group(f, module.modules, skip)
f.close()
return matched_info
......
......@@ -25,6 +25,14 @@ from dragon.core.testing.unittest.common_utils import run_tests
from dragon.core.testing.unittest.common_utils import TEST_CUDA
class TestContext(unittest.TestCase):
"""Test the framework context."""
def test_properties(self):
dragon.random.set_seed(1337)
dragon.set_num_threads(dragon.get_num_threads())
class TestGradientTape(unittest.TestCase):
"""Test the gradient tape."""
......
......@@ -133,8 +133,7 @@ def export(
if output_names is not None:
raise ValueError(
'Excepted the output names from <outputs>.\n'
'You should set the <output_names> to None.'
)
'You should set the <output_names> to None.')
outputs, output_names = list(outputs.values()), list(outputs.keys())
else:
outputs = nest.flatten(outputs)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!