Commit 1bd78a3c by Ting PAN

Fix a numerical issue by breaking the union of CUB storages

Summary:
We found it unstable when defining CUB storages in a union.
More surveys should be taken to understand this issue.
1 parent ad83f4e4
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX # ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
# NAME: Kepler Maxwell Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere # NAME: Kepler Maxwell Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere
# NUM: Any number. Only those pairs are currently accepted by NVCC though: # NUM: Any number. Only those pairs are currently accepted by NVCC though:
# 3.5 3.7 5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5 8.0 # 3.5 3.7 5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5 8.0 8.6
# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable} # Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
# Additionally, sets ${out_variable}_readable to the resulting numeric list # Additionally, sets ${out_variable}_readable to the resulting numeric list
# Example: # Example:
...@@ -84,10 +84,21 @@ endif() ...@@ -84,10 +84,21 @@ endif()
if(CUDA_VERSION VERSION_GREATER "10.5") if(CUDA_VERSION VERSION_GREATER "10.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere") list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0" "8.0+PTX") list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0") list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0")
if(CUDA_VERSION VERSION_LESS "11.1")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
endif()
endif()
if(NOT CUDA_VERSION VERSION_LESS "11.1")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6")
if(CUDA_VERSION VERSION_LESS "12.0") if(CUDA_VERSION VERSION_LESS "12.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0") set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0")
endif() endif()
endif() endif()
...@@ -150,7 +161,7 @@ function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) ...@@ -150,7 +161,7 @@ function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
set(CUDA_GPU_DETECT_OUTPUT_FILTERED "") set(CUDA_GPU_DETECT_OUTPUT_FILTERED "")
separate_arguments(CUDA_GPU_DETECT_OUTPUT) separate_arguments(CUDA_GPU_DETECT_OUTPUT)
foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT}) foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT})
if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR
ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE)) ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE))
list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM) list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM)
string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}") string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}")
...@@ -224,8 +235,8 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) ...@@ -224,8 +235,8 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
set(arch_bin 7.5) set(arch_bin 7.5)
set(arch_ptx 7.5) set(arch_ptx 7.5)
elseif(${arch_name} STREQUAL "Ampere") elseif(${arch_name} STREQUAL "Ampere")
set(arch_bin 8.0) set(arch_bin 8.0 8.6)
set(arch_ptx 8.0) set(arch_ptx 8.6)
else() else()
message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
endif() endif()
......
...@@ -15,14 +15,13 @@ endif() ...@@ -15,14 +15,13 @@ endif()
# ---[ Merge CUDA kernels to speed up compiling # ---[ Merge CUDA kernels to speed up compiling
if (USE_CUDA) if (USE_CUDA)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/op_kernels.cu "") set(_gen_file ${CMAKE_CURRENT_BINARY_DIR}/../codegen/op_kernels.cu)
file(WRITE ${_gen_file} "")
foreach(_file ${KERNEL_CUDA_SOURCES}) foreach(_file ${KERNEL_CUDA_SOURCES})
file(STRINGS ${_file} tmp NEWLINE_CONSUME) file(STRINGS ${_file} tmp NEWLINE_CONSUME)
file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/op_kernels.cu ${tmp} "\n") file(APPEND ${_gen_file} ${tmp} "\n")
endforeach() endforeach()
set(MODULE_CUDA_SOURCES set(MODULE_CUDA_SOURCES ${MODULE_CUDA_SOURCES} ${_gen_file})
${MODULE_CUDA_SOURCES}
${CMAKE_CURRENT_BINARY_DIR}/op_kernels.cu)
endif() endif()
# ---[ Submit to the parent scope # ---[ Submit to the parent scope
......
...@@ -8,20 +8,20 @@ namespace kernel { ...@@ -8,20 +8,20 @@ namespace kernel {
namespace { namespace {
template <typename Tx, typename Ty> template <typename T, typename AccT>
void _RowwiseMoments( void _RowwiseMoments(
const int rows, const int rows,
const int cols, const int cols,
const Tx* x, const T* x,
Ty* mean, AccT* mean,
Ty* var) { AccT* var) {
const Ty scale = Ty(1) / (Ty)rows; const AccT scale = AccT(1) / AccT(rows);
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(cols)) #pragma omp parallel for num_threads(OMP_THREADS(cols))
#endif #endif
for (int i = 0; i < cols; ++i) { for (int i = 0; i < cols; ++i) {
Tx x_val; T x_val;
Ty m_val = 0, v_val = 0, mu; AccT m_val = AccT(0), v_val = AccT(0), mu;
for (int j = 0; j < rows; ++j) { for (int j = 0; j < rows; ++j) {
x_val = x[j * cols + i]; x_val = x[j * cols + i];
m_val += x_val; m_val += x_val;
...@@ -32,20 +32,20 @@ void _RowwiseMoments( ...@@ -32,20 +32,20 @@ void _RowwiseMoments(
} }
} }
template <typename Tx, typename Ty> template <typename T, typename AccT>
void _ColwiseMoments( void _ColwiseMoments(
const int rows, const int rows,
const int cols, const int cols,
const Tx* x, const T* x,
Ty* mean, AccT* mean,
Ty* var) { AccT* var) {
const Ty scale = Ty(1) / (Ty)cols; const AccT scale = AccT(1) / AccT(cols);
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(rows)) #pragma omp parallel for num_threads(OMP_THREADS(rows))
#endif #endif
for (int i = 0; i < rows; ++i) { for (int i = 0; i < rows; ++i) {
Tx x_val; T x_val;
Ty m_val = 0, v_val = 0, mu; AccT m_val = AccT(0), v_val = AccT(0), mu;
for (int j = 0; j < cols; ++j) { for (int j = 0; j < cols; ++j) {
x_val = x[i * cols + j]; x_val = x[i * cols + j];
m_val += x_val; m_val += x_val;
...@@ -56,23 +56,23 @@ void _ColwiseMoments( ...@@ -56,23 +56,23 @@ void _ColwiseMoments(
} }
} }
template <typename Tx, typename Ty> template <typename T, typename AccT>
void _GenericMoments( void _GenericMoments(
const int rows, const int rows,
const int cols, const int cols,
const int num_dims, const int num_dims,
const int* x_dims, const int* x_dims,
const int* x_strides, const int* x_strides,
const Tx* x, const T* x,
Ty* mean, AccT* mean,
Ty* var) { AccT* var) {
const Ty scale = Ty(1) / (Ty)cols; const AccT scale = AccT(1) / AccT(cols);
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(rows)) #pragma omp parallel for num_threads(OMP_THREADS(rows))
#endif #endif
for (int i = 0; i < rows; ++i) { for (int i = 0; i < rows; ++i) {
Tx x_val; T x_val;
Ty m_val = 0, v_val = 0, mu; AccT m_val = AccT(0), v_val = AccT(0), mu;
int xi, c, r; int xi, c, r;
for (int j = 0; j < cols; ++j) { for (int j = 0; j < cols; ++j) {
xi = 0; xi = 0;
...@@ -90,52 +90,58 @@ void _GenericMoments( ...@@ -90,52 +90,58 @@ void _GenericMoments(
} }
} }
template <typename Tx, typename Ty> template <typename T, typename AccT>
void _Moments( void _Moments(
const int num_dims, const int num_dims,
const int* dims, const int* dims,
const int num_axes, const int num_axes,
const int* axes, const int* axes,
const Tx* x, const T* x,
Ty* mean, AccT* mean,
Ty* var, AccT* var,
CPUContext* ctx) { CPUContext* ctx) {
int rows, cols; int rows, cols;
vec32_t y_dims(dims, dims + num_dims); vec32_t out_dims(dims, dims + num_dims);
for (int i = 0; i < num_axes; ++i) for (int i = 0; i < num_axes; ++i) {
y_dims[axes[i]] = 1; out_dims[axes[i]] = 1;
}
// Case #1: Rowwise Reduce
if (math::utils::IsRowwiseReduce( if (math::utils::IsRowwiseReduce(
num_dims, dims, y_dims.data(), &rows, &cols)) { num_dims, dims, out_dims.data(), &rows, &cols)) {
_RowwiseMoments(rows, cols, x, mean, var); _RowwiseMoments(rows, cols, x, mean, var);
return; return;
} }
// Case #2: Colwise Reduce
if (math::utils::IsColwiseReduce( if (math::utils::IsColwiseReduce(
num_dims, dims, y_dims.data(), &rows, &cols)) { num_dims, dims, out_dims.data(), &rows, &cols)) {
_ColwiseMoments(rows, cols, x, mean, var); _ColwiseMoments(rows, cols, x, mean, var);
return; return;
} }
vec32_t transpose_axes(num_dims);
// Case #3: Generic Reduce vec32_t transpose_strides(num_dims);
vec32_t axesT(num_dims), stridesT(num_dims), dimsT(num_dims); vec32_t transpose_dims(num_dims);
math::utils::TransposeAxesForReduce(num_dims, num_axes, axes, axesT.data()); math::utils::TransposeAxesForReduce(
num_dims, num_axes, axes, transpose_axes.data());
math::utils::ComputeTransposeStrides( math::utils::ComputeTransposeStrides(
num_dims, dims, axesT.data(), stridesT.data()); num_dims, dims, transpose_axes.data(), transpose_strides.data());
rows = cols = 1; rows = cols = 1;
const int pivot = num_dims - num_axes; const int pivot = num_dims - num_axes;
for (int i = 0; i < pivot; ++i) for (int i = 0; i < pivot; ++i) {
rows *= dims[axesT[i]]; rows *= dims[transpose_axes[i]];
for (int i = pivot; i < num_dims; ++i) }
cols *= dims[axesT[i]]; for (int i = pivot; i < num_dims; ++i) {
for (int i = 0; i < num_dims; ++i) cols *= dims[transpose_axes[i]];
dimsT[i] = dims[axesT[i]]; }
for (int i = 0; i < num_dims; ++i) {
transpose_dims[i] = dims[transpose_axes[i]];
}
_GenericMoments( _GenericMoments(
rows, cols, num_dims, dimsT.data(), stridesT.data(), x, mean, var); rows,
cols,
num_dims,
transpose_dims.data(),
transpose_strides.data(),
x,
mean,
var);
} }
} // namespace } // namespace
...@@ -155,16 +161,16 @@ void Moments<float16, float, CPUContext>( ...@@ -155,16 +161,16 @@ void Moments<float16, float, CPUContext>(
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
#define DEFINE_KERNEL_LAUNCHER(Tx, Ty) \ #define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \ template <> \
void Moments<Tx, Ty, CPUContext>( \ void Moments<T, AccT, CPUContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
const int* axes, \ const int* axes, \
const Tx* x, \ const T* x, \
Ty* mean, \ AccT* mean, \
Ty* var, \ AccT* var, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_Moments(num_dims, dims, num_axes, axes, x, mean, var, ctx); \ _Moments(num_dims, dims, num_axes, axes, x, mean, var, ctx); \
} }
......
...@@ -9,6 +9,8 @@ namespace dragon { ...@@ -9,6 +9,8 @@ namespace dragon {
namespace kernel { namespace kernel {
namespace {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
#define LDG(x, i) __ldg(x + i) #define LDG(x, i) __ldg(x + i)
#define LDG2(x, i) convert::To<AccT>(__ldg(x + i)) #define LDG2(x, i) convert::To<AccT>(__ldg(x + i))
...@@ -17,22 +19,18 @@ namespace kernel { ...@@ -17,22 +19,18 @@ namespace kernel {
#define LDG2(x, i) convert::To<AccT>(x[i]) #define LDG2(x, i) convert::To<AccT>(x[i])
#endif #endif
namespace {
template <typename T, typename AccT, StorageOrder kOrder> template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _BatchNormExpectation( __global__ void _BatchNormExpectation(
const int N, const int N,
const int C, const int C,
const int S, const int S,
const AccT denorm, const AccT normalizer,
const T* x, const T* x,
AccT* ex, AccT* ex,
AccT* ex2) { AccT* ex2) {
const int outer_dim = N * S; const int outer_dim = N * S;
__shared__ union { __shared__ typename BlockReduce<AccT>::TempStorage ex_storage;
typename BlockReduce<AccT>::TempStorage ex; __shared__ typename BlockReduce<AccT>::TempStorage ex2_storage;
typename BlockReduce<AccT>::TempStorage ex2;
} storage;
CUDA_2D_KERNEL_LOOP1(i, C) { CUDA_2D_KERNEL_LOOP1(i, C) {
AccT ex_val = AccT(0), ex2_val = AccT(0); AccT ex_val = AccT(0), ex2_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, outer_dim) { CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
...@@ -41,11 +39,11 @@ __global__ void _BatchNormExpectation( ...@@ -41,11 +39,11 @@ __global__ void _BatchNormExpectation(
ex_val += LDG2(x, xi); ex_val += LDG2(x, xi);
ex2_val += math::utils::Square(LDG2(x, xi)); ex2_val += math::utils::Square(LDG2(x, xi));
} }
ex_val = BlockReduce<AccT>(storage.ex).Reduce(ex_val, cub::Sum()); ex_val = BlockReduce<AccT>(ex_storage).Reduce(ex_val, cub::Sum());
ex2_val = BlockReduce<AccT>(storage.ex2).Reduce(ex2_val, cub::Sum()); ex2_val = BlockReduce<AccT>(ex2_storage).Reduce(ex2_val, cub::Sum());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
ex[i] = ex_val * denorm; ex[i] = ex_val / normalizer;
ex2[i] = ex2_val * denorm; ex2[i] = ex2_val / normalizer;
} }
} }
} }
...@@ -82,22 +80,19 @@ __global__ void _BatchNormAffine( ...@@ -82,22 +80,19 @@ __global__ void _BatchNormAffine(
} }
template <typename T, typename AccT, StorageOrder kOrder> template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _BatchNormInternalGrad( __global__ void _BatchNormWGrad(
const int N, const int N,
const int C, const int C,
const int S, const int S,
const T* x, const T* x,
const AccT* mu, const AccT* mu,
const AccT* rsig, const AccT* rsig,
const AccT* gamma,
const T* dy, const T* dy,
AccT* dgamma, AccT* dgamma,
AccT* dbeta) { AccT* dbeta) {
const int outer_dim = N * S; const int outer_dim = N * S;
__shared__ union { __shared__ typename BlockReduce<AccT>::TempStorage dg_storage;
typename BlockReduce<AccT>::TempStorage dg; __shared__ typename BlockReduce<AccT>::TempStorage db_storage;
typename BlockReduce<AccT>::TempStorage db;
} storage;
CUDA_2D_KERNEL_LOOP1(i, C) { CUDA_2D_KERNEL_LOOP1(i, C) {
AccT dg_val = AccT(0), db_val = AccT(0); AccT dg_val = AccT(0), db_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, outer_dim) { CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
...@@ -106,8 +101,8 @@ __global__ void _BatchNormInternalGrad( ...@@ -106,8 +101,8 @@ __global__ void _BatchNormInternalGrad(
dg_val += LDG2(dy, xi) * (LDG2(x, xi) - LDG(mu, i)) * LDG(rsig, i); dg_val += LDG2(dy, xi) * (LDG2(x, xi) - LDG(mu, i)) * LDG(rsig, i);
db_val += LDG2(dy, xi); db_val += LDG2(dy, xi);
} }
dg_val = BlockReduce<AccT>(storage.dg).Reduce(dg_val, cub::Sum()); dg_val = BlockReduce<AccT>(dg_storage).Sum(dg_val);
db_val = BlockReduce<AccT>(storage.db).Reduce(db_val, cub::Sum()); db_val = BlockReduce<AccT>(db_storage).Sum(db_val);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
dgamma[i] = dg_val; dgamma[i] = dg_val;
dbeta[i] = db_val; dbeta[i] = db_val;
...@@ -121,6 +116,7 @@ __global__ void _BatchNormTrainingGrad( ...@@ -121,6 +116,7 @@ __global__ void _BatchNormTrainingGrad(
const int N, const int N,
const int C, const int C,
const int S, const int S,
const AccT normalizer,
const T* x, const T* x,
const AccT* mu, const AccT* mu,
const AccT* rsig, const AccT* rsig,
...@@ -129,46 +125,15 @@ __global__ void _BatchNormTrainingGrad( ...@@ -129,46 +125,15 @@ __global__ void _BatchNormTrainingGrad(
const AccT* dbeta, const AccT* dbeta,
const T* dy, const T* dy,
T* dx) { T* dx) {
const AccT denom = AccT(1) / AccT(N * S);
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C; const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
const AccT xnorm = (LDG2(x, i) - LDG(mu, pi)) * LDG(rsig, pi);
dx[i] = convert::To<T>( dx[i] = convert::To<T>(
LDG(gamma, pi) * LDG(rsig, pi) * LDG(gamma, pi) * LDG(rsig, pi) *
(LDG2(dy, i) - fma(xnorm, LDG(dgamma, pi), LDG(dbeta, pi)) * denom)); (LDG2(dy, i) -
} fma((LDG2(x, i) - LDG(mu, pi)) * LDG(rsig, pi),
} LDG(dgamma, pi),
LDG(dbeta, pi)) /
template <typename T, typename AccT, StorageOrder kOrder> normalizer));
__global__ void _BatchNormWGrad(
const int N,
const int C,
const int S,
const T* x,
const AccT* mu,
const AccT* rsig,
const T* dy,
AccT* dgamma,
AccT* dbeta) {
const int outer_dim = N * S;
__shared__ union {
typename BlockReduce<AccT>::TempStorage dg;
typename BlockReduce<AccT>::TempStorage db;
} storage;
CUDA_2D_KERNEL_LOOP1(i, C) {
AccT dg_val = AccT(0), db_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i;
dg_val += LDG2(dy, xi) * (LDG2(x, xi) - LDG(mu, i)) * LDG(rsig, i);
db_val += LDG2(dy, xi);
}
dg_val = BlockReduce<AccT>(storage.db).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<AccT>(storage.db).Reduce(db_val, cub::Sum());
if (threadIdx.x == 0) {
dgamma[i] = dg_val;
dbeta[i] = db_val;
}
} }
} }
...@@ -248,7 +213,7 @@ __global__ void _BatchNormInferenceGrad( ...@@ -248,7 +213,7 @@ __global__ void _BatchNormInferenceGrad(
const int N, \ const int N, \
const int C, \ const int C, \
const int S, \ const int S, \
const AccT denorm, \ const float normalizer, \
const string& data_format, \ const string& data_format, \
const T* x, \ const T* x, \
AccT* ex, \ AccT* ex, \
...@@ -263,13 +228,13 @@ __global__ void _BatchNormInferenceGrad( ...@@ -263,13 +228,13 @@ __global__ void _BatchNormInferenceGrad(
N, \ N, \
C, \ C, \
S, \ S, \
denorm, \ normalizer, \
reinterpret_cast<const ScalarT*>(x), \ reinterpret_cast<const ScalarT*>(x), \
ex, \ ex, \
ex2); \ ex2); \
} \ } \
template <> \ template <> \
void BatchNormInternalGrad<T, AccT, CUDAContext>( \ void BatchNormWGrad<T, AccT, CUDAContext>( \
const int N, \ const int N, \
const int C, \ const int C, \
const int S, \ const int S, \
...@@ -277,13 +242,12 @@ __global__ void _BatchNormInferenceGrad( ...@@ -277,13 +242,12 @@ __global__ void _BatchNormInferenceGrad(
const T* x, \ const T* x, \
const AccT* mu, \ const AccT* mu, \
const AccT* rsig, \ const AccT* rsig, \
const AccT* gamma, \
const T* dy, \ const T* dy, \
AccT* dgamma, \ AccT* dgamma, \
AccT* dbeta, \ AccT* dbeta, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
DISPATCH_BATCHNORM_KERNEL( \ DISPATCH_BATCHNORM_KERNEL( \
_BatchNormInternalGrad, \ _BatchNormWGrad, \
ScalarT, \ ScalarT, \
AccT, \ AccT, \
CUDA_2D_BLOCKS(C), \ CUDA_2D_BLOCKS(C), \
...@@ -294,7 +258,6 @@ __global__ void _BatchNormInferenceGrad( ...@@ -294,7 +258,6 @@ __global__ void _BatchNormInferenceGrad(
reinterpret_cast<const ScalarT*>(x), \ reinterpret_cast<const ScalarT*>(x), \
mu, \ mu, \
rsig, \ rsig, \
gamma, \
reinterpret_cast<const ScalarT*>(dy), \ reinterpret_cast<const ScalarT*>(dy), \
dgamma, \ dgamma, \
dbeta); \ dbeta); \
...@@ -304,6 +267,7 @@ __global__ void _BatchNormInferenceGrad( ...@@ -304,6 +267,7 @@ __global__ void _BatchNormInferenceGrad(
const int N, \ const int N, \
const int C, \ const int C, \
const int S, \ const int S, \
const float normalizer, \
const string& data_format, \ const string& data_format, \
const T* x, \ const T* x, \
const AccT* mu, \ const AccT* mu, \
...@@ -325,6 +289,7 @@ __global__ void _BatchNormInferenceGrad( ...@@ -325,6 +289,7 @@ __global__ void _BatchNormInferenceGrad(
N, \ N, \
C, \ C, \
S, \ S, \
normalizer, \
reinterpret_cast<const ScalarT*>(x), \ reinterpret_cast<const ScalarT*>(x), \
mu, \ mu, \
rsig, \ rsig, \
...@@ -384,8 +349,10 @@ __global__ void _BatchNormInferenceGrad( ...@@ -384,8 +349,10 @@ __global__ void _BatchNormInferenceGrad(
DEFINE_KERNEL_LAUNCHER(float16, half, float); DEFINE_KERNEL_LAUNCHER(float16, half, float);
DEFINE_KERNEL_LAUNCHER(float, float, float); DEFINE_KERNEL_LAUNCHER(float, float, float);
DEFINE_KERNEL_LAUNCHER(double, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float); DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float); DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(double, double, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
#undef DISPATCH_BATCHNORM_KERNEL #undef DISPATCH_BATCHNORM_KERNEL
......
...@@ -228,7 +228,9 @@ void GroupNormGrad<float16, float, CPUContext>( ...@@ -228,7 +228,9 @@ void GroupNormGrad<float16, float, CPUContext>(
} }
DEFINE_KERNEL_LAUNCHER(float, float); DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float); DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(double, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
......
...@@ -9,6 +9,8 @@ namespace dragon { ...@@ -9,6 +9,8 @@ namespace dragon {
namespace kernel { namespace kernel {
namespace {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
#define LDG(x, i) __ldg(x + i) #define LDG(x, i) __ldg(x + i)
#define LDG2(x, i) convert::To<AccT>(__ldg(x + i)) #define LDG2(x, i) convert::To<AccT>(__ldg(x + i))
...@@ -17,8 +19,6 @@ namespace kernel { ...@@ -17,8 +19,6 @@ namespace kernel {
#define LDG2(x, i) convert::To<AccT>(x[i]) #define LDG2(x, i) convert::To<AccT>(x[i])
#endif #endif
namespace {
template <typename T> template <typename T>
__global__ void _GroupNormFusedParams( __global__ void _GroupNormFusedParams(
const int N, const int N,
...@@ -99,10 +99,8 @@ __global__ void _GroupNormWGrad( ...@@ -99,10 +99,8 @@ __global__ void _GroupNormWGrad(
AccT* dbeta) { AccT* dbeta) {
const int outer_dim = G * D; const int outer_dim = G * D;
const int inner_dim = N * S; const int inner_dim = N * S;
__shared__ union { __shared__ typename BlockReduce<AccT>::TempStorage dg_storage;
typename BlockReduce<AccT>::TempStorage dg; __shared__ typename BlockReduce<AccT>::TempStorage db_storage;
typename BlockReduce<AccT>::TempStorage db;
} storage;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) { CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
AccT dg_val = AccT(0), db_val = AccT(0); AccT dg_val = AccT(0), db_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, inner_dim) { CUDA_2D_KERNEL_LOOP2(j, inner_dim) {
...@@ -114,8 +112,8 @@ __global__ void _GroupNormWGrad( ...@@ -114,8 +112,8 @@ __global__ void _GroupNormWGrad(
dg_val += LDG2(dy, xi) * (LDG2(x, xi) - LDG(mu, mi)) * LDG(rsig, mi); dg_val += LDG2(dy, xi) * (LDG2(x, xi) - LDG(mu, mi)) * LDG(rsig, mi);
db_val += LDG2(dy, xi); db_val += LDG2(dy, xi);
} }
dg_val = BlockReduce<AccT>(storage.dg).Reduce(dg_val, cub::Sum()); dg_val = BlockReduce<AccT>(dg_storage).Sum(dg_val);
db_val = BlockReduce<AccT>(storage.db).Reduce(db_val, cub::Sum()); db_val = BlockReduce<AccT>(db_storage).Sum(db_val);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
dgamma[i] = dg_val; dgamma[i] = dg_val;
dbeta[i] = db_val; dbeta[i] = db_val;
...@@ -136,10 +134,8 @@ __global__ void _GroupNormInternalGrad( ...@@ -136,10 +134,8 @@ __global__ void _GroupNormInternalGrad(
AccT* db) { AccT* db) {
const int outer_dim = N * G; const int outer_dim = N * G;
const int inner_dim = D * S; const int inner_dim = D * S;
__shared__ union { __shared__ typename BlockReduce<AccT>::TempStorage ds_storage;
typename BlockReduce<AccT>::TempStorage ds; __shared__ typename BlockReduce<AccT>::TempStorage db_storage;
typename BlockReduce<AccT>::TempStorage db;
} storage;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) { CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
AccT ds_val = AccT(0), db_val = AccT(0); AccT ds_val = AccT(0), db_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, inner_dim) { CUDA_2D_KERNEL_LOOP2(j, inner_dim) {
...@@ -150,8 +146,8 @@ __global__ void _GroupNormInternalGrad( ...@@ -150,8 +146,8 @@ __global__ void _GroupNormInternalGrad(
ds_val += LDG(gamma, gi) * LDG2(dy, xi) * LDG2(x, xi); ds_val += LDG(gamma, gi) * LDG2(dy, xi) * LDG2(x, xi);
db_val += LDG(gamma, gi) * LDG2(dy, xi); db_val += LDG(gamma, gi) * LDG2(dy, xi);
} }
ds_val = BlockReduce<AccT>(storage.ds).Reduce(ds_val, cub::Sum()); ds_val = BlockReduce<AccT>(ds_storage).Sum(ds_val);
db_val = BlockReduce<AccT>(storage.db).Reduce(db_val, cub::Sum()); db_val = BlockReduce<AccT>(db_storage).Sum(db_val);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
ds[i] = ds_val; ds[i] = ds_val;
db[i] = db_val; db[i] = db_val;
...@@ -330,8 +326,10 @@ __global__ void _GroupNormGrad( ...@@ -330,8 +326,10 @@ __global__ void _GroupNormGrad(
DEFINE_KERNEL_LAUNCHER(float16, half, float); DEFINE_KERNEL_LAUNCHER(float16, half, float);
DEFINE_KERNEL_LAUNCHER(float, float, float); DEFINE_KERNEL_LAUNCHER(float, float, float);
DEFINE_KERNEL_LAUNCHER(double, double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float); DEFINE_GRAD_KERNEL_LAUNCHER(float16, half, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float); DEFINE_GRAD_KERNEL_LAUNCHER(float, float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(double, double, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER #undef DEFINE_GRAD_KERNEL_LAUNCHER
#undef DISPATCH_GROUPNORM_KERNEL #undef DISPATCH_GROUPNORM_KERNEL
......
...@@ -12,8 +12,8 @@ void _L1Normalize( ...@@ -12,8 +12,8 @@ void _L1Normalize(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const T scale, const T normalizer,
const T eps, const T epsilon,
const T* x, const T* x,
T* y) { T* y) {
const auto dim = reduce_dim * inner_dim; const auto dim = reduce_dim * inner_dim;
...@@ -24,7 +24,7 @@ void _L1Normalize( ...@@ -24,7 +24,7 @@ void _L1Normalize(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim)); x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
EigenStridedVectorMap<T>( EigenStridedVectorMap<T>(
y + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) = y + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
X / std::max(X.template lpNorm<1>() * scale, eps); X / std::max(X.template lpNorm<1>() / normalizer, epsilon);
} }
} }
} }
...@@ -34,8 +34,8 @@ void _L2Normalize( ...@@ -34,8 +34,8 @@ void _L2Normalize(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const T scale, const T normalizer,
const T eps, const T epsilon,
const T* x, const T* x,
T* y) { T* y) {
const auto dim = reduce_dim * inner_dim; const auto dim = reduce_dim * inner_dim;
...@@ -46,7 +46,7 @@ void _L2Normalize( ...@@ -46,7 +46,7 @@ void _L2Normalize(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim)); x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
EigenStridedVectorMap<T>( EigenStridedVectorMap<T>(
y + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) = y + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
X / std::max(std::sqrt(X.squaredNorm() * scale), eps); X / std::max(std::sqrt(X.squaredNorm() / normalizer), epsilon);
} }
} }
} }
...@@ -56,8 +56,8 @@ void _L1NormalizeGrad( ...@@ -56,8 +56,8 @@ void _L1NormalizeGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const T scale, const T normalizer,
const T eps, const T epsilon,
const T* dy, const T* dy,
const T* x, const T* x,
T* dx) { T* dx) {
...@@ -69,11 +69,12 @@ void _L1NormalizeGrad( ...@@ -69,11 +69,12 @@ void _L1NormalizeGrad(
dy + offset, 1, reduce_dim, EigenInnerStride(inner_dim)); dy + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
auto X = ConstEigenStridedVectorMap<T>( auto X = ConstEigenStridedVectorMap<T>(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim)); x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
auto norm = std::max(X.template lpNorm<1>() * scale, eps); auto norm = std::max(X.template lpNorm<1>() / normalizer, epsilon);
auto norm2 = std::pow(norm, 2); auto norm2 = std::pow(norm, 2);
EigenStridedVectorMap<T>( EigenStridedVectorMap<T>(
dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) = dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
(dY / norm) - (X.array().sign().matrix() / norm2) * dY.dot(X) * scale; (dY / norm) -
(X.array().sign().matrix() / norm2) * dY.dot(X) / normalizer;
} }
} }
} }
...@@ -83,8 +84,8 @@ void _L2NormalizeGrad( ...@@ -83,8 +84,8 @@ void _L2NormalizeGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const T scale, const T normalizer,
const T eps, const T epsilon,
const T* dy, const T* dy,
const T* x, const T* x,
T* dx) { T* dx) {
...@@ -96,11 +97,11 @@ void _L2NormalizeGrad( ...@@ -96,11 +97,11 @@ void _L2NormalizeGrad(
dy + offset, 1, reduce_dim, EigenInnerStride(inner_dim)); dy + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
auto X = ConstEigenStridedVectorMap<T>( auto X = ConstEigenStridedVectorMap<T>(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim)); x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
auto norm = std::max(std::sqrt(X.squaredNorm() * scale), eps); auto norm = std::max(std::sqrt(X.squaredNorm() / normalizer), epsilon);
auto norm3 = std::pow(norm, 3); auto norm3 = std::pow(norm, 3);
EigenStridedVectorMap<T>( EigenStridedVectorMap<T>(
dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) = dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
(dY / norm) - ((X / norm3) * dY.dot(X) * scale); (dY / norm) - ((X / norm3) * dY.dot(X) / normalizer);
} }
} }
} }
...@@ -114,8 +115,8 @@ void L1Normalize<float16, CPUContext>( ...@@ -114,8 +115,8 @@ void L1Normalize<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const float scale, const float normalizer,
const float eps, const float epsilon,
const float16* x, const float16* x,
float16* y, float16* y,
CPUContext* ctx) { CPUContext* ctx) {
...@@ -127,8 +128,8 @@ void L2Normalize<float16, CPUContext>( ...@@ -127,8 +128,8 @@ void L2Normalize<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const float scale, const float normalizer,
const float eps, const float epsilon,
const float16* x, const float16* x,
float16* y, float16* y,
CPUContext* ctx) { CPUContext* ctx) {
...@@ -140,8 +141,8 @@ void L1NormalizeGrad<float16, CPUContext>( ...@@ -140,8 +141,8 @@ void L1NormalizeGrad<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const float scale, const float normalizer,
const float eps, const float epsilon,
const float16* dy, const float16* dy,
const float16* x, const float16* x,
float16* dx, float16* dx,
...@@ -154,8 +155,8 @@ void L2NormalizeGrad<float16, CPUContext>( ...@@ -154,8 +155,8 @@ void L2NormalizeGrad<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const float scale, const float normalizer,
const float eps, const float epsilon,
const float16* dy, const float16* dy,
const float16* x, const float16* x,
float16* dx, float16* dx,
...@@ -163,33 +164,33 @@ void L2NormalizeGrad<float16, CPUContext>( ...@@ -163,33 +164,33 @@ void L2NormalizeGrad<float16, CPUContext>(
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} // L2NormalizeGrad } // L2NormalizeGrad
#define DEFINE_KERNEL_LAUNCHER(name, T) \ #define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void name<T, CPUContext>( \ void name<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \ const int reduce_dim, \
const float scale, \ const float normalizer, \
const float eps, \ const float eps, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, x, y); \ _##name<T>(outer_dim, inner_dim, reduce_dim, normalizer, eps, x, y); \
} }
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void name<T, CPUContext>( \ void name<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \ const int reduce_dim, \
const float scale, \ const float normalizer, \
const float eps, \ const float eps, \
const T* dy, \ const T* dy, \
const T* x, \ const T* x, \
T* dx, \ T* dx, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name(outer_dim, inner_dim, reduce_dim, (T)scale, (T)eps, dy, x, dx); \ _##name<T>(outer_dim, inner_dim, reduce_dim, normalizer, eps, dy, x, dx); \
} }
DEFINE_KERNEL_LAUNCHER(L1Normalize, float); DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
......
...@@ -16,8 +16,8 @@ __global__ void _L1Normalize( ...@@ -16,8 +16,8 @@ __global__ void _L1Normalize(
const int nblocks, const int nblocks,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const AccT scale, const AccT normalizer,
const AccT eps, const AccT epsilon,
const T* x, const T* x,
T* y) { T* y) {
__shared__ AccT norm; __shared__ AccT norm;
...@@ -30,7 +30,7 @@ __global__ void _L1Normalize( ...@@ -30,7 +30,7 @@ __global__ void _L1Normalize(
} }
sum = BlockReduce<AccT>(storage).Sum(sum); sum = BlockReduce<AccT>(storage).Sum(sum);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
norm = max(sum * scale, eps); norm = max(sum / normalizer, epsilon);
} }
__syncthreads(); __syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) { CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
...@@ -45,8 +45,8 @@ __global__ void _L2Normalize( ...@@ -45,8 +45,8 @@ __global__ void _L2Normalize(
const int nblocks, const int nblocks,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const AccT scale, const AccT normalizer,
const AccT eps, const AccT epsilon,
const T* x, const T* x,
T* y) { T* y) {
__shared__ AccT norm; __shared__ AccT norm;
...@@ -59,7 +59,7 @@ __global__ void _L2Normalize( ...@@ -59,7 +59,7 @@ __global__ void _L2Normalize(
} }
sum = BlockReduce<AccT>(storage).Sum(sum); sum = BlockReduce<AccT>(storage).Sum(sum);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
norm = max(sqrt(sum * scale), eps); norm = max(sqrt(sum / normalizer), epsilon);
} }
__syncthreads(); __syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) { CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
...@@ -74,8 +74,8 @@ __global__ void _L1NormalizeGrad( ...@@ -74,8 +74,8 @@ __global__ void _L1NormalizeGrad(
const int nblocks, const int nblocks,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const AccT scale, const AccT normalizer,
const AccT eps, const AccT epsilon,
const T* dy, const T* dy,
const T* x, const T* x,
T* dx) { T* dx) {
...@@ -92,9 +92,9 @@ __global__ void _L1NormalizeGrad( ...@@ -92,9 +92,9 @@ __global__ void _L1NormalizeGrad(
val1 = BlockReduce<AccT>(storage).Sum(val1); val1 = BlockReduce<AccT>(storage).Sum(val1);
val2 = BlockReduce<AccT>(storage).Sum(val2); val2 = BlockReduce<AccT>(storage).Sum(val2);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
norm = max(val1 * scale, eps); norm = max(val1 / normalizer, epsilon);
norm2 = pow(norm, 2); norm2 = pow(norm, 2);
sum = val2 * scale; sum = val2 / normalizer;
} }
__syncthreads(); __syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) { CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
...@@ -111,8 +111,8 @@ __global__ void _L2NormalizeGrad( ...@@ -111,8 +111,8 @@ __global__ void _L2NormalizeGrad(
const int nblocks, const int nblocks,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const AccT scale, const AccT normalizer,
const AccT eps, const AccT epsilon,
const T* dy, const T* dy,
const T* x, const T* x,
T* dx) { T* dx) {
...@@ -129,9 +129,9 @@ __global__ void _L2NormalizeGrad( ...@@ -129,9 +129,9 @@ __global__ void _L2NormalizeGrad(
val1 = BlockReduce<AccT>(storage).Sum(val1); val1 = BlockReduce<AccT>(storage).Sum(val1);
val2 = BlockReduce<AccT>(storage).Sum(val2); val2 = BlockReduce<AccT>(storage).Sum(val2);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
norm = max(sqrt(val1 * scale), eps); norm = max(sqrt(val1 / normalizer), epsilon);
norm3 = pow(norm, 3); norm3 = pow(norm, 3);
sum = val2 * scale; sum = val2 / normalizer;
} }
__syncthreads(); __syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) { CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
...@@ -153,8 +153,8 @@ __global__ void _L2NormalizeGrad( ...@@ -153,8 +153,8 @@ __global__ void _L2NormalizeGrad(
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \ const int reduce_dim, \
const float scale, \ const float normalizer, \
const float eps, \ const float epsilon, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
...@@ -164,8 +164,8 @@ __global__ void _L2NormalizeGrad( ...@@ -164,8 +164,8 @@ __global__ void _L2NormalizeGrad(
nblocks, \ nblocks, \
inner_dim, \ inner_dim, \
reduce_dim, \ reduce_dim, \
scale, \ AccT(normalizer), \
eps, \ AccT(epsilon), \
reinterpret_cast<const ScalarT*>(x), \ reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<ScalarT*>(y)); \ reinterpret_cast<ScalarT*>(y)); \
} }
...@@ -176,8 +176,8 @@ __global__ void _L2NormalizeGrad( ...@@ -176,8 +176,8 @@ __global__ void _L2NormalizeGrad(
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int reduce_dim, \ const int reduce_dim, \
const float scale, \ const float normalizer, \
const float eps, \ const float epsilon, \
const T* dy, \ const T* dy, \
const T* x, \ const T* x, \
T* dx, \ T* dx, \
...@@ -188,8 +188,8 @@ __global__ void _L2NormalizeGrad( ...@@ -188,8 +188,8 @@ __global__ void _L2NormalizeGrad(
nblocks, \ nblocks, \
inner_dim, \ inner_dim, \
reduce_dim, \ reduce_dim, \
scale, \ AccT(normalizer), \
eps, \ AccT(epsilon), \
reinterpret_cast<const ScalarT*>(dy), \ reinterpret_cast<const ScalarT*>(dy), \
reinterpret_cast<const ScalarT*>(x), \ reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<ScalarT*>(dx)); \ reinterpret_cast<ScalarT*>(dx)); \
......
...@@ -26,6 +26,10 @@ class CollectiveOpBase : public Operator<Context> { ...@@ -26,6 +26,10 @@ class CollectiveOpBase : public Operator<Context> {
public: public:
CollectiveOpBase(const OperatorDef& def, Workspace* ws) CollectiveOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
comm_rank_(0),
comm_size_(1),
comm_root_(0),
enable_nccl_(false),
comm_((MPI_Comm)OP_SINGLE_ARG(int64_t, "comm", 0)), comm_((MPI_Comm)OP_SINGLE_ARG(int64_t, "comm", 0)),
group_((MPI_Group)OP_SINGLE_ARG(int64_t, "group", 0)) { group_((MPI_Group)OP_SINGLE_ARG(int64_t, "group", 0)) {
if ((int64_t)comm_ == 0) return; if ((int64_t)comm_ == 0) return;
......
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename InputType, typename ParamType> template <typename T>
void BatchNormOp<Context>::TrainingImpl() { void BatchNormOp<Context>::TrainingImpl() {
using ParamType = typename math::utils::AccmulatorType<T>::type;
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType);
...@@ -18,7 +19,7 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -18,7 +19,7 @@ void BatchNormOp<Context>::TrainingImpl() {
auto* X_scale = Buffer("X_scale")->Reshape({C_}); auto* X_scale = Buffer("X_scale")->Reshape({C_});
auto* X_bias = Buffer("X_bias")->Reshape({C_}); auto* X_bias = Buffer("X_bias")->Reshape({C_});
auto* x = Input(0).template data<InputType, Context>(); auto* x = Input(0).template data<T, Context>();
auto* rm = Input(3).template mutable_data<ParamType, Context>(); auto* rm = Input(3).template mutable_data<ParamType, Context>();
auto* rv = Input(4).template mutable_data<ParamType, Context>(); auto* rv = Input(4).template mutable_data<ParamType, Context>();
auto* mu = X_mu->template mutable_data<ParamType, Context>(); auto* mu = X_mu->template mutable_data<ParamType, Context>();
...@@ -33,7 +34,7 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -33,7 +34,7 @@ void BatchNormOp<Context>::TrainingImpl() {
N_, N_,
C_, C_,
S_, S_,
ParamType(1) / (N_ * comm_size_ * S_), float(N_ * S_ * comm_size_),
data_format(), data_format(),
x, x,
mu, mu,
...@@ -43,23 +44,23 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -43,23 +44,23 @@ void BatchNormOp<Context>::TrainingImpl() {
ctx()->FinishDeviceComputation(); ctx()->FinishDeviceComputation();
if (enable_nccl_) { if (enable_nccl_) {
#ifdef USE_NCCL #ifdef USE_NCCL
auto nccl_comm_ = this->nccl_comm(); auto coll_comm = this->nccl_comm();
auto nccl_dtype_ = this->template nccl_dtype<ParamType>(); auto coll_dtype = this->template nccl_dtype<ParamType>();
NCCL_CHECK(ncclAllReduce( NCCL_CHECK(ncclAllReduce(
(void*)mu, (void*)mu,
(void*)mu, (void*)mu,
C_, C_,
nccl_dtype_, coll_dtype,
ncclSum, ncclSum,
nccl_comm_, coll_comm,
((CUDAContext*)ctx())->cuda_stream())); ((CUDAContext*)ctx())->cuda_stream()));
NCCL_CHECK(ncclAllReduce( NCCL_CHECK(ncclAllReduce(
(void*)rsig, (void*)rsig,
(void*)rsig, (void*)rsig,
C_, C_,
nccl_dtype_, coll_dtype,
ncclSum, ncclSum,
nccl_comm_, coll_comm,
((CUDAContext*)ctx())->cuda_stream())); ((CUDAContext*)ctx())->cuda_stream()));
#endif // USE_NCCL #endif // USE_NCCL
} else { } else {
...@@ -103,13 +104,14 @@ void BatchNormOp<Context>::TrainingImpl() { ...@@ -103,13 +104,14 @@ void BatchNormOp<Context>::TrainingImpl() {
Input(2).template data<ParamType, Context>(), // beta Input(2).template data<ParamType, Context>(), // beta
scale, scale,
X_bias->template mutable_data<ParamType, Context>(), X_bias->template mutable_data<ParamType, Context>(),
Output(0)->template mutable_data<InputType, Context>(), Output(0)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
template <class Context> template <class Context>
template <typename InputType, typename ParamType> template <typename T>
void BatchNormOp<Context>::InferenceImpl() { void BatchNormOp<Context>::InferenceImpl() {
using ParamType = typename math::utils::AccmulatorType<T>::type;
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType);
...@@ -130,14 +132,14 @@ void BatchNormOp<Context>::InferenceImpl() { ...@@ -130,14 +132,14 @@ void BatchNormOp<Context>::InferenceImpl() {
C_, C_,
S_, S_,
data_format(), data_format(),
Input(0).template data<InputType, Context>(), Input(0).template data<T, Context>(),
Input(3).template data<ParamType, Context>(), Input(3).template data<ParamType, Context>(),
rsig, rsig,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamType, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta Input(2).template data<ParamType, Context>(), // beta
X_scale->template mutable_data<ParamType, Context>(), X_scale->template mutable_data<ParamType, Context>(),
X_bias->template mutable_data<ParamType, Context>(), X_bias->template mutable_data<ParamType, Context>(),
Output(0)->template mutable_data<InputType, Context>(), Output(0)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -149,80 +151,59 @@ void BatchNormOp<Context>::RunOnDevice() { ...@@ -149,80 +151,59 @@ void BatchNormOp<Context>::RunOnDevice() {
auto* flag = workspace()->GetTensor("/share/flag/recomputing"); auto* flag = workspace()->GetTensor("/share/flag/recomputing");
is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0; is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0;
// Dispatch the training or inference impl // Dispatch the training or inference implementation
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (Input(0).template IsType<float>()) { DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
if (is_training_) {
TrainingImpl<float, float>();
} else {
InferenceImpl<float, float>();
}
} else if (Input(0).template IsType<float16>()) {
if (is_training_) {
TrainingImpl<float16, float>();
} else {
InferenceImpl<float16, float>();
}
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
}
} }
template <class Context> template <class Context>
template <typename InputType, typename ParamType> template <typename T>
void BatchNormGradientOp<Context>::TrainingImpl() { void BatchNormGradientOp<Context>::TrainingImpl() {
using ParamType = typename math::utils::AccmulatorType<T>::type;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig"); auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
auto *X_scale = Buffer("X_scale"), *X_bias = Buffer("X_bias");
auto* x = Input(0).template data<InputType, Context>(); auto* x = Input(0).template data<T, Context>();
auto* gamma = Input(1).template data<ParamType, Context>(); auto* gamma = Input(1).template data<ParamType, Context>();
auto* dy = Input(4).template data<InputType, Context>(); auto* dy = Input(4).template data<T, Context>();
auto* mu = X_mu->template data<ParamType, Context>(); auto* mu = X_mu->template data<ParamType, Context>();
auto* rsig = X_rsig->template data<ParamType, Context>(); auto* rsig = X_rsig->template data<ParamType, Context>();
auto* scale = X_scale->template mutable_data<ParamType, Context>();
auto* bias = X_bias->template mutable_data<ParamType, Context>();
auto* dgamma = dW->Reshape({C_})->template mutable_data<ParamType, Context>(); auto* dgamma = dW->Reshape({C_})->template mutable_data<ParamType, Context>();
auto* dbeta = dB->Reshape({C_})->template mutable_data<ParamType, Context>(); auto* dbeta = dB->Reshape({C_})->template mutable_data<ParamType, Context>();
// Gradient w.r.t. gamma and beta // Gradient w.r.t. gamma and beta
kernel::BatchNormInternalGrad( kernel::BatchNormWGrad(
N_, C_, S_, data_format(), x, mu, rsig, gamma, dy, dgamma, dbeta, ctx()); N_, C_, S_, data_format(), x, mu, rsig, dy, dgamma, dbeta, ctx());
if (sync_stats_ > 0) { if (sync_stats_ > 0) {
#ifdef USE_MPI #ifdef USE_MPI
ctx()->FinishDeviceComputation(); ctx()->FinishDeviceComputation();
if (enable_nccl_) { if (enable_nccl_) {
#ifdef USE_NCCL #ifdef USE_NCCL
auto nccl_comm_ = this->nccl_comm(); auto coll_comm = this->nccl_comm();
auto nccl_dtype_ = this->template nccl_dtype<ParamType>(); auto coll_dtype = this->template nccl_dtype<ParamType>();
NCCL_CHECK(ncclAllReduce( NCCL_CHECK(ncclAllReduce(
(void*)dgamma, (void*)dgamma,
(void*)scale, (void*)dgamma,
C_, C_,
nccl_dtype_, coll_dtype,
ncclSum, ncclSum,
nccl_comm_, coll_comm,
((CUDAContext*)ctx())->cuda_stream())); ((CUDAContext*)ctx())->cuda_stream()));
NCCL_CHECK(ncclAllReduce( NCCL_CHECK(ncclAllReduce(
(void*)dbeta, (void*)dbeta,
(void*)bias, (void*)dbeta,
C_, C_,
nccl_dtype_, coll_dtype,
ncclSum, ncclSum,
nccl_comm_, coll_comm,
((CUDAContext*)ctx())->cuda_stream())); ((CUDAContext*)ctx())->cuda_stream()));
#endif // USE_NCCL #endif // USE_NCCL
} else { } else {
AllReduce(dgamma, scale, C_); AllReduce(dgamma, dgamma, C_);
AllReduce(dbeta, bias, C_); AllReduce(dbeta, dbeta, C_);
} }
math::Scale(C_, ParamType(1) / comm_size_, scale, scale, ctx());
math::Scale(C_, ParamType(1) / comm_size_, bias, bias, ctx());
#endif // USE_MPI #endif // USE_MPI
} else {
scale = dgamma, bias = dbeta;
} }
// Gradient w.r.t. input // Gradient w.r.t. input
...@@ -230,21 +211,27 @@ void BatchNormGradientOp<Context>::TrainingImpl() { ...@@ -230,21 +211,27 @@ void BatchNormGradientOp<Context>::TrainingImpl() {
N_, N_,
C_, C_,
S_, S_,
#ifdef USE_MPI
float(N_ * S_ * comm_size_),
#else
float(N_ * S_),
#endif
data_format(), data_format(),
x, x,
mu, mu,
rsig, rsig,
gamma, gamma,
scale, dgamma,
bias, dbeta,
dy, dy,
Output(0)->template mutable_data<InputType, Context>(), Output(0)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
template <class Context> template <class Context>
template <typename InputType, typename ParamType> template <typename T>
void BatchNormGradientOp<Context>::InferenceImpl() { void BatchNormGradientOp<Context>::InferenceImpl() {
using ParamType = typename math::utils::AccmulatorType<T>::type;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto* X_scale = Buffer("X_scale")->Reshape({C_}); auto* X_scale = Buffer("X_scale")->Reshape({C_});
...@@ -267,14 +254,14 @@ void BatchNormGradientOp<Context>::InferenceImpl() { ...@@ -267,14 +254,14 @@ void BatchNormGradientOp<Context>::InferenceImpl() {
C_, C_,
S_, S_,
data_format(), data_format(),
Input(0).template data<InputType, Context>(), // x Input(0).template data<T, Context>(), // x
Input(2).template data<ParamType, Context>(), // rm Input(2).template data<ParamType, Context>(), // rm
rsig, rsig,
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamType, Context>(), // gamma
Input(4).template data<InputType, Context>(), // dy Input(4).template data<T, Context>(), // dy
dgamma, dgamma,
dbeta, dbeta,
dX->template mutable_data<InputType, Context>(), dX->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -282,24 +269,9 @@ template <class Context> ...@@ -282,24 +269,9 @@ template <class Context>
void BatchNormGradientOp<Context>::RunOnDevice() { void BatchNormGradientOp<Context>::RunOnDevice() {
DetermineBaseArguments(); DetermineBaseArguments();
// Dispatch the training or inference impl // Dispatch the training or inference implementation
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (Input(0).template IsType<float>()) { DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
if (is_training_ > 0) {
TrainingImpl<float, float>();
} else {
InferenceImpl<float, float>();
}
} else if (Input(0).template IsType<float16>()) {
if (is_training_ > 0) {
TrainingImpl<float16, float>();
} else {
InferenceImpl<float16, float>();
}
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
}
} }
DEPLOY_CPU_OPERATOR(BatchNorm); DEPLOY_CPU_OPERATOR(BatchNorm);
......
...@@ -91,11 +91,20 @@ class BatchNormOp : public BatchNormOpBase<Context> { ...@@ -91,11 +91,20 @@ class BatchNormOp : public BatchNormOpBase<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename InputType, typename ParamType> template <typename T>
void TrainingImpl(); void TrainingImpl();
template <typename InputType, typename ParamType> template <typename T>
void InferenceImpl(); void InferenceImpl();
template <typename T>
void DoRunWithType() {
if (is_training_) {
TrainingImpl<T>();
} else {
InferenceImpl<T>();
}
};
}; };
template <class Context> template <class Context>
...@@ -111,11 +120,20 @@ class BatchNormGradientOp : public BatchNormOpBase<Context> { ...@@ -111,11 +120,20 @@ class BatchNormGradientOp : public BatchNormOpBase<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename InputType, typename ParamType> template <typename T>
void TrainingImpl(); void TrainingImpl();
template <typename InputType, typename ParamType> template <typename T>
void InferenceImpl(); void InferenceImpl();
template <typename T>
void DoRunWithType() {
if (is_training_) {
TrainingImpl<T>();
} else {
InferenceImpl<T>();
}
};
}; };
#ifdef USE_CUDNN #ifdef USE_CUDNN
...@@ -179,6 +197,15 @@ class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> { ...@@ -179,6 +197,15 @@ class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> {
template <typename T> template <typename T>
void TrainingImpl(); void TrainingImpl();
template <typename T>
void DoRunWithType() {
if (is_training_) {
TrainingImpl<T>();
} else {
this->template InferenceImpl<T>();
}
};
protected: protected:
cudnnTensorDescriptor_t input_desc_, bn_desc_; cudnnTensorDescriptor_t input_desc_, bn_desc_;
cudnnBatchNormMode_t bn_mode_; cudnnBatchNormMode_t bn_mode_;
......
...@@ -9,9 +9,13 @@ namespace dragon { ...@@ -9,9 +9,13 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void CuDNNBatchNormOp<Context>::DoRunWithType() { void CuDNNBatchNormOp<Context>::DoRunWithType() {
typedef typename CuDNNType<T>::BNParamType ParamType; using ParamType = typename CuDNNType<T>::BNParamType;
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamType);
// Determine the bn desc // Determine the descriptors
if (Input(0).ndim() == 2) { if (Input(0).ndim() == 2) {
bn_mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; bn_mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
CuDNNSetTensorDesc<T>(&input_desc_, vec64_t({N_, C_, 1, 1})); CuDNNSetTensorDesc<T>(&input_desc_, vec64_t({N_, C_, 1, 1}));
...@@ -19,15 +23,9 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() { ...@@ -19,15 +23,9 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() {
bn_mode_ = CUDNN_BATCHNORM_SPATIAL; bn_mode_ = CUDNN_BATCHNORM_SPATIAL;
CuDNNSetTensorDesc<T>(&input_desc_, Input(0).dims(), data_format()); CuDNNSetTensorDesc<T>(&input_desc_, Input(0).dims(), data_format());
} }
// Derive the bn desc
CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(bn_desc_, input_desc_, bn_mode_)); CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(bn_desc_, input_desc_, bn_mode_));
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType); // Dispatch the training or inference implementation
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(3), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(4), vec64_t({C_}), ParamType);
if (is_training_ > 0) { if (is_training_ > 0) {
auto* X_mu = Buffer("X_mu")->Reshape({C_}); auto* X_mu = Buffer("X_mu")->Reshape({C_});
auto* X_rsig = Buffer("X_rsig")->Reshape({C_}); auto* X_rsig = Buffer("X_rsig")->Reshape({C_});
...@@ -78,24 +76,17 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() { ...@@ -78,24 +76,17 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() {
// Dispatch the training or inference impl // Dispatch the training or inference impl
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (Input(0).template IsType<float>()) { DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
DoRunWithType<float>();
} else if (Input(0).template IsType<float16>()) {
DoRunWithType<float16>();
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
}
} }
template <class Context> template <class Context>
template <typename T> template <typename T>
void CuDNNBatchNormGradientOp<Context>::TrainingImpl() { void CuDNNBatchNormGradientOp<Context>::TrainingImpl() {
typedef typename CuDNNType<T>::BNParamType ParamType; using ParamType = typename CuDNNType<T>::BNParamType;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig"); auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
// Determine the bn desc // Determine the descriptors
if (Input(0).ndim() == 2) { if (Input(0).ndim() == 2) {
bn_mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; bn_mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
CuDNNSetTensorDesc<T>(&input_desc_, vec64_t({N_, C_, 1, 1})); CuDNNSetTensorDesc<T>(&input_desc_, vec64_t({N_, C_, 1, 1}));
...@@ -103,8 +94,6 @@ void CuDNNBatchNormGradientOp<Context>::TrainingImpl() { ...@@ -103,8 +94,6 @@ void CuDNNBatchNormGradientOp<Context>::TrainingImpl() {
bn_mode_ = CUDNN_BATCHNORM_SPATIAL; bn_mode_ = CUDNN_BATCHNORM_SPATIAL;
CuDNNSetTensorDesc<T>(&input_desc_, Input(0).dims(), data_format()); CuDNNSetTensorDesc<T>(&input_desc_, Input(0).dims(), data_format());
} }
// Derive the bn desc
CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(bn_desc_, input_desc_, bn_mode_)); CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(bn_desc_, input_desc_, bn_mode_));
// Gradient w.r.t. gamma, beta and input // Gradient w.r.t. gamma, beta and input
...@@ -134,25 +123,9 @@ template <class Context> ...@@ -134,25 +123,9 @@ template <class Context>
void CuDNNBatchNormGradientOp<Context>::RunOnDevice() { void CuDNNBatchNormGradientOp<Context>::RunOnDevice() {
DetermineBaseArguments(); DetermineBaseArguments();
// Dispatch the training or inference impl // Dispatch the training or inference implementation
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (Input(0).template IsType<float>()) { DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
if (is_training_ > 0) {
TrainingImpl<float>();
} else {
this->template InferenceImpl<float, float>();
}
} else if (Input(0).template IsType<float16>()) {
if (is_training_ > 0) {
TrainingImpl<float16>();
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
}
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
}
} }
DEPLOY_CUDNN_OPERATOR(BatchNorm); DEPLOY_CUDNN_OPERATOR(BatchNorm);
......
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
template <typename InputType, typename ParamType> template <typename T>
void GroupNormOp<Context>::DoRunWithType() { void GroupNormOp<Context>::DoRunWithType() {
using ParamType = typename math::utils::AccmulatorType<T>::type;
TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(1), vec64_t({C_}), ParamType);
TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType); TENSOR_FILL_WITH_TYPE(Input(2), vec64_t({C_}), ParamType);
...@@ -16,7 +17,7 @@ void GroupNormOp<Context>::DoRunWithType() { ...@@ -16,7 +17,7 @@ void GroupNormOp<Context>::DoRunWithType() {
auto* X_scale = Buffer("X_scale")->Reshape({N_, C_}); auto* X_scale = Buffer("X_scale")->Reshape({N_, C_});
auto* X_bias = Buffer("X_bias")->Reshape({N_, C_}); auto* X_bias = Buffer("X_bias")->Reshape({N_, C_});
auto* x = Input(0).template data<InputType, Context>(); auto* x = Input(0).template data<T, Context>();
auto* mu = X_mu->template mutable_data<ParamType, Context>(); auto* mu = X_mu->template mutable_data<ParamType, Context>();
auto* rsig = X_rsig->template mutable_data<ParamType, Context>(); auto* rsig = X_rsig->template mutable_data<ParamType, Context>();
...@@ -48,7 +49,7 @@ void GroupNormOp<Context>::DoRunWithType() { ...@@ -48,7 +49,7 @@ void GroupNormOp<Context>::DoRunWithType() {
Input(2).template data<ParamType, Context>(), // beta Input(2).template data<ParamType, Context>(), // beta
X_scale->template mutable_data<ParamType, Context>(), X_scale->template mutable_data<ParamType, Context>(),
X_bias->template mutable_data<ParamType, Context>(), X_bias->template mutable_data<ParamType, Context>(),
Output(0)->template mutable_data<InputType, Context>(), Output(0)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -56,21 +57,15 @@ template <class Context> ...@@ -56,21 +57,15 @@ template <class Context>
void GroupNormOp<Context>::RunOnDevice() { void GroupNormOp<Context>::RunOnDevice() {
DetermineBaseArguments(); DetermineBaseArguments();
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
if (Input(0).template IsType<float>()) {
DoRunWithType<float, float>();
} else if (Input(0).template IsType<float16>()) {
DoRunWithType<float16, float>();
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
}
} }
template <class Context> template <class Context>
template <typename InputType, typename ParamType> template <typename T>
void GroupNormGradientOp<Context>::DoRunWithType() { void GroupNormGradientOp<Context>::DoRunWithType() {
using ParamType = typename math::utils::AccmulatorType<T>::type;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2); auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig"); auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
auto* X_scale = Buffer("X_scale")->Reshape({N_, G_}); auto* X_scale = Buffer("X_scale")->Reshape({N_, G_});
auto* X_bias = Buffer("X_bias")->Reshape({N_, G_}); auto* X_bias = Buffer("X_bias")->Reshape({N_, G_});
...@@ -82,16 +77,16 @@ void GroupNormGradientOp<Context>::DoRunWithType() { ...@@ -82,16 +77,16 @@ void GroupNormGradientOp<Context>::DoRunWithType() {
D_, D_,
S_, S_,
data_format(), data_format(),
Input(0).template data<InputType, Context>(), // x Input(0).template data<T, Context>(), // x
X_mu->template data<ParamType, Context>(), X_mu->template data<ParamType, Context>(),
X_rsig->template data<ParamType, Context>(), X_rsig->template data<ParamType, Context>(),
Input(1).template data<ParamType, Context>(), // gamma Input(1).template data<ParamType, Context>(), // gamma
Input(2).template data<InputType, Context>(), // dy Input(2).template data<T, Context>(), // dy
X_scale->template mutable_data<ParamType, Context>(), X_scale->template mutable_data<ParamType, Context>(),
X_bias->template mutable_data<ParamType, Context>(), X_bias->template mutable_data<ParamType, Context>(),
dW->Reshape({C_})->template mutable_data<ParamType, Context>(), dW->Reshape({C_})->template mutable_data<ParamType, Context>(),
dB->Reshape({C_})->template mutable_data<ParamType, Context>(), dB->Reshape({C_})->template mutable_data<ParamType, Context>(),
dX->template mutable_data<InputType, Context>(), dX->template mutable_data<T, Context>(),
ctx()); ctx());
} }
...@@ -99,15 +94,7 @@ template <class Context> ...@@ -99,15 +94,7 @@ template <class Context>
void GroupNormGradientOp<Context>::RunOnDevice() { void GroupNormGradientOp<Context>::RunOnDevice() {
DetermineBaseArguments(); DetermineBaseArguments();
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
if (Input(0).template IsType<float>()) {
DoRunWithType<float, float>();
} else if (Input(0).template IsType<float16>()) {
DoRunWithType<float16, float>();
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float16", "float32"});
}
} }
DEPLOY_CPU_OPERATOR(GroupNorm); DEPLOY_CPU_OPERATOR(GroupNorm);
......
...@@ -70,7 +70,7 @@ class GroupNormOp final : public GroupNormOpBase<Context> { ...@@ -70,7 +70,7 @@ class GroupNormOp final : public GroupNormOpBase<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename InputType, typename ParamType> template <typename T>
void DoRunWithType(); void DoRunWithType();
}; };
...@@ -84,7 +84,7 @@ class GroupNormGradientOp final : public GroupNormOpBase<Context> { ...@@ -84,7 +84,7 @@ class GroupNormGradientOp final : public GroupNormOpBase<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename InputType, typename ParamType> template <typename T>
void DoRunWithType(); void DoRunWithType();
}; };
......
...@@ -30,7 +30,7 @@ void LpNormalizeOp<Context>::DoRunWithType() { ...@@ -30,7 +30,7 @@ void LpNormalizeOp<Context>::DoRunWithType() {
X.count(0, axis), X.count(0, axis),
X.count(axis + num_axes), X.count(axis + num_axes),
reduce_dim, reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f, reduction_ == "MEAN" ? float(reduce_dim) : 1.f,
epsilon_, epsilon_,
X.template data<T, Context>(), X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
...@@ -40,7 +40,7 @@ void LpNormalizeOp<Context>::DoRunWithType() { ...@@ -40,7 +40,7 @@ void LpNormalizeOp<Context>::DoRunWithType() {
X.count(0, axis), X.count(0, axis),
X.count(axis + num_axes), X.count(axis + num_axes),
reduce_dim, reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f, reduction_ == "MEAN" ? float(reduce_dim) : 1.f,
epsilon_, epsilon_,
X.template data<T, Context>(), X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
...@@ -67,7 +67,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() { ...@@ -67,7 +67,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
X.count(0, axis), X.count(0, axis),
X.count(axis + num_axes), X.count(axis + num_axes),
reduce_dim, reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f, reduction_ == "MEAN" ? float(reduce_dim) : 1.f,
epsilon_, epsilon_,
dY.template data<T, Context>(), dY.template data<T, Context>(),
X.template data<T, Context>(), X.template data<T, Context>(),
...@@ -78,7 +78,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() { ...@@ -78,7 +78,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
X.count(0, axis), X.count(0, axis),
X.count(axis + num_axes), X.count(axis + num_axes),
reduce_dim, reduce_dim,
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f, reduction_ == "MEAN" ? float(reduce_dim) : 1.f,
epsilon_, epsilon_,
dY.template data<T, Context>(), dY.template data<T, Context>(),
X.template data<T, Context>(), X.template data<T, Context>(),
......
...@@ -153,7 +153,7 @@ setuptools.setup( ...@@ -153,7 +153,7 @@ setuptools.setup(
package_data={'dragon': find_package_data()}, package_data={'dragon': find_package_data()},
package_dir={'dragon': 'dragon'}, package_dir={'dragon': 'dragon'},
cmdclass={'bdist_wheel': bdist_wheel, 'install': install}, cmdclass={'bdist_wheel': bdist_wheel, 'install': install},
python_requires='>=3.5', python_requires='>=3.6',
install_requires=['numpy', 'protobuf', 'kpl-dataset'], install_requires=['numpy', 'protobuf', 'kpl-dataset'],
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', 'Development Status :: 5 - Production/Stable',
...@@ -164,10 +164,10 @@ setuptools.setup( ...@@ -164,10 +164,10 @@ setuptools.setup(
'Programming Language :: C++', 'Programming Language :: C++',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
......
...@@ -165,37 +165,47 @@ void _GenericReduceSum( ...@@ -165,37 +165,47 @@ void _GenericReduceSum(
return; \ return; \
} \ } \
int rows, cols; \ int rows, cols; \
vec32_t y_dims(dims, dims + num_dims); \ vec32_t out_dims(dims, dims + num_dims); \
for (int i = 0; i < num_axes; ++i) \ for (int i = 0; i < num_axes; ++i) { \
y_dims[axes[i]] = 1; \ out_dims[axes[i]] = 1; \
/* Case #1: Rowwise Reduce */ \ } \
if (math::utils::IsRowwiseReduce( \ if (math::utils::IsRowwiseReduce( \
num_dims, dims, y_dims.data(), &rows, &cols)) { \ num_dims, dims, out_dims.data(), &rows, &cols)) { \
_RowwiseReduce##name(rows, cols, scale, x, y); \ _RowwiseReduce##name(rows, cols, scale, x, y); \
return; \ return; \
} \ } \
/* Case #2: Colwise Reduce */ \
if (math::utils::IsColwiseReduce( \ if (math::utils::IsColwiseReduce( \
num_dims, dims, y_dims.data(), &rows, &cols)) { \ num_dims, dims, out_dims.data(), &rows, &cols)) { \
_ColwiseReduce##name(rows, cols, scale, x, y); \ _ColwiseReduce##name(rows, cols, scale, x, y); \
return; \ return; \
} \ } \
/* Case #3: Generic Reduce */ \ vec32_t transpose_axes(num_dims); \
vec32_t axesT(num_dims), stridesT(num_dims), dimsT(num_dims); \ vec32_t transpose_strides(num_dims); \
vec32_t transpose_dims(num_dims); \
math::utils::TransposeAxesForReduce( \ math::utils::TransposeAxesForReduce( \
num_dims, num_axes, axes, axesT.data()); \ num_dims, num_axes, axes, transpose_axes.data()); \
math::utils::ComputeTransposeStrides( \ math::utils::ComputeTransposeStrides( \
num_dims, dims, axesT.data(), stridesT.data()); \ num_dims, dims, transpose_axes.data(), transpose_strides.data()); \
rows = cols = 1; \ rows = cols = 1; \
const int pivot = num_dims - num_axes; \ const int pivot = num_dims - num_axes; \
for (int i = 0; i < pivot; ++i) \ for (int i = 0; i < pivot; ++i) { \
rows *= dims[axesT[i]]; \ rows *= dims[transpose_axes[i]]; \
for (int i = pivot; i < num_dims; ++i) \ } \
cols *= dims[axesT[i]]; \ for (int i = pivot; i < num_dims; ++i) { \
for (int i = 0; i < num_dims; ++i) \ cols *= dims[transpose_axes[i]]; \
dimsT[i] = dims[axesT[i]]; \ } \
for (int i = 0; i < num_dims; ++i) { \
transpose_dims[i] = dims[transpose_axes[i]]; \
} \
_GenericReduce##name( \ _GenericReduce##name( \
rows, cols, num_dims, dimsT.data(), stridesT.data(), scale, x, y); \ rows, \
cols, \
num_dims, \
transpose_dims.data(), \
transpose_strides.data(), \
scale, \
x, \
y); \
} }
DEFINE_REDUCE_FUNC(Max); DEFINE_REDUCE_FUNC(Max);
......
...@@ -34,6 +34,24 @@ namespace math { ...@@ -34,6 +34,24 @@ namespace math {
namespace utils { namespace utils {
template <typename T>
class AccmulatorType {
public:
typedef float type;
};
template <>
class AccmulatorType<int64_t> {
public:
typedef double type;
};
template <>
class AccmulatorType<double> {
public:
typedef double type;
};
template < template <
typename T, typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0> typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
......
...@@ -811,15 +811,15 @@ void CosGrad(const int count, const T* dy, const T* x, T* dx, Context* ctx); ...@@ -811,15 +811,15 @@ void CosGrad(const int count, const T* dy, const T* x, T* dx, Context* ctx);
/* math.moments */ /* math.moments */
template <typename Tx, typename Ty, class Context> template <typename T, typename AccT, class Context>
void Moments( void Moments(
const int num_dims, const int num_dims,
const int* dims, const int* dims,
const int num_axes, const int num_axes,
const int* axes, const int* axes,
const Tx* x, const T* x,
Ty* mean, AccT* mean,
Ty* var, AccT* var,
Context* ctx); Context* ctx);
/* math.reciprocal */ /* math.reciprocal */
...@@ -845,35 +845,35 @@ void SinGrad(const int count, const T* dy, const T* x, T* dx, Context* ctx); ...@@ -845,35 +845,35 @@ void SinGrad(const int count, const T* dy, const T* x, T* dx, Context* ctx);
/* normalization.batch_norm */ /* normalization.batch_norm */
template <typename T, typename AccT, class Context> template <typename T, typename AccT, class Context>
void BatchNorm( void BatchNormExpectation(
const int N, const int N,
const int C, const int C,
const int S, const int S,
const float normalizer,
const string& data_format, const string& data_format,
const T* x, const T* x,
const AccT* mu, AccT* ex,
const AccT* rsig, AccT* ex2,
const AccT* gamma,
const AccT* beta,
AccT* scale,
AccT* bias,
T* y,
Context* ctx); Context* ctx);
template <typename T, typename AccT, class Context> template <typename T, typename AccT, class Context>
void BatchNormExpectation( void BatchNorm(
const int N, const int N,
const int C, const int C,
const int S, const int S,
const AccT denorm,
const string& data_format, const string& data_format,
const T* x, const T* x,
AccT* ex, const AccT* mu,
AccT* ex2, const AccT* rsig,
const AccT* gamma,
const AccT* beta,
AccT* scale,
AccT* bias,
T* y,
Context* ctx); Context* ctx);
template <typename T, typename AccT, class Context> template <typename T, typename AccT, class Context>
void BatchNormInternalGrad( void BatchNormWGrad(
const int N, const int N,
const int C, const int C,
const int S, const int S,
...@@ -881,7 +881,6 @@ void BatchNormInternalGrad( ...@@ -881,7 +881,6 @@ void BatchNormInternalGrad(
const T* x, const T* x,
const AccT* mu, const AccT* mu,
const AccT* rsig, const AccT* rsig,
const AccT* gamma,
const T* dy, const T* dy,
AccT* dgamma, AccT* dgamma,
AccT* dbeta, AccT* dbeta,
...@@ -892,6 +891,7 @@ void BatchNormTrainingGrad( ...@@ -892,6 +891,7 @@ void BatchNormTrainingGrad(
const int N, const int N,
const int C, const int C,
const int S, const int S,
const float normalizer,
const string& data_format, const string& data_format,
const T* x, const T* x,
const AccT* mu, const AccT* mu,
...@@ -964,8 +964,8 @@ void L1Normalize( ...@@ -964,8 +964,8 @@ void L1Normalize(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const float scale, const float normalizer,
const float eps, const float epsilon,
const T* x, const T* x,
T* y, T* y,
Context* ctx); Context* ctx);
...@@ -975,8 +975,8 @@ void L1NormalizeGrad( ...@@ -975,8 +975,8 @@ void L1NormalizeGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const float scale, const float normalizer,
const float eps, const float epsilon,
const T* dy, const T* dy,
const T* x, const T* x,
T* dx, T* dx,
...@@ -987,8 +987,8 @@ void L2Normalize( ...@@ -987,8 +987,8 @@ void L2Normalize(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const float scale, const float normalizer,
const float eps, const float epsilon,
const T* x, const T* x,
T* y, T* y,
Context* ctx); Context* ctx);
...@@ -998,8 +998,8 @@ void L2NormalizeGrad( ...@@ -998,8 +998,8 @@ void L2NormalizeGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int reduce_dim, const int reduce_dim,
const float scale, const float normalizer,
const float eps, const float epsilon,
const T* dy, const T* dy,
const T* x, const T* x,
T* dx, T* dx,
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!