Commit 46feba80 by Ting PAN

Instantiate dispatch template by value for crucial CUDA kernels

Summary:
This commit instantiates CUDA kernels by using constant dimensions
to enable the optimization during compiler-time.
1 parent 936c351b
FROM ubuntu:16.04
FROM ubuntu:18.04
RUN \
apt-get update && apt-get install -y \
......@@ -43,8 +43,8 @@ RUN \
-DPYTHON_EXECUTABLE=/usr/bin/python3 \
-DUSE_CUDA=OFF \
-DUSE_CUDNN=OFF \
-DUSE_AVX2=OFF \
-DUSE_FMA=OFF && \
-DUSE_AVX2=ON \
-DUSE_FMA=ON && \
make install -j $(nproc) && \
cd .. && rm -rf build && \
python3 setup.py install
......
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
FROM nvidia/cuda:10.2-cudnn8-devel-ubuntu18.04
RUN \
rm /etc/apt/sources.list.d/cuda.list && \
......@@ -48,8 +48,8 @@ RUN \
-DPYTHON_EXECUTABLE=/usr/bin/python3 \
-DUSE_MPI=ON \
-DUSE_NCCL=ON \
-DUSE_AVX2=OFF \
-DUSE_FMA=OFF && \
-DUSE_AVX2=ON \
-DUSE_FMA=ON && \
make install -j $(nproc) && \
cd .. && rm -rf build && \
python3 setup.py install
......
......@@ -62,10 +62,6 @@ class CUDAObjects {
} else {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#elif CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
}
#endif
}
return handles[stream_id];
......@@ -437,7 +433,8 @@ class DRAGON_API CUDAContext {
CUDA_NOT_COMPILED;
}
/*! \brief Switch to the device and select given stream in current thread */
/*! \brief Switch to the device and select given stream in current
* thread */
void SwitchToDevice(int stream_id) {
CUDA_NOT_COMPILED;
}
......
......@@ -10,60 +10,76 @@ namespace kernels {
namespace {
template <typename T, int D>
template <typename T, typename AccT, int D>
__global__ void _ReduceSumGrad(
const int N,
const int num_dims,
const SimpleArray<int, D> X_dims,
const SimpleArray<int, D> Y_dims,
const SimpleArray<int, D> Y_strides,
const float scale,
const AccT scale,
const T* dy,
T* dx) {
CUDA_1D_KERNEL_LOOP(xi, N) {
int yi = 0, tmp = xi;
for (int d = num_dims - 1; d >= 0; --d) {
#pragma unroll
for (int d = D - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(X_dims.data[d], tmp, &tmp, &r);
yi += (r % Y_dims.data[d]) * Y_strides.data[d];
}
dx[xi] = convert::To<T>(convert::To<float>(__ldg(dy + yi)) * scale);
dx[xi] = convert::To<T>(convert::To<AccT>(__ldg(dy + yi)) * scale);
}
}
template <typename T, typename AccT, int D>
void _ReduceSumGradImpl(
const int64_t* x_dims,
const int64_t* y_dims,
const int64_t* y_strides,
const AccT scale,
const T* dy,
T* dx,
CUDAContext* ctx) {
SimpleArray<int, D> X_dims, Y_dims, Y_strides;
const auto N =
std::accumulate(x_dims, x_dims + D, 1, std::multiplies<int64_t>());
for (int i = 0; i < D; ++i) {
X_dims.data[i] = x_dims[i];
Y_dims.data[i] = y_dims[i];
Y_strides.data[i] = y_strides[i];
}
_ReduceSumGrad<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, X_dims, Y_dims, Y_strides, scale, dy, dx);
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void ReduceSumGrad<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_dims, \
const int64_t* y_dims, \
const int64_t* y_strides, \
const float scale, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_dims, Y_dims, Y_strides; \
const auto N = std::accumulate( \
x_dims, x_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_dims.data[i] = x_dims[i]; \
Y_dims.data[i] = y_dims[i]; \
Y_strides.data[i] = y_strides[i]; \
} \
_ReduceSumGrad<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
num_dims, \
X_dims, \
Y_dims, \
Y_strides, \
scale, \
reinterpret_cast<const math::ScalarType<T>::type*>(dy), \
reinterpret_cast<math::ScalarType<T>::type*>(dx)); \
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void ReduceSumGrad<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_dims, \
const int64_t* y_dims, \
const int64_t* y_strides, \
const float scale, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
DISPATCH_FUNC_BY_VALUE_WITH_TYPE_2( \
_ReduceSumGradImpl, \
math::ScalarType<T>::type, \
math::AccmulatorType<T>::type, \
num_dims, \
x_dims, \
y_dims, \
y_strides, \
convert::To<math::AccmulatorType<T>::type>(scale), \
reinterpret_cast<const math::ScalarType<T>::type*>(dy), \
reinterpret_cast<math::ScalarType<T>::type*>(dx), \
ctx); \
}
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
......
......@@ -12,7 +12,6 @@ namespace {
template <typename T, int D>
__global__ void _Roll(
const int N,
const int num_dims,
const SimpleArray<int, D> X_shifts,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims,
......@@ -20,7 +19,8 @@ __global__ void _Roll(
T* y) {
CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) {
#pragma unroll
for (int d = D - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
r -= X_shifts.data[d];
......@@ -31,33 +31,43 @@ __global__ void _Roll(
}
}
template <typename T, int D>
void _RollImpl(
const int64_t* x_shifts,
const int64_t* x_strides,
const int64_t* y_dims,
const T* x,
T* y,
CUDAContext* ctx) {
SimpleArray<int, D> X_shifts, X_strides, Y_dims;
const auto N =
std::accumulate(y_dims, y_dims + D, 1, std::multiplies<int64_t>());
for (int i = 0; i < D; ++i) {
X_shifts.data[i] = x_shifts[i];
X_strides.data[i] = x_strides[i];
Y_dims.data[i] = y_dims[i];
}
_Roll<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, X_shifts, X_strides, Y_dims, x, y);
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Roll<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_shifts, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_shifts; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> Y_dims; \
const auto N = std::accumulate( \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_shifts.data[i] = x_shifts[i]; \
X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \
} \
_Roll<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, num_dims, X_shifts, X_strides, Y_dims, x, y); \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Roll<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_shifts, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
DISPATCH_FUNC_BY_VALUE_WITH_TYPE_1( \
_RollImpl, T, num_dims, x_shifts, x_strides, y_dims, x, y, ctx); \
}
DEFINE_KERNEL_LAUNCHER(bool);
......
......@@ -13,7 +13,6 @@ namespace {
template <typename T, int D>
__global__ void _Slice(
const int N,
const int num_dims,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims,
const SimpleArray<int, D> X_starts,
......@@ -21,7 +20,8 @@ __global__ void _Slice(
T* y) {
CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) {
#pragma unroll
for (int d = D - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
xi += (r + X_starts.data[d]) * X_strides.data[d];
......@@ -33,7 +33,6 @@ __global__ void _Slice(
template <typename T, int D>
__global__ void _SliceGrad(
const int N,
const int num_dims,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims,
const SimpleArray<int, D> X_starts,
......@@ -41,7 +40,8 @@ __global__ void _SliceGrad(
T* dx) {
CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) {
#pragma unroll
for (int d = D - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
xi += (r + X_starts.data[d]) * X_strides.data[d];
......@@ -50,31 +50,49 @@ __global__ void _SliceGrad(
}
}
template <typename T, int D>
void _SliceImpl(
const string& routine,
const int64_t* x_strides,
const int64_t* y_dims,
const int64_t* starts,
const T* x,
T* y,
CUDAContext* ctx) {
SimpleArray<int, D> X_strides, Y_dims, X_starts;
const auto N =
std::accumulate(y_dims, y_dims + D, 1, std::multiplies<int64_t>());
for (int i = 0; i < D; ++i) {
X_strides.data[i] = x_strides[i];
Y_dims.data[i] = y_dims[i];
X_starts.data[i] = starts[i];
}
if (routine == "Slice") {
_Slice<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, X_strides, Y_dims, X_starts, x, y);
} else if (routine == "SliceGrad") {
_SliceGrad<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, X_strides, Y_dims, X_starts, x, y);
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const int64_t* starts, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims, X_starts; \
const auto N = std::accumulate( \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \
X_starts.data[i] = starts[i]; \
} \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, num_dims, X_strides, Y_dims, X_starts, x, y); \
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const int64_t* starts, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
DISPATCH_FUNC_BY_VALUE_WITH_TYPE_1( \
_SliceImpl, T, num_dims, #name, x_strides, y_dims, starts, x, y, ctx); \
}
DEFINE_KERNEL_LAUNCHER(Slice, bool);
......
......@@ -31,12 +31,13 @@ __global__ void _Transpose(
template <typename T, int D>
void _TransposeImpl(
const int N,
const int64_t* x_strides,
const int64_t* y_dims,
const T* x,
T* y,
CUDAContext* ctx) {
const auto N =
std::accumulate(y_dims, y_dims + D, 1, std::multiplies<int64_t>());
SimpleArray<int, D> X_strides, Y_dims;
for (int i = 0; i < D; ++i) {
X_strides.data[i] = x_strides[i];
......@@ -50,46 +51,18 @@ void _TransposeImpl(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Transpose<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
const auto N = std::accumulate( \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \
switch (num_dims) { \
case 1: \
_TransposeImpl<T, 1>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 2: \
_TransposeImpl<T, 2>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 3: \
_TransposeImpl<T, 3>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 4: \
_TransposeImpl<T, 4>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 5: \
_TransposeImpl<T, 5>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 6: \
_TransposeImpl<T, 6>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 7: \
_TransposeImpl<T, 7>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 8: \
_TransposeImpl<T, 8>(N, x_strides, y_dims, x, y, ctx); \
break; \
default: \
break; \
} \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Transpose<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
DISPATCH_FUNC_BY_VALUE_WITH_TYPE_1( \
_TransposeImpl, T, num_dims, x_strides, y_dims, x, y, ctx); \
}
DEFINE_KERNEL_LAUNCHER(bool);
......
......@@ -82,7 +82,7 @@ __global__ void _SoftmaxCrossEntropyGrad(
const int S,
const int C,
const int ignore_index,
const InputT* input,
const InputT* /* input */,
const TargetT* target,
InputT* dx,
InputT* mask) {
......
......@@ -38,7 +38,7 @@ __global__ void _NLLLossGrad(
const int S,
const int C,
const int ignore_index,
const InputT* input,
const InputT* /* input */,
const TargetT* target,
InputT* dx,
InputT* mask) {
......
......@@ -67,7 +67,6 @@ template <typename T, typename AccT, int D>
__global__ void _GenericMoments(
const int rows,
const int cols,
const int num_dims,
const SimpleArray<int, D> X_dims,
const SimpleArray<int, D> X_strides,
const T* x,
......@@ -80,7 +79,8 @@ __global__ void _GenericMoments(
AccT m_val = AccT(0), v_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, cols) {
int xi = 0, c = i * cols + j;
for (int d = num_dims - 1; d >= 0; --d) {
#pragma unroll
for (int d = D - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(X_dims.data[d], c, &c, &r);
xi += r * X_strides.data[d];
......@@ -98,9 +98,8 @@ __global__ void _GenericMoments(
}
}
template <typename T, typename AccT>
void _Moments(
const int num_dims,
template <typename T, typename AccT, int D>
void _GenericMomentsImpl(
const int* dims,
const int num_axes,
const int* axes,
......@@ -108,70 +107,72 @@ void _Moments(
AccT* mean,
AccT* var,
CUDAContext* ctx) {
int rows, cols;
vec32_t out_dims(dims, dims + num_dims);
for (int i = 0; i < num_axes; ++i) {
out_dims[axes[i]] = 1;
}
if (math::utils::IsRowwiseReduce(
num_dims, dims, out_dims.data(), &rows, &cols)) {
_RowwiseMoments<<<cols, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
rows, cols, x, mean, var);
return;
}
if (math::utils::IsColwiseReduce(
num_dims, dims, out_dims.data(), &rows, &cols)) {
_ColwiseMoments<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
rows, cols, x, mean, var);
return;
}
CUDA_TENSOR_DIMS_CHECK(num_dims);
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_axes;
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_strides;
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_dims;
math::utils::TransposeAxesForReduce(
num_dims, num_axes, axes, transpose_axes.data);
SimpleArray<int, D> transpose_axes;
SimpleArray<int, D> transpose_strides;
SimpleArray<int, D> transpose_dims;
math::utils::TransposeAxesForReduce(D, num_axes, axes, transpose_axes.data);
math::utils::ComputeTransposeStrides(
num_dims, dims, transpose_axes.data, transpose_strides.data);
rows = cols = 1;
const int pivot = num_dims - num_axes;
D, dims, transpose_axes.data, transpose_strides.data);
int rows = 1, cols = 1;
const int pivot = D - num_axes;
for (int i = 0; i < pivot; ++i) {
rows *= dims[transpose_axes.data[i]];
}
for (int i = pivot; i < num_dims; ++i) {
for (int i = pivot; i < D; ++i) {
cols *= dims[transpose_axes.data[i]];
}
for (int i = 0; i < num_dims; ++i) {
for (int i = 0; i < D; ++i) {
transpose_dims.data[i] = dims[transpose_axes.data[i]];
}
_GenericMoments<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
rows, cols, num_dims, transpose_dims, transpose_strides, x, mean, var);
rows, cols, transpose_dims, transpose_strides, x, mean, var);
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \
void Moments<T, AccT, CUDAContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const T* x, \
AccT* mean, \
AccT* var, \
CUDAContext* ctx) { \
_Moments( \
num_dims, \
dims, \
num_axes, \
axes, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
mean, \
var, \
ctx); \
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \
void Moments<T, AccT, CUDAContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const T* x, \
AccT* mean, \
AccT* var, \
CUDAContext* ctx) { \
int rows, cols; \
vec32_t out_dims(dims, dims + num_dims); \
for (int i = 0; i < num_axes; ++i) { \
out_dims[axes[i]] = 1; \
} \
if (math::utils::IsRowwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_RowwiseMoments<<<cols, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, cols, x, mean, var); \
return; \
} \
if (math::utils::IsColwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_ColwiseMoments<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, cols, x, mean, var); \
return; \
} \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
DISPATCH_FUNC_BY_VALUE_WITH_TYPE_2( \
_GenericMomentsImpl, \
T, \
AccT, \
num_dims, \
dims, \
num_axes, \
axes, \
x, \
mean, \
var, \
ctx); \
}
DEFINE_KERNEL_LAUNCHER(uint8_t, float);
......
......@@ -13,7 +13,7 @@ namespace {
template <typename T>
__global__ void _Im2Col2dNCHW(
const int nthreads,
const int C,
const int /* C */,
const int H,
const int W,
const int out_h,
......@@ -59,7 +59,7 @@ __global__ void _Im2Col2dNHWC(
const int C,
const int H,
const int W,
const int out_h,
const int /* out_h */,
const int out_w,
const int kernel_h,
const int kernel_w,
......@@ -97,7 +97,7 @@ __global__ void _Im2Col2dNHWC(
template <typename T>
__global__ void _Col2Im2dNCHW(
const int nthreads,
const int C,
const int /* C */,
const int H,
const int W,
const int out_h,
......@@ -147,7 +147,7 @@ template <typename T>
__global__ void _Col2Im2dNHWC(
const int nthreads,
const int C,
const int H,
const int /* H */,
const int W,
const int out_h,
const int out_w,
......
......@@ -7,7 +7,7 @@ namespace kernels {
namespace {
template <typename T>
template <typename T, typename AccT>
void _MaxPool2dNCHW(
const int N,
const int C,
......@@ -29,8 +29,7 @@ void _MaxPool2dNCHW(
const auto NxCxHoxWo = N * C * out_h * out_w;
std::array<int, 4> index = {0, 0, 0, 0};
std::array<int, 4> dims = {N, C, out_h, out_w};
T val;
int hstart, hend, wstart, wend, xi, mask_val;
int hstart, hend, wstart, wend;
for (int i = 0; i < NxCxHoxWo; ++i) {
hstart = index[2] * stride_h - pad_h;
wstart = index[3] * stride_w - pad_w;
......@@ -39,23 +38,24 @@ void _MaxPool2dNCHW(
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
const T* offset_x = x + index[0] * CxHxW + index[1] * HxW;
mask_val = -1;
val = T(-FLT_MAX);
int mask_val = -1;
AccT val = AccT(-FLT_MAX);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
xi = h * W + w;
if (offset_x[xi] > val) {
val = offset_x[mask_val = xi];
const auto xi = h * W + w;
if (convert::To<AccT>(offset_x[xi]) > val) {
mask_val = xi;
val = convert::To<AccT>(offset_x[xi]);
}
}
}
y[i] = val;
y[i] = convert::To<T>(val);
mask[i] = mask_val;
math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
}
}
template <typename T>
template <typename T, typename AccT>
void _MaxPool2dNHWC(
const int N,
const int C,
......@@ -76,8 +76,7 @@ void _MaxPool2dNHWC(
const auto NxHoxWoxC = N * C * out_h * out_w;
std::array<int, 4> index = {0, 0, 0, 0};
std::array<int, 4> dims = {N, out_h, out_w, C};
T val;
int hstart, hend, wstart, wend, xi, mask_val;
int hstart, hend, wstart, wend;
for (int i = 0; i < NxHoxWoxC; ++i) {
hstart = index[1] * stride_h - pad_h;
wstart = index[2] * stride_w - pad_w;
......@@ -86,23 +85,24 @@ void _MaxPool2dNHWC(
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
const T* offset_x = x + index[0] * HxWxC;
mask_val = -1;
val = T(-FLT_MAX);
int mask_val = -1;
AccT val = AccT(-FLT_MAX);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
xi = (h * W + w) * C + index[3];
if (offset_x[xi] > val) {
val = offset_x[mask_val = xi];
const auto xi = (h * W + w) * C + index[3];
if (convert::To<AccT>(offset_x[xi]) > val) {
mask_val = xi;
val = convert::To<AccT>(offset_x[xi]);
}
}
}
y[i] = val;
y[i] = convert::To<T>(val);
mask[i] = mask_val;
math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
}
}
template <typename T>
template <typename T, typename AccT>
void _MaxPool2dGradNCHW(
const int N,
const int C,
......@@ -127,13 +127,15 @@ void _MaxPool2dGradNCHW(
memset(dx, 0, sizeof(T) * N * CxHxW);
for (int i = 0; i < NxCxHoxWo; ++i) {
if (mask[i] != -1) {
dx[index[0] * CxHxW + index[1] * HxW + mask[i]] += dy[i];
const auto xi = index[0] * CxHxW + index[1] * HxW + mask[i];
dx[xi] =
convert::To<T>(convert::To<AccT>(dx[xi]) + convert::To<AccT>(dy[i]));
}
math::utils::IncreaseIndexInDims(3, dims.data(), index.data());
}
}
template <typename T>
template <typename T, typename AccT>
void _MaxPool2dGradNHWC(
const int N,
const int C,
......@@ -157,13 +159,15 @@ void _MaxPool2dGradNHWC(
memset(dx, 0, sizeof(T) * N * HxWxC);
for (int i = 0; i < NxHoxWoxC; ++i) {
if (mask[i] != -1) {
dx[index[0] * HxWxC + mask[i]] += dy[i];
const auto xi = index[0] * HxWxC + mask[i];
dx[xi] =
convert::To<T>(convert::To<AccT>(dx[xi]) + convert::To<AccT>(dy[i]));
}
math::utils::IncreaseIndexInDims(2, dims.data(), index.data());
}
}
template <typename T>
template <typename T, typename AccT>
void _MaxPool3dNCHW(
const int N,
const int C,
......@@ -190,8 +194,7 @@ void _MaxPool3dNCHW(
const auto NxCxDoxHoxWo = N * C * out_d * out_h * out_w;
std::array<int, 5> index = {0, 0, 0, 0, 0};
std::array<int, 5> dims = {N, C, out_d, out_h, out_w};
T val;
int dstart, dend, hstart, hend, wstart, wend, xi, mask_val;
int dstart, dend, hstart, hend, wstart, wend;
for (int i = 0; i < NxCxDoxHoxWo; ++i) {
dstart = index[2] * stride_d - pad_d;
hstart = index[3] * stride_h - pad_h;
......@@ -203,25 +206,26 @@ void _MaxPool3dNCHW(
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
const T* offset_x = x + index[0] * CxDxHxW + index[1] * DxHxW;
mask_val = -1;
val = T(-FLT_MAX);
int mask_val = -1;
AccT val = AccT(-FLT_MAX);
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
xi = (d * H + h) * W + w;
if (offset_x[xi] > val) {
val = offset_x[mask_val = xi];
const auto xi = (d * H + h) * W + w;
if (convert::To<AccT>(offset_x[xi]) > val) {
mask_val = xi;
val = convert::To<AccT>(offset_x[xi]);
}
}
}
}
y[i] = val;
y[i] = convert::To<T>(val);
mask[i] = mask_val;
math::utils::IncreaseIndexInDims(5, dims.data(), index.data());
}
}
template <typename T>
template <typename T, typename AccT>
void _MaxPool3dNHWC(
const int N,
const int C,
......@@ -247,8 +251,7 @@ void _MaxPool3dNHWC(
const auto NxDoxHoxWoxC = N * C * out_d * out_h * out_w;
std::array<int, 5> index = {0, 0, 0, 0, 0};
std::array<int, 5> dims = {N, out_d, out_h, out_w, C};
T val;
int dstart, dend, hstart, hend, wstart, wend, xi, mask_val;
int dstart, dend, hstart, hend, wstart, wend;
for (int i = 0; i < NxDoxHoxWoxC; ++i) {
dstart = index[1] * stride_d - pad_d;
hstart = index[2] * stride_h - pad_h;
......@@ -260,25 +263,26 @@ void _MaxPool3dNHWC(
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
const T* offset_x = x + index[0] * DxHxWxC;
mask_val = -1;
val = T(-FLT_MAX);
int mask_val = -1;
AccT val = AccT(-FLT_MAX);
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
xi = ((d * H + h) * W + w) * C + index[4];
if (offset_x[xi] > val) {
val = offset_x[mask_val = xi];
const auto xi = ((d * H + h) * W + w) * C + index[4];
if (convert::To<AccT>(offset_x[xi]) > val) {
mask_val = xi;
val = convert::To<AccT>(offset_x[xi]);
}
}
}
}
y[i] = val;
y[i] = convert::To<T>(val);
mask[i] = mask_val;
math::utils::IncreaseIndexInDims(5, dims.data(), index.data());
}
}
template <typename T>
template <typename T, typename AccT>
void _MaxPool3dGradNCHW(
const int N,
const int C,
......@@ -308,13 +312,15 @@ void _MaxPool3dGradNCHW(
memset(dx, 0, sizeof(T) * N * CxDxHxW);
for (int i = 0; i < NxCxDoxHoxWo; ++i) {
if (mask[i] != -1) {
dx[index[0] * CxDxHxW + index[1] * DxHxW + mask[i]] += dy[i];
const auto xi = index[0] * CxDxHxW + index[1] * DxHxW + mask[i];
dx[xi] =
convert::To<T>(convert::To<AccT>(dx[xi]) + convert::To<AccT>(dy[i]));
}
math::utils::IncreaseIndexInDims(3, dims.data(), index.data());
}
}
template <typename T>
template <typename T, typename AccT>
void _MaxPool3dGradNHWC(
const int N,
const int C,
......@@ -343,7 +349,9 @@ void _MaxPool3dGradNHWC(
memset(dx, 0, sizeof(T) * N * DxHxWxC);
for (int i = 0; i < NxDoxHoxWoxC; ++i) {
if (mask[i] != -1) {
dx[index[0] * DxHxWxC + mask[i]] += dy[i];
const auto xi = index[0] * DxHxWxC + mask[i];
dx[xi] =
convert::To<T>(convert::To<AccT>(dx[xi]) + convert::To<AccT>(dy[i]));
}
math::utils::IncreaseIndexInDims(2, dims.data(), index.data());
}
......@@ -353,11 +361,11 @@ void _MaxPool3dGradNHWC(
/* ------------------- Launcher Separator ------------------- */
#define DISPATCH_POOL_KERNEL(name, ...) \
#define DISPATCH_POOL_KERNEL(name, T, AccT, ...) \
if (data_format == "NCHW") { \
name##NCHW(__VA_ARGS__); \
name##NCHW<T, AccT>(__VA_ARGS__); \
} else if (data_format == "NHWC") { \
name##NHWC(__VA_ARGS__); \
name##NHWC<T, AccT>(__VA_ARGS__); \
} else { \
LOG(FATAL) << "Unknown DataFormat: " << data_format; \
}
......@@ -384,6 +392,8 @@ void _MaxPool3dGradNHWC(
CPUContext* ctx) { \
DISPATCH_POOL_KERNEL( \
_##name, \
math::ScalarType<T>::type, \
math::AccmulatorType<T>::type, \
N, \
C, \
H, \
......@@ -401,8 +411,10 @@ void _MaxPool3dGradNHWC(
y); \
}
DEFINE_KERNEL_LAUNCHER(MaxPool2d, float16);
DEFINE_KERNEL_LAUNCHER(MaxPool2d, float);
DEFINE_KERNEL_LAUNCHER(MaxPool2d, double);
DEFINE_KERNEL_LAUNCHER(MaxPool2dGrad, float16); // MaxPool2dGrad
DEFINE_KERNEL_LAUNCHER(MaxPool2dGrad, float); // MaxPool2dGrad
DEFINE_KERNEL_LAUNCHER(MaxPool2dGrad, double); // MaxPool2dGrad
#undef DEFINE_KERNEL_LAUNCHER
......@@ -434,6 +446,8 @@ DEFINE_KERNEL_LAUNCHER(MaxPool2dGrad, double); // MaxPool2dGrad
CPUContext* ctx) { \
DISPATCH_POOL_KERNEL( \
_##name, \
math::ScalarType<T>::type, \
math::AccmulatorType<T>::type, \
N, \
C, \
D, \
......@@ -456,8 +470,10 @@ DEFINE_KERNEL_LAUNCHER(MaxPool2dGrad, double); // MaxPool2dGrad
y); \
}
DEFINE_KERNEL_LAUNCHER(MaxPool3d, float16);
DEFINE_KERNEL_LAUNCHER(MaxPool3d, float);
DEFINE_KERNEL_LAUNCHER(MaxPool3d, double);
DEFINE_KERNEL_LAUNCHER(MaxPool3dGrad, float16); // MaxPool3dGrad
DEFINE_KERNEL_LAUNCHER(MaxPool3dGrad, float); // MaxPool3dGrad
DEFINE_KERNEL_LAUNCHER(MaxPool3dGrad, double); // MaxPool3dGrad
#undef DEFINE_KERNEL_LAUNCHER
......
......@@ -85,7 +85,7 @@ __global__ void _RoiPoolGrad(
const int W,
const int out_h,
const int out_w,
const float spatial_scale,
const float /* spatial_scale */,
const T* dy,
const float* rois,
const int* mask,
......
......@@ -11,14 +11,13 @@ void TransposeOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
int num_axes, num_dims = X.ndim();
vec64_t X_strides(num_dims), Y_dims(num_dims);
perm(0, &num_axes);
CHECK(num_axes == 0 || num_axes == num_dims)
<< "\nProviding " << num_axes << " dimensions to permute, "
<< "while Tensor(" << X.name() << ")'s dims are " << X.DimString();
vec64_t new_axes(num_dims);
vec64_t new_axes(num_dims), new_dims(num_dims);
for (int i = 0; i < num_dims; ++i) {
new_axes[i] = num_axes > 0 ? perm(i) : num_dims - i - 1;
}
......@@ -31,13 +30,27 @@ void TransposeOp<Context>::DoRunWithType() {
}
for (int i = 0; i < num_dims; ++i) {
X_strides[i] = X.stride(new_axes[i]);
Y_dims[i] = X.dim(new_axes[i]);
new_dims[i] = X.dim(new_axes[i]);
}
vec64_t transpose_dims, transpose_axes;
math::utils::CollapseTransposeAxes(
num_dims,
X.dims().data(),
new_axes.data(),
transpose_dims,
transpose_axes);
Tensor X_collapse(transpose_dims);
num_dims = X_collapse.ndim();
vec64_t X_strides(num_dims), Y_dims(num_dims);
for (int i = 0; i < num_dims; ++i) {
X_strides[i] = X_collapse.stride(transpose_axes[i]);
Y_dims[i] = X_collapse.dim(transpose_axes[i]);
}
auto* scratch = ((void*)&X == (void*)Y)
? ctx()->workspace()->template data<T, Context>({X.count()})[0]
: Y->Reshape(Y_dims)->template mutable_data<T, Context>();
: Y->Reshape(new_dims)->template mutable_data<T, Context>();
kernels::Transpose(
num_dims,
......@@ -51,7 +64,7 @@ void TransposeOp<Context>::DoRunWithType() {
math::Copy(
X.count(),
scratch,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
Y->Reshape(new_dims)->template mutable_data<T, Context>(),
ctx());
}
}
......
......@@ -107,11 +107,6 @@ void PoolOp<Context>::DoRunWithType() {
}
template <class Context>
void PoolOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::TypesBase<float, double>>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void PoolGradientOp<Context>::DoRunWithType() {
ComputeOutShape();
......@@ -212,11 +207,6 @@ void PoolGradientOp<Context>::DoRunWithType() {
}
}
template <class Context>
void PoolGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::TypesBase<float, double>>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Pool);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Pool);
......
......@@ -27,7 +27,9 @@ class PoolOp final : public PoolOpBase<Context> {
USE_OPERATOR_FUNCTIONS;
USE_POOL_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -43,7 +45,9 @@ class PoolGradientOp final : public PoolOpBase<Context> {
USE_OPERATOR_FUNCTIONS;
USE_POOL_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -70,7 +74,9 @@ class CuDNNPoolOp final : public CuDNNPoolOpBase<Context> {
CUDNN_CHECK(cudnnDestroyPoolingDescriptor(pool_desc_));
}
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -99,7 +105,9 @@ class CuDNNPoolGradientOp final : public CuDNNPoolOpBase<Context> {
CUDNN_CHECK(cudnnDestroyPoolingDescriptor(pool_desc_));
}
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......
#ifdef USE_CUDNN
#include "dragon/core/workspace.h"
#include "dragon/operators/vision/pool_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -10,6 +12,56 @@ void CuDNNPoolOp<Context>::DoRunWithType() {
ComputeOutShape();
auto &X = Input(0), *Y = Output(0);
// CuDNN NHWC pooling is slow.
// Temporarily fallback to the naive implementation.
if (data_format() == "NHWC" && mode_ == "AVG") {
if (num_axes_ == 1 || num_axes_ == 2) {
kernels::AvgPool2d(
in_dims_[0],
in_dims_[1],
in_dims_[2],
num_axes_ == 1 ? 1 : in_dims_[3],
out_dims_[2],
num_axes_ == 1 ? 1 : out_dims_[3],
kshape_[0],
num_axes_ == 1 ? 1 : kshape_[1],
strides_[0],
num_axes_ == 1 ? 1 : strides_[1],
pads_begin_[0],
num_axes_ == 1 ? 0 : pads_begin_[1],
data_format(),
X.template data<T, Context>(),
Y->Reshape(out_shape_)->template mutable_data<T, Context>(),
ctx());
} else if (num_axes_ == 3) {
kernels::AvgPool3d(
in_dims_[0],
in_dims_[1],
in_dims_[2],
in_dims_[3],
in_dims_[4],
out_dims_[2],
out_dims_[3],
out_dims_[4],
kshape_[0],
kshape_[1],
kshape_[2],
strides_[0],
strides_[1],
strides_[2],
pads_begin_[0],
pads_begin_[1],
pads_begin_[2],
data_format(),
X.template data<T, Context>(),
Y->Reshape(out_shape_)->template mutable_data<T, Context>(),
ctx());
} else {
LOG(FATAL) << "AvgPool" << num_axes_ << "d is not supported.";
}
return;
}
SetPoolDesc();
CuDNNSetTensorDesc<T>(&input_desc_, X.dims(), data_format());
CuDNNSetTensorDesc<T>(&output_desc_, out_shape_, data_format());
......@@ -26,11 +78,6 @@ void CuDNNPoolOp<Context>::DoRunWithType() {
}
template <class Context>
void CuDNNPoolOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void CuDNNPoolGradientOp<Context>::DoRunWithType() {
ComputeOutShape();
......@@ -56,11 +103,6 @@ void CuDNNPoolGradientOp<Context>::DoRunWithType() {
dX->ReshapeLike(X)->template mutable_data<T, Context>()));
}
template <class Context>
void CuDNNPoolGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CUDNN_OPERATOR(Pool);
DEPLOY_CUDNN_OPERATOR(PoolGradient);
......
......@@ -60,12 +60,19 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
CHECK_EQ(X_reshape.count(), X.count())
<< "\nCould not rearrange " << X.DimString() << " to "
<< X_reshape.DimString() << " with block size " << block_size_ << ".";
vec64_t X_strides(in_dims.size());
vec64_t Y_dims(in_dims.size());
for (int i = 0; i < X_reshape.ndim(); i++) {
X_strides[i] = X_reshape.stride(perm[i]);
Y_dims[i] = X_reshape.dim(perm[i]);
vec64_t transpose_dims, transpose_axes;
math::utils::CollapseTransposeAxes(
X_reshape.ndim(),
X_reshape.dims().data(),
perm.data(),
transpose_dims,
transpose_axes);
Tensor X_collapse(transpose_dims);
num_dims = X_collapse.ndim();
vec64_t X_strides(num_dims), Y_dims(num_dims);
for (int i = 0; i < num_dims; ++i) {
X_strides[i] = X_collapse.stride(transpose_axes[i]);
Y_dims[i] = X_collapse.dim(transpose_axes[i]);
}
auto* scratch = ((void*)&X == (void*)Y)
......@@ -73,7 +80,7 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
: Y->Reshape(out_shape)->template mutable_data<T, Context>();
kernels::Transpose(
X_strides.size(),
num_dims,
X_strides.data(),
Y_dims.data(),
X.template data<T, Context>(),
......@@ -135,12 +142,19 @@ void DepthToSpaceOp<Context>::DoRunWithType() {
CHECK_EQ(X_reshape.count(), X.count())
<< "\nCould not rearrange " << X.DimString() << " to "
<< X_reshape.DimString() << " with block size " << block_size_ << ".";
vec64_t X_strides(in_dims.size());
vec64_t Y_dims(in_dims.size());
for (int i = 0; i < in_dims.size(); i++) {
X_strides[i] = X_reshape.stride(perm[i]);
Y_dims[i] = X_reshape.dim(perm[i]);
vec64_t transpose_dims, transpose_axes;
math::utils::CollapseTransposeAxes(
X_reshape.ndim(),
X_reshape.dims().data(),
perm.data(),
transpose_dims,
transpose_axes);
Tensor X_collapse(transpose_dims);
num_dims = X_collapse.ndim();
vec64_t X_strides(num_dims), Y_dims(num_dims);
for (int i = 0; i < num_dims; ++i) {
X_strides[i] = X_collapse.stride(transpose_axes[i]);
Y_dims[i] = X_collapse.dim(transpose_axes[i]);
}
auto* scratch = ((void*)&X == (void*)Y)
......@@ -148,7 +162,7 @@ void DepthToSpaceOp<Context>::DoRunWithType() {
: Y->Reshape(out_shape)->template mutable_data<T, Context>();
kernels::Transpose(
X_strides.size(),
num_dims,
X_strides.data(),
Y_dims.data(),
X.template data<T, Context>(),
......
......@@ -158,6 +158,129 @@ class CUDADeviceGuard {
int prev_id_;
};
#define DISPATCH_FUNC_BY_VALUE_WITH_TYPE_1(Func, T, val, ...) \
do { \
switch (val) { \
case 1: { \
Func<T, 1>(__VA_ARGS__); \
break; \
} \
case 2: { \
Func<T, 2>(__VA_ARGS__); \
break; \
} \
case 3: { \
Func<T, 3>(__VA_ARGS__); \
break; \
} \
case 4: { \
Func<T, 4>(__VA_ARGS__); \
break; \
} \
case 5: { \
Func<T, 5>(__VA_ARGS__); \
break; \
} \
case 6: { \
Func<T, 6>(__VA_ARGS__); \
break; \
} \
case 7: { \
Func<T, 7>(__VA_ARGS__); \
break; \
} \
case 8: { \
Func<T, 8>(__VA_ARGS__); \
break; \
} \
default: { \
break; \
} \
} \
} while (false)
#define DISPATCH_FUNC_BY_VALUE_WITH_TYPE_2(Func, T1, T2, val, ...) \
do { \
switch (val) { \
case 1: { \
Func<T1, T2, 1>(__VA_ARGS__); \
break; \
} \
case 2: { \
Func<T1, T2, 2>(__VA_ARGS__); \
break; \
} \
case 3: { \
Func<T1, T2, 3>(__VA_ARGS__); \
break; \
} \
case 4: { \
Func<T1, T2, 4>(__VA_ARGS__); \
break; \
} \
case 5: { \
Func<T1, T2, 5>(__VA_ARGS__); \
break; \
} \
case 6: { \
Func<T1, T2, 6>(__VA_ARGS__); \
break; \
} \
case 7: { \
Func<T1, T2, 7>(__VA_ARGS__); \
break; \
} \
case 8: { \
Func<T1, T2, 8>(__VA_ARGS__); \
break; \
} \
default: { \
break; \
} \
} \
} while (false)
#define DISPATCH_FUNC_BY_VALUE_WITH_TYPE_3(Func, T1, T2, T3, val, ...) \
do { \
switch (val) { \
case 1: { \
Func<T1, T2, T3, 1>(__VA_ARGS__); \
break; \
} \
case 2: { \
Func<T1, T2, T3, 2>(__VA_ARGS__); \
break; \
} \
case 3: { \
Func<T1, T2, T3, 3>(__VA_ARGS__); \
break; \
} \
case 4: { \
Func<T1, T2, T3, 4>(__VA_ARGS__); \
break; \
} \
case 5: { \
Func<T1, T2, T3, 5>(__VA_ARGS__); \
break; \
} \
case 6: { \
Func<T1, T2, T3, 6>(__VA_ARGS__); \
break; \
} \
case 7: { \
Func<T1, T2, T3, 7>(__VA_ARGS__); \
break; \
} \
case 8: { \
Func<T1, T2, T3, 8>(__VA_ARGS__); \
break; \
} \
default: { \
break; \
} \
} \
} while (false)
#else
#define CUDA_NOT_COMPILED LOG(FATAL) << "CUDA library is not compiled with."
......
......@@ -62,7 +62,6 @@ template <typename T, typename AccT, class Reducer, int D>
__global__ void _GenericReduce(
const int rows,
const int cols,
const int num_dims,
const SimpleArray<int, D> x_dims,
const SimpleArray<int, D> x_strides,
const Reducer reducer,
......@@ -75,7 +74,8 @@ __global__ void _GenericReduce(
AccT val = init;
CUDA_2D_KERNEL_LOOP2(j, cols) {
int xi = 0, c = i * cols + j;
for (int d = num_dims - 1; d >= 0; --d) {
#pragma unroll
for (int d = D - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(x_dims.data[d], c, &c, &r);
xi += r * x_strides.data[d];
......@@ -89,66 +89,92 @@ __global__ void _GenericReduce(
}
}
#define DEFINE_REDUCE_DISPATCHER(name) \
template <typename T, typename AccT, typename Reducer> \
void _Reduce##name( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const Reducer reducer, \
const AccT init, \
const AccT scale, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
int rows, cols; \
vec32_t out_dims(dims, dims + num_dims); \
for (int i = 0; i < num_axes; ++i) { \
out_dims[axes[i]] = 1; \
} \
if (math::utils::IsRowwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_RowwiseReduce<<<cols, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, cols, reducer, init, scale, x, y); \
return; \
} \
if (math::utils::IsColwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_ColwiseReduce<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, cols, reducer, init, scale, x, y); \
return; \
} \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_axes; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_strides; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> transpose_dims; \
math::utils::TransposeAxesForReduce( \
num_dims, num_axes, axes, transpose_axes.data); \
math::utils::ComputeTransposeStrides( \
num_dims, dims, transpose_axes.data, transpose_strides.data); \
rows = cols = 1; \
const int pivot = num_dims - num_axes; \
for (int i = 0; i < pivot; ++i) { \
rows *= dims[transpose_axes.data[i]]; \
} \
for (int i = pivot; i < num_dims; ++i) { \
cols *= dims[transpose_axes.data[i]]; \
} \
for (int i = 0; i < num_dims; ++i) { \
transpose_dims.data[i] = dims[transpose_axes.data[i]]; \
} \
_GenericReduce<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, \
cols, \
num_dims, \
transpose_dims, \
transpose_strides, \
reducer, \
init, \
scale, \
x, \
y); \
template <typename T, typename AccT, class Reducer, int D>
void _GenericReduceImpl(
const int* dims,
const int num_axes,
const int* axes,
const Reducer reducer,
const AccT init,
const AccT scale,
const T* x,
T* y,
CUDAContext* ctx) {
SimpleArray<int, D> transpose_axes;
SimpleArray<int, D> transpose_strides;
SimpleArray<int, D> transpose_dims;
math::utils::TransposeAxesForReduce(D, num_axes, axes, transpose_axes.data);
math::utils::ComputeTransposeStrides(
D, dims, transpose_axes.data, transpose_strides.data);
int rows = 1, cols = 1;
const int pivot = D - num_axes;
for (int i = 0; i < pivot; ++i) {
rows *= dims[transpose_axes.data[i]];
}
for (int i = pivot; i < D; ++i) {
cols *= dims[transpose_axes.data[i]];
}
for (int i = 0; i < D; ++i) {
transpose_dims.data[i] = dims[transpose_axes.data[i]];
}
_GenericReduce<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
rows,
cols,
transpose_dims,
transpose_strides,
reducer,
init,
scale,
x,
y);
}
#define DEFINE_REDUCE_DISPATCHER(name) \
template <typename T, typename AccT, typename Reducer> \
void _Reduce##name( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const Reducer reducer, \
const AccT init, \
const AccT scale, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
int rows, cols; \
vec32_t out_dims(dims, dims + num_dims); \
for (int i = 0; i < num_axes; ++i) { \
out_dims[axes[i]] = 1; \
} \
if (math::utils::IsRowwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_RowwiseReduce<<<cols, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, cols, reducer, init, scale, x, y); \
return; \
} \
if (math::utils::IsColwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_ColwiseReduce<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, cols, reducer, init, scale, x, y); \
return; \
} \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
DISPATCH_FUNC_BY_VALUE_WITH_TYPE_3( \
_GenericReduceImpl, \
T, \
AccT, \
Reducer, \
num_dims, \
dims, \
num_axes, \
axes, \
reducer, \
init, \
scale, \
x, \
y, \
ctx); \
}
DEFINE_REDUCE_DISPATCHER(Max);
......
......@@ -311,14 +311,41 @@ inline void ComputeTransposeStrides(
}
}
template <typename DimT, typename AxisT>
inline void CollapseTransposeAxes(
const int num_dims,
const DimT* dims,
const AxisT* axes,
vector<DimT>& new_dims,
vector<AxisT>& new_axes) {
new_dims = vector<DimT>(dims, dims + num_dims);
new_axes = vector<AxisT>({axes[0]});
vector<AxisT> collapse_axes;
for (int i = 1; i < num_dims; ++i) {
if (axes[i] - 1 == axes[i - 1]) {
collapse_axes.push_back(axes[i]);
new_dims[axes[i]] *= new_dims[axes[i] - 1];
new_dims[axes[i] - 1] = -1;
} else {
new_axes.push_back(axes[i]);
}
}
const auto& erase_iter = std::remove_if(
new_dims.begin(), new_dims.end(), [](int x) { return x == -1; });
new_dims.erase(erase_iter, new_dims.end());
for (int i = 0; i < new_axes.size(); ++i) {
for (auto collapse_axis : collapse_axes) {
if (new_axes[i] > collapse_axis) new_axes[i]--;
}
}
}
template <typename DimT, typename IndexT>
inline IndexT
GetIndexFromDims(const int num_dims, const DimT* dims, IndexT* index) {
IndexT ret = 0;
for (int i = 0; i < num_dims; ++i) {
if (dims[i] > 1) {
ret = ret * dims[i] + index[i];
}
if (dims[i] > 1) ret = ret * dims[i] + index[i];
}
return ret;
}
......
......@@ -267,7 +267,7 @@ def uniform_(tensor, a=0, b=1):
----------
tensor : dragon.vm.torch.Tensor
The input tensor.
a : number, optional, default=-1
a : number, optional, default=0
The value to :math:`\alpha`.
b : number, optional, default=1
The value to :math:`\beta`.
......
......@@ -390,7 +390,7 @@ class MultiheadAttention(Module):
self.in_proj_bias = Parameter(Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = Linear(embed_dim, embed_dim, bias=True)
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
self.reset_parameters()
def reset_parameters(self):
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!