Commit 5cbbef4b by Ting PAN

Use block reduction for ArgMax and ArgMin Operator

Summary:
This commit reimplements the cuda argmax/argmin via BlockReduce,
instead of the naive reduction in kernel loop.
1 parent b4019faa
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -10,98 +12,58 @@ namespace kernel { ...@@ -10,98 +12,58 @@ namespace kernel {
namespace { namespace {
template <typename T> template <typename T>
__global__ void _ArgMax( struct ArgMaxFunctor {
const int nthreads, inline __device__ cub::KeyValuePair<int64_t, T> operator()(
const int inner_dim, const cub::KeyValuePair<int64_t, T>& lhs,
const int axis_dim, const cub::KeyValuePair<int64_t, T>& rhs) const {
const T* x, if ((greater_(rhs.value, lhs.value)) ||
int64_t* y) { (equal_(lhs.value, rhs.value) && (rhs.key < lhs.key))) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { return rhs;
const int i = yi / inner_dim;
const int j = yi % inner_dim;
const T* offset_x = x + (i * axis_dim * inner_dim + j);
auto max_val = offset_x[0];
auto max_idx = int64_t(0);
for (int k = 1; k < axis_dim; ++k) {
const T val = offset_x[k * inner_dim];
if (val > max_val) {
max_val = val;
max_idx = k;
}
} }
y[yi] = max_idx; return lhs;
} }
} math::GreaterFunctor<T> greater_;
math::EqualFunctor<T> equal_;
};
template <> template <typename T>
__global__ void _ArgMax<half>( struct ArgMinFunctor {
const int nthreads, inline __device__ cub::KeyValuePair<int64_t, T> operator()(
const int inner_dim, const cub::KeyValuePair<int64_t, T>& lhs,
const int axis_dim, const cub::KeyValuePair<int64_t, T>& rhs) const {
const half* x, if ((less_(rhs.value, lhs.value)) ||
int64_t* y) { (equal_(lhs.value, rhs.value) && (rhs.key < lhs.key))) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { return rhs;
const int i = yi / inner_dim;
const int j = yi % inner_dim;
const half* offset_x = x + (i * axis_dim * inner_dim + j);
auto max_val = __half2float(offset_x[0]);
auto max_idx = int64_t(0);
for (int k = 1; k < axis_dim; ++k) {
const float val = __half2float(offset_x[k * inner_dim]);
if (val > max_val) {
max_val = val;
max_idx = k;
}
} }
y[yi] = max_idx; return lhs;
} }
} math::LessFunctor<T> less_;
math::EqualFunctor<T> equal_;
};
template <typename T> template <typename T, class Reducer>
__global__ void _ArgMin( __global__ void _ArgReduce(
const int nthreads, const int rows,
const int cols,
const int inner_dim, const int inner_dim,
const int axis_dim, const Reducer reducer,
const T init,
const T* x, const T* x,
int64_t* y) { int64_t* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { typedef cub::KeyValuePair<int64_t, T> KeyValuePair;
const int i = yi / inner_dim; __shared__ typename BlockReduce<KeyValuePair>::TempStorage storage;
const int j = yi % inner_dim; CUDA_2D_KERNEL_LOOP1(i, rows) {
const T* offset_x = x + (i * axis_dim * inner_dim + j); auto key_val = KeyValuePair(-1, init);
auto min_val = offset_x[0]; CUDA_2D_KERNEL_LOOP2(j, cols) {
auto min_idx = int64_t(0); key_val = reducer(
for (int k = 1; k < axis_dim; ++k) { key_val,
const T val = offset_x[k * inner_dim]; KeyValuePair(
if (val < min_val) { j, x[((i / inner_dim) * cols + j) * inner_dim + i % inner_dim]));
min_val = val;
min_idx = k;
}
} }
y[yi] = min_idx; key_val = BlockReduce<KeyValuePair>(storage).Reduce(key_val, reducer);
} if (threadIdx.x == 0) {
} y[i] = key_val.key;
template <>
__global__ void _ArgMin<half>(
const int nthreads,
const int inner_dim,
const int axis_dim,
const half* x,
int64_t* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / inner_dim;
const int j = yi % inner_dim;
const half* offset_x = x + (i * axis_dim * inner_dim + j);
auto min_val = __half2float(offset_x[0]);
auto min_idx = int64_t(0);
for (int k = 1; k < axis_dim; ++k) {
const float val = __half2float(offset_x[k * inner_dim]);
if (val < min_val) {
min_val = val;
min_idx = k;
}
} }
y[yi] = min_idx;
} }
} }
...@@ -109,34 +71,111 @@ __global__ void _ArgMin<half>( ...@@ -109,34 +71,111 @@ __global__ void _ArgMin<half>(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T1, T2) \ #define DEFINE_KERNEL_LAUNCHER(name, T1, T2, Reducer, kInit) \
template <> \ template <> \
void name<T1, CUDAContext>( \ void name<T1, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const T1* x, \ const T1* x, \
int64_t* y, \ int64_t* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
auto nthreads = outer_dim * inner_dim; \ const auto rows = outer_dim * inner_dim; \
_##name<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ const auto cols = axis_dim; \
nthreads, inner_dim, axis_dim, reinterpret_cast<const T2*>(x), y); \ _ArgReduce<<<CUDA_2D_BLOCKS(rows), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, \
cols, \
inner_dim, \
Reducer<T2>(), \
kInit, \
reinterpret_cast<const T2*>(x), \
y); \
} }
DEFINE_KERNEL_LAUNCHER(ArgMax, int8_t, int8_t); DEFINE_KERNEL_LAUNCHER(
DEFINE_KERNEL_LAUNCHER(ArgMax, uint8_t, uint8_t); ArgMax,
DEFINE_KERNEL_LAUNCHER(ArgMax, int, int); int8_t,
DEFINE_KERNEL_LAUNCHER(ArgMax, int64_t, int64_t); int8_t,
DEFINE_KERNEL_LAUNCHER(ArgMax, float16, half); ArgMaxFunctor,
DEFINE_KERNEL_LAUNCHER(ArgMax, float, float); std::numeric_limits<int8_t>::lowest());
DEFINE_KERNEL_LAUNCHER(ArgMax, double, double); DEFINE_KERNEL_LAUNCHER(
DEFINE_KERNEL_LAUNCHER(ArgMin, int8_t, int8_t); ArgMax,
DEFINE_KERNEL_LAUNCHER(ArgMin, uint8_t, uint8_t); uint8_t,
DEFINE_KERNEL_LAUNCHER(ArgMin, int, int); uint8_t,
DEFINE_KERNEL_LAUNCHER(ArgMin, int64_t, int64_t); ArgMaxFunctor,
DEFINE_KERNEL_LAUNCHER(ArgMin, float16, half); std::numeric_limits<uint8_t>::lowest());
DEFINE_KERNEL_LAUNCHER(ArgMin, float, float); DEFINE_KERNEL_LAUNCHER(
DEFINE_KERNEL_LAUNCHER(ArgMin, double, double); ArgMax,
int,
int,
ArgMaxFunctor,
std::numeric_limits<int>::lowest());
DEFINE_KERNEL_LAUNCHER(
ArgMax,
int64_t,
int64_t,
ArgMaxFunctor,
std::numeric_limits<int64_t>::lowest());
DEFINE_KERNEL_LAUNCHER(
ArgMax,
float16,
half,
ArgMaxFunctor,
cub::Traits<half>::Lowest());
DEFINE_KERNEL_LAUNCHER(
ArgMax,
float,
float,
ArgMaxFunctor,
std::numeric_limits<float>::lowest());
DEFINE_KERNEL_LAUNCHER(
ArgMax,
double,
double,
ArgMaxFunctor,
std::numeric_limits<double>::lowest());
DEFINE_KERNEL_LAUNCHER(
ArgMin,
int8_t,
int8_t,
ArgMinFunctor,
std::numeric_limits<int8_t>::max());
DEFINE_KERNEL_LAUNCHER(
ArgMin,
uint8_t,
uint8_t,
ArgMinFunctor,
std::numeric_limits<uint8_t>::max());
DEFINE_KERNEL_LAUNCHER(
ArgMin,
int,
int,
ArgMinFunctor,
std::numeric_limits<int>::max());
DEFINE_KERNEL_LAUNCHER(
ArgMin,
int64_t,
int64_t,
ArgMinFunctor,
std::numeric_limits<int64_t>::max());
DEFINE_KERNEL_LAUNCHER(
ArgMin,
float16,
half,
ArgMinFunctor,
cub::Traits<half>::Max());
DEFINE_KERNEL_LAUNCHER(
ArgMin,
float,
float,
ArgMinFunctor,
std::numeric_limits<float>::max());
DEFINE_KERNEL_LAUNCHER(
ArgMin,
double,
double,
ArgMinFunctor,
std::numeric_limits<double>::max());
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -44,6 +44,7 @@ void SyncBatchNormOp<Context>::TrainingImpl() { ...@@ -44,6 +44,7 @@ void SyncBatchNormOp<Context>::TrainingImpl() {
ctx()); ctx());
// Compute D(X) = E(X^2) - E(X)^2 // Compute D(X) = E(X^2) - E(X)^2
ctx()->FinishDeviceComputation();
if (enable_nccl_) { if (enable_nccl_) {
#ifdef USE_NCCL #ifdef USE_NCCL
auto nccl_comm_ = this->nccl_comm(); auto nccl_comm_ = this->nccl_comm();
...@@ -138,6 +139,7 @@ void SyncBatchNormGradientOp<Context>::TrainingImpl() { ...@@ -138,6 +139,7 @@ void SyncBatchNormGradientOp<Context>::TrainingImpl() {
N_, C_, S_, data_format(), x, mu, rsig, gamma, dy, dgamma, dbeta, ctx()); N_, C_, S_, data_format(), x, mu, rsig, gamma, dy, dgamma, dbeta, ctx());
// Gradient w.r.t. gamma and beta of global batch // Gradient w.r.t. gamma and beta of global batch
ctx()->FinishDeviceComputation();
if (enable_nccl_) { if (enable_nccl_) {
#ifdef USE_NCCL #ifdef USE_NCCL
auto nccl_comm_ = this->nccl_comm(); auto nccl_comm_ = this->nccl_comm();
......
...@@ -709,7 +709,7 @@ class TestArrayOps(OpTestCase): ...@@ -709,7 +709,7 @@ class TestArrayOps(OpTestCase):
self.assertEqual(x.shape, (4,)) self.assertEqual(x.shape, (4,))
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_range_cuda(self): def test_permutation_cuda(self):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_permutation() self.test_permutation()
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!