Commit d56e67d1 by Ting PAN

Use FP32 accumulator for FP16 ReduceSum

Summary:
This commit adds a fallback with FP32 accumulator
for FP16 ReduceSum to avoid dropping too many small values.
Besides, FP16 kernels for arch < 530 are almost available.
1 parent 9ca4b60f
Showing with 683 additions and 782 deletions
......@@ -47,7 +47,7 @@ class Dropout(Layer):
param = layer_param.dropout_param
if not param.scale_train:
raise ValueError('Unscaled dropout is not supported.')
self.arguments = {'prob': param.dropout_ratio}
self.arguments = {'ratio': param.dropout_ratio}
def __call__(self, bottom):
return activation_ops.dropout(bottom, **self.arguments)
......
......@@ -22,27 +22,28 @@ __global__ void _DropPath(
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = x[i] * (T)(__ldg(mask + (i / cols)) > thresh) * scale;
y[i] = x[i] * T(__ldg(mask + (i / cols)) > thresh) * scale;
#else
y[i] = x[i] * (T)(mask[i / cols] > thresh) * scale;
y[i] = x[i] * T(mask[i / cols] > thresh) * scale;
#endif
}
}
template <>
__global__ void _DropPath<half>(
__global__ void _DropPath(
const int nthreads,
const int cols,
const float thresh,
const half scale,
const float scale,
const half* x,
const float* mask,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(
__hmul(x[i], scale),
__float2half((float)(__ldg(mask + (i / cols)) > thresh)));
#if __CUDA_ARCH__ >= 350
y[i] = __float2half(
__half2float(x[i]) * float(__ldg(mask + (i / cols)) > thresh) * scale);
#else
y[i] = __float2half(
__half2float(x[i]) * float(mask[i / cols] > thresh) * scale);
#endif
}
}
......@@ -60,13 +61,13 @@ void DropPath<float16, CUDAContext>(
const float* mask,
float16* y,
CUDAContext* ctx) {
auto nthreads = rows * cols;
auto thresh = 1.f - (1.f / scale);
const auto nthreads = rows * cols;
const auto thresh = 1.f - (1.f / scale);
_DropPath<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
nthreads,
cols,
thresh,
cast::to<half>(scale),
scale,
reinterpret_cast<const half*>(x),
mask,
reinterpret_cast<half*>(y));
......@@ -82,8 +83,8 @@ void DropPath<float16, CUDAContext>(
const float* mask, \
T* y, \
CUDAContext* ctx) { \
auto nthreads = rows * cols; \
auto thresh = 1.f - (1.f / scale); \
const auto nthreads = rows * cols; \
const auto thresh = 1.f - (1.f / scale); \
_DropPath<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, thresh, cast::to<T>(scale), x, mask, y); \
}
......
......@@ -37,20 +37,20 @@ void _ApplyMask<float16>(
template <typename T>
void _Dropout(
const int count,
const T prob,
const T ratio,
const T scale,
const T* x,
uint8_t* mask,
T* y,
CPUContext* ctx) {
math::RandomBernoulli(count, T(1) - prob, mask, ctx);
math::RandomBernoulli(count, T(1) - ratio, mask, ctx);
_ApplyMask(count, scale, x, mask, y);
}
template <>
void _Dropout<float16>(
const int count,
const float16 prob,
const float16 ratio,
const float16 scale,
const float16* x,
uint8_t* mask,
......@@ -63,28 +63,28 @@ void _Dropout<float16>(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ApplyMask<T, CPUContext>( \
const int count, \
const float scale, \
const T* x, \
const uint8_t* mask, \
T* y, \
CPUContext* ctx) { \
_ApplyMask(count, cast::to<T>(scale), x, mask, y); \
} \
template <> \
void Dropout<T, CPUContext>( \
const int count, \
const float prob, \
const float scale, \
const T* x, \
uint8_t* mask, \
T* y, \
uint32_t* r, \
CPUContext* ctx) { \
_Dropout(count, cast::to<T>(prob), cast::to<T>(scale), x, mask, y, ctx); \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void ApplyMask<T, CPUContext>( \
const int count, \
const float scale, \
const T* x, \
const uint8_t* mask, \
T* y, \
CPUContext* ctx) { \
_ApplyMask(count, cast::to<T>(scale), x, mask, y); \
} \
template <> \
void Dropout<T, CPUContext>( \
const int count, \
const float ratio, \
const float scale, \
const T* x, \
uint8_t* mask, \
T* y, \
uint32_t* r, \
CPUContext* ctx) { \
_Dropout(count, cast::to<T>(ratio), cast::to<T>(scale), x, mask, y, ctx); \
}
DEFINE_KERNEL_LAUNCHER(float16);
......
......@@ -23,17 +23,14 @@ __global__ void _ApplyMask(
}
}
template <>
__global__ void _ApplyMask<half>(
__global__ void _ApplyMask(
const int nthreads,
const half scale,
const float scale,
const half* x,
const uint8_t* mask,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(__hmul(x[i], scale), __float2half((float)mask[i]));
#endif
y[i] = __float2half(__half2float(x[i]) * (float)mask[i] * scale);
}
}
......@@ -47,25 +44,21 @@ __global__ void _Dropout(
uint8_t* mask,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = x[i] * (T)(mask[i] = (r[i] > threshold)) * scale;
y[i] = x[i] * T(mask[i] = (r[i] > threshold)) * scale;
}
}
template <>
__global__ void _Dropout<half>(
__global__ void _Dropout(
const int nthreads,
const uint32_t threshold,
const half scale,
const float scale,
const half* x,
const uint32_t* r,
uint8_t* mask,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(
__hmul(x[i], scale),
__float2half((float)(mask[i] = (r[i] > threshold))));
#endif
y[i] = __float2half(
__half2float(x[i]) * float(mask[i] = (r[i] > threshold)) * scale);
}
}
......@@ -83,7 +76,7 @@ void ApplyMask<float16, CUDAContext>(
CUDAContext* ctx) {
_ApplyMask<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
cast::to<half>(scale),
scale,
reinterpret_cast<const half*>(x),
mask,
reinterpret_cast<half*>(y));
......@@ -92,7 +85,7 @@ void ApplyMask<float16, CUDAContext>(
template <>
void Dropout<float16, CUDAContext>(
const int count,
const float prob,
const float ratio,
const float scale,
const float16* x,
uint8_t* mask,
......@@ -102,8 +95,8 @@ void Dropout<float16, CUDAContext>(
math::Random(count, r, ctx);
_Dropout<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
static_cast<uint32_t>(UINT_MAX * prob),
cast::to<half>(scale),
static_cast<uint32_t>(UINT_MAX * ratio),
scale,
reinterpret_cast<const half*>(x),
r,
mask,
......@@ -125,7 +118,7 @@ void Dropout<float16, CUDAContext>(
template <> \
void Dropout<T, CUDAContext>( \
const int count, \
const float prob, \
const float ratio, \
const float scale, \
const T* x, \
uint8_t* mask, \
......@@ -133,7 +126,7 @@ void Dropout<float16, CUDAContext>(
uint32_t* r, \
CUDAContext* ctx) { \
math::Random(count, r, ctx); \
auto threshold = static_cast<uint32_t>(UINT_MAX * prob); \
auto threshold = static_cast<uint32_t>(UINT_MAX * ratio); \
_Dropout<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, threshold, cast::to<T>(scale), x, r, mask, y); \
}
......
......@@ -24,11 +24,15 @@ __global__ void _PRelu(const int nthreads, const T* x, const T* w, T* y) {
template <>
__global__ void
_PRelu<half>(const int nthreads, const half* x, const half* w, half* y) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hgt(__ldg(x + i), kZero) ? __ldg(x + i)
: __hmul(__ldg(x + i), __ldg(w));
#if __CUDA_ARCH__ >= 350
y[i] = __half2float(__ldg(x + i)) > 0.f
? __ldg(x + i)
: __float2half(__half2float(__ldg(x + i)) * __half2float(__ldg(w)));
#else
y[i] = __half2float(x[i]) > 0.f
? x[i]
: __float2half(__half2float(x[i]) * __half2float(w[0]));
#endif
}
}
......@@ -59,12 +63,17 @@ __global__ void _PReluNCHW<half>(
const half* x,
const half* w,
half* y) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hgt(__ldg(x + i), kZero)
#if __CUDA_ARCH__ >= 350
y[i] = __half2float(__ldg(x + i)) > 0.f
? __ldg(x + i)
: __hmul(__ldg(x + i), __ldg(w + ((i / S) % C)));
: __float2half(
__half2float(__ldg(x + i)) *
__half2float(__ldg(w + ((i / S) % C))));
#else
y[i] = __half2float(x[i]) > 0.f
? x[i]
: __float2half(__half2float(x[i]) * __half2float(w[(i / S) % C]));
#endif
}
}
......@@ -89,12 +98,16 @@ __global__ void _PReluNHWC<half>(
const half* x,
const half* w,
half* y) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hgt(__ldg(x + i), kZero)
#if __CUDA_ARCH__ >= 350
y[i] = __half2float(__ldg(x + i)) > 0.f
? __ldg(x + i)
: __hmul(__ldg(x + i), __ldg(w + (i % C)));
: __float2half(
__half2float(__ldg(x + i)) * __half2float(__ldg(w + (i % C))));
#else
y[i] = __half2float(x[i]) > 0.f
? x[i]
: __float2half(__half2float(x[i]) * __half2float(w[i % C]));
#endif
}
}
......@@ -118,11 +131,15 @@ __global__ void _PReluGrad<half>(
const half* x,
const half* w,
half* dx) {
const half kOne = __float2half(1.f);
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul(dy[i], (__hgt(__ldg(x + i), kZero) ? kOne : __ldg(w)));
#if __CUDA_ARCH__ >= 350
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(x[i]) > 0.f ? 1.f : __half2float(__ldg(w))));
#else
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(x[i]) > 0.f ? 1.f : __half2float(w[0])));
#endif
}
}
......@@ -154,12 +171,16 @@ __global__ void _PReluGradNCHW<half>(
const half* x,
const half* w,
half* dx) {
const half kOne = __float2half(1.f);
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] =
dy[i] * (__hgt(__ldg(x + i), kZero) ? kOne : __ldg(w + ((i / S) % C)));
#if __CUDA_ARCH__ >= 350
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(x[i]) > 0.f ? 1.f
: __half2float(__ldg(w + ((i / S) % C)))));
#else
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(x[i]) > 0.f ? 1.f : __half2float(w[(i / S) % C])));
#endif
}
}
......@@ -189,11 +210,15 @@ __global__ void _PReluGradNHWC<half>(
const half* x,
const half* w,
half* dx) {
const half kOne = __float2half(1.f);
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = dy[i] * (__hgt(__ldg(x + i), kZero) ? kOne : __ldg(w + (i % C)));
#if __CUDA_ARCH__ >= 350
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(x[i]) > 0.f ? 1.f : __half2float(__ldg(w + (i % C)))));
#else
dx[i] = __float2half(
__half2float(dy[i]) *
(__half2float(x[i]) > 0.f ? 1.f : __half2float(w[i % C])));
#endif
}
}
......@@ -201,12 +226,14 @@ __global__ void _PReluGradNHWC<half>(
template <typename T>
__global__ void _PReluWGrad(const int N, const T* dy, const T* x, T* dw) {
__shared__ typename BlockReduce<T>::TempStorage storage;
T val = T(0);
CUDA_2D_KERNEL_LOOP2(i, N) {
#if __CUDA_ARCH__ >= 350
val += __ldg(x + i) < T(0) ? dy[i] * __ldg(x + i) : T(0);
#else
val += x[i] < T(0) ? dy[i] * x[i] : T(0);
#endif
}
val = BlockReduce<T>(storage).Sum(val);
if (threadIdx.x == 0) *dw = val;
}
......@@ -214,16 +241,18 @@ __global__ void _PReluWGrad(const int N, const T* dy, const T* x, T* dw) {
template <>
__global__ void
_PReluWGrad<half>(const int N, const half* dy, const half* x, half* dw) {
const half kZero = __float2half(0.f);
__shared__ typename BlockReduce<float>::TempStorage storage;
float val = 0.f;
CUDA_2D_KERNEL_LOOP2(i, N) {
#if __CUDA_ARCH__ >= 530
val += __hlt(x[i], kZero) ? __half2float(__hmul(dy[i], x[i])) : 0.f;
#if __CUDA_ARCH__ >= 350
val += __half2float(__ldg(x + i)) < 0.f
? __half2float(dy[i]) * __half2float(__ldg(x + i))
: 0.f;
#else
val += __half2float(x[i]) < 0.f ? __half2float(dy[i]) * __half2float(x[i])
: 0.f;
#endif
}
val = BlockReduce<float>(storage).Sum(val);
if (threadIdx.x == 0) *dw = __float2half(val);
}
......@@ -241,7 +270,11 @@ __global__ void _PReluWGradNCHW(
T val = T(0);
CUDA_2D_KERNEL_LOOP2(j, NS) {
const int yi = ((j / S) * C + i) * S + j % S;
#if __CUDA_ARCH__ >= 350
val += __ldg(x + yi) < T(0) ? dy[yi] * __ldg(x + yi) : T(0);
#else
val += x[yi] < T(0) ? dy[yi] * x[yi] : T(0);
#endif
}
val = BlockReduce<T>(storage).Sum(val);
if (threadIdx.x == 0) dw[i] = val;
......@@ -256,14 +289,19 @@ __global__ void _PReluWGradNCHW<half>(
const half* dy,
const half* x,
half* dw) {
const half kZero = __float2half(0.f);
__shared__ typename BlockReduce<float>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, C) {
float val = 0.f;
CUDA_2D_KERNEL_LOOP2(j, NS) {
#if __CUDA_ARCH__ >= 530
const int yi = ((j / S) * C + i) * S + j % S;
val += __hlt(x[yi], kZero) ? __half2float(__hmul(dy[yi], x[yi])) : 0.f;
#if __CUDA_ARCH__ >= 350
val += __half2float(__ldg(x + yi)) < 0.f
? __half2float(dy[yi]) * __half2float(__ldg(x + yi))
: 0.f;
#else
val += __half2float(x[yi]) < 0.f
? __half2float(dy[yi]) * __half2float(x[yi])
: 0.f;
#endif
}
val = BlockReduce<float>(storage).Sum(val);
......@@ -279,7 +317,11 @@ _PReluWGradNHWC(const int NS, const int C, const T* dy, const T* x, T* dw) {
T val = T(0);
CUDA_2D_KERNEL_LOOP2(j, NS) {
const int yi = j * C + i;
#if __CUDA_ARCH__ >= 350
val += __ldg(x + yi) < 0 ? dy[yi] * __ldg(x + yi) : T(0);
#else
val += x[yi] < 0 ? dy[yi] * x[yi] : T(0);
#endif
}
val = BlockReduce<T>(storage).Sum(val);
if (threadIdx.x == 0) dw[i] = val;
......@@ -298,9 +340,15 @@ __global__ void _PReluWGradNHWC<half>(
CUDA_2D_KERNEL_LOOP1(i, C) {
float val = 0.f;
CUDA_2D_KERNEL_LOOP2(j, NS) {
#if __CUDA_ARCH__ >= 530
const int yi = j * C + i;
val += __hlt(x[yi], kZero) ? __half2float(__hmul(dy[yi], x[yi])) : 0.f;
#if __CUDA_ARCH__ >= 350
val += __half2float(__ldg(x + yi)) < 0.f
? __half2float(dy[yi]) * __half2float(__ldg(x + yi))
: 0.f;
#else
val += __half2float(x[yi]) < 0.f
? __half2float(dy[yi]) * __half2float(x[yi])
: 0.f;
#endif
}
val = BlockReduce<float>(storage).Sum(val);
......
......@@ -25,9 +25,12 @@ template <>
__global__ void
_Relu<half>(const int nthreads, const float alpha, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 350
const float val = __half2float(__ldg(x + i));
y[i] = val > 0.f ? __ldg(x + i) : __float2half(val * alpha);
#else
const float val = __half2float(x[i]);
y[i] = val > 0.f ? x[i] : __float2half(val * alpha);
#endif
}
}
......@@ -36,12 +39,10 @@ template <>
__global__ void
_Relu<half2>(const int nthreads, const float alpha, const half2* x, half2* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float2 val = __half22float2(x[i]);
y[i] = __floats2half2_rn(
val.x > 0.f ? val.x : val.x * alpha,
val.y > 0.f ? val.y : val.y * alpha);
#endif
}
}
......@@ -63,13 +64,25 @@ template <>
__global__ void
_ReluN<half>(const int nthreads, const half max_value, const half* x, half* y) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __hgt(__ldg(x + i), kZero)
? (__hlt(__ldg(x + i), max_value) ? __ldg(x + i) : max_value)
: kZero;
#endif
}
#elif __CUDA_ARCH__ >= 350
const float kMax = __half2float(max_value);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float val = __half2float(__ldg(x + i));
y[i] = val > 0.f ? ((val < kMax) ? __ldg(x + i) : max_value) : kZero;
}
#else
const float kMax = __half2float(max_value);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float val = __half2float(x[i]);
y[i] = val > 0.f ? ((val < kMax) ? x[i] : max_value) : kZero;
}
#endif
}
__global__ void _ReluNHalf2(
......@@ -91,6 +104,25 @@ __global__ void _ReluNHalf2(
? __high2half(__ldg(x + i))
: max_value)
: kZero);
#elif __CUDA_ARCH__ >= 350
const float kMax = __half2float(max_value);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float2 val = __half22float2(__ldg(x + i));
y[i] = __halves2half2(
val.x > 0.f ? ((val.x < kMax) ? __low2half(__ldg(x + i)) : max_value)
: kZero,
val.y > 0.f ? ((val.y < kMax) ? __high2half(__ldg(x + i)) : max_value)
: kZero);
}
#else
const float kMax = __half2float(max_value);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float2 val = __half22float2(x[i]);
y[i] = __halves2half2(
val.x > 0.f ? ((val.x < kMax) ? __low2half(x[i]) : max_value) : kZero,
val.y > 0.f ? ((val.y < kMax) ? __high2half(x[i]) : max_value)
: kZero);
}
#endif
}
}
......@@ -103,7 +135,7 @@ __global__ void _ReluGrad(
const T* y,
T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 350
dx[i] = __ldg(dy + i) * ((__ldg(y + i) > 0) + alpha * (__ldg(y + i) <= 0));
#else
dx[i] = dy[i] * ((y[i] > 0) + alpha * (y[i] <= 0));
......@@ -118,14 +150,10 @@ __global__ void _ReluGrad<half>(
const half* dy,
const half* y,
half* dx) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul(
dy[i],
__float2half(
__hgt(__ldg(y + i), kZero) + __hle(__ldg(y + i), kZero) * alpha));
#endif
const float val = __half2float(y[i]);
dx[i] = __float2half(
__half2float(dy[i]) * ((val > 0.f) + alpha * (val <= 0.f)));
}
} // ReluGrad
......@@ -136,17 +164,12 @@ __global__ void _ReluGrad<half2>(
const half2* dy,
const half2* y,
half2* dx) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul2(
dy[i],
__floats2half2_rn(
__hgt(__low2half(__ldg(y + i)), kZero) +
__hle(__low2half(__ldg(y + i)), kZero) * alpha,
__hgt(__high2half(__ldg(y + i)), kZero) +
__hle(__high2half(__ldg(y + i)), kZero) * alpha));
#endif
const float2 val = __half22float2(y[i]);
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(
grad.x * ((val.x > 0.f) + alpha * (val.x <= 0.f)),
grad.y * ((val.y > 0.f) + alpha * (val.y <= 0.f)));
}
} // ReluGrad
......@@ -158,7 +181,7 @@ __global__ void _ReluNGrad(
const T* y,
T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 350
dx[i] = (__ldg(y + i) > 0 && __ldg(y + i) < max_value) ? dy[i] : T(0);
#else
dx[i] = (y[i] > 0 && y[i] < max_value) ? dy[i] : T(0);
......@@ -174,14 +197,20 @@ __global__ void _ReluNGrad<half>(
const half* y,
half* dx) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = (__hgt(__ldg(y + i), kZero) && __hlt(__ldg(y + i), max_value))
? dy[i]
: kZero;
#endif
}
} // ReluNGrad
#else
const float kMax = __half2float(max_value);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float val = __half2float(y[i]);
dx[i] = (val > 0.f && val < kMax) ? dy[i] : kZero;
}
#endif
}
template <>
__global__ void _ReluNGrad<half2>(
......@@ -190,15 +219,33 @@ __global__ void _ReluNGrad<half2>(
const half2* dy,
const half2* y,
half2* dx) {
#if __CUDA_ARCH__ >= 530
const half2 kZero = __float2half2_rn(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul2(
__hmul2(__hgt2(__ldg(y + i), kZero), __hlt2(__ldg(y + i), max_value)),
dy[i]);
#endif
}
} // ReluNGrad
#elif __CUDA_ARCH__ >= 350
const half kZero = __float2half(0.f);
const float kMax = __half2float(__low2half(max_value));
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float2 val = __half22float2(y[i]);
dx[i] = __halves2half2(
(val.x > 0.f && val.x < kMax) ? __low2half(__ldg(dy + i)) : kZero,
(val.y > 0.f && val.y < kMax) ? __high2half(__ldg(dy + i)) : kZero);
}
#else
const half kZero = __float2half(0.f);
const float kMax = __half2float(__low2half(max_value));
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const float2 val = __half22float2(y[i]);
dx[i] = __halves2half2(
(val.x > 0.f && val.x < kMax) ? __low2half(dy[i]) : kZero,
(val.y > 0.f && val.y < kMax) ? __high2half(dy[i]) : kZero);
}
#endif
}
} // namespace
......
......@@ -35,11 +35,9 @@ __global__ void _Selu<half>(
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float val = __half2float(x[i]);
y[i] =
__float2half(val > 0.f ? gamma * val : alphaXgamma * (exp(val) - 1.f));
#endif
}
}
......@@ -51,12 +49,10 @@ __global__ void _Selu<half2>(
const half2* x,
half2* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float2 val = __half22float2(x[i]);
y[i] = __floats2half2_rn(
val.x > 0.f ? gamma * val.x : alphaXgamma * (exp(val.x) - 1.f),
val.y > 0.f ? gamma * val.y : alphaXgamma * (exp(val.y) - 1.f));
#endif
}
}
......@@ -86,11 +82,9 @@ __global__ void _SeluGrad<half>(
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float val = __half2float(y[i]);
dx[i] =
__hmul(dy[i], __float2half(val > 0.f ? gamma : (alphaXgamma + val)));
#endif
dx[i] = __float2half(
__half2float(dy[i]) * (val > 0.f ? gamma : (alphaXgamma + val)));
}
} // SeluGrad
......@@ -103,14 +97,11 @@ __global__ void _SeluGrad<half2>(
const half2* y,
half2* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const float2 val = __half22float2(y[i]);
dx[i] = __hmul2(
dy[i],
__floats2half2_rn(
val.x > 0.f ? gamma : (alphaXgamma + val.x),
val.y > 0.f ? gamma : (alphaXgamma + val.y)));
#endif
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(
grad.x * (val.x > 0.f ? gamma : (alphaXgamma + val.x)),
grad.y * (val.y > 0.f ? gamma : (alphaXgamma + val.y)));
}
} // SeluGrad
......
......@@ -18,21 +18,17 @@ __global__ void _Sigmoid(const int nthreads, const T* x, T* y) {
template <>
__global__ void _Sigmoid<half>(const int nthreads, const half* x, half* y) {
const half kOne = __float2half(1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hdiv(kOne, __hadd(kOne, hexp(__hneg(x[i]))));
#endif
y[i] = __float2half(1.f / (1.f + exp(-__half2float(x[i]))));
}
}
template <>
__global__ void _Sigmoid<half2>(const int nthreads, const half2* x, half2* y) {
const half2 kOne = __float2half2_rn(1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __h2div(kOne, __hadd2(kOne, h2exp(__hneg2(x[i]))));
#endif
const float2 val = __half22float2(x[i]);
y[i] =
__floats2half2_rn(1.f / (1.f + exp(-val.x)), 1.f / (1.f + exp(-val.y)));
}
}
......@@ -54,11 +50,9 @@ __global__ void _SigmoidGrad<half>(
const half* dy,
const half* y,
half* dx) {
const half kOne = __float2half(1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul(dy[i], __hmul(__ldg(y + i), __hsub(kOne, __ldg(y + i))));
#endif
const float val = __half2float(y[i]);
dx[i] = __float2half(__half2float(dy[i]) * val * (1.f - val));
}
} // SigmoidGrad
......@@ -68,11 +62,11 @@ __global__ void _SigmoidGrad<half2>(
const half2* dy,
const half2* y,
half2* dx) {
const half2 kOne = __float2half2_rn(1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul2(dy[i], __hmul2(__ldg(y + i), __hsub2(kOne, __ldg(y + i))));
#endif
const float2 val = __half22float2(y[i]);
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(
grad.x * val.x * (1.f - val.x), grad.y * val.y * (1.f - val.y));
}
} // SigmoidGrad
......
......@@ -71,7 +71,6 @@ __global__ void _Softmax<half>(
const half lowest,
const half* x,
half* y) {
#if __CUDA_ARCH__ >= 530
__shared__ float block_val;
__shared__ typename BlockReduce<float>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, rows) {
......@@ -105,7 +104,6 @@ __global__ void _Softmax<half>(
y[yi] = __float2half(__half2float(y[yi]) / block_val);
}
}
#endif
}
template <typename T>
......@@ -153,7 +151,6 @@ __global__ void _SoftmaxGrad<half>(
const half* dy,
const half* y,
half* dx) {
#if __CUDA_ARCH__ >= 530
__shared__ float block_val;
__shared__ typename BlockReduce<float>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, rows) {
......@@ -162,7 +159,11 @@ __global__ void _SoftmaxGrad<half>(
float val = 0.f;
CUDA_2D_KERNEL_LOOP2(j, cols) {
const int yi = c + j * inner_dim;
#if __CUDA_ARCH__ >= 350
val += __half2float(__ldg(dy + yi)) * __half2float(__ldg(y + yi));
#else
val += __half2float(dy[yi]) * __half2float(y[yi]);
#endif
}
val = BlockReduce<float>(storage).Sum(val);
if (threadIdx.x == 0) block_val = val;
......@@ -170,12 +171,16 @@ __global__ void _SoftmaxGrad<half>(
CUDA_2D_KERNEL_LOOP2(j, cols) {
const int yi = c + j * inner_dim;
#if __CUDA_ARCH__ >= 350
dx[yi] = __float2half(
(__half2float(__ldg(dy + yi)) - block_val) *
__half2float(__ldg(y + yi)));
#else
dx[yi] = __float2half(
(__half2float(dy[yi]) - block_val) * __half2float(y[yi]));
#endif
}
}
#endif
} // SoftmaxGrad
} // namespace
......
......@@ -20,22 +20,15 @@ __global__ void _Tanh(const int nthreads, const T* x, T* y) {
template <>
__global__ void _Tanh<half>(const int nthreads, const half* x, half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const half a = hexp(__ldg(x + i));
const half b = hexp(__hneg(__ldg(x + i)));
y[i] = __hdiv(__hsub(a, b), __hadd(a, b));
#endif
y[i] = __float2half(tanh(__half2float(x[i])));
}
}
template <>
__global__ void _Tanh<half2>(const int nthreads, const half2* x, half2* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const half2 a = h2exp(__ldg(x + i));
const half2 b = h2exp(__hneg2(__ldg(x + i)));
y[i] = __h2div(__hsub2(a, b), __hadd2(a, b));
#endif
const float2 val = __half22float2(x[i]);
y[i] = __floats2half2_rn(tanh(val.x), tanh(val.y));
}
}
......@@ -49,11 +42,9 @@ __global__ void _TanhGrad(const int nthreads, const T* dy, const T* y, T* dx) {
template <>
__global__ void
_TanhGrad<half>(const int nthreads, const half* dy, const half* y, half* dx) {
const half kOne = __float2half(1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul(dy[i], __hsub(kOne, utils::math::Square(y[i])));
#endif
dx[i] = __float2half(
__half2float(dy[i]) * (1.f - utils::math::Square(__half2float(y[i]))));
}
}
......@@ -63,11 +54,12 @@ __global__ void _TanhGrad<half2>(
const half2* dy,
const half2* y,
half2* dx) {
const half2 kOne = __float2half2_rn(1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul2(dy[i], __hsub2(kOne, utils::math::Square(y[i])));
#endif
const float2 val = __half22float2(y[i]);
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(
grad.x * (1.f - utils::math::Square(val.x)),
grad.y * (1.f - utils::math::Square(val.y)));
}
}
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -22,7 +23,7 @@ __global__ void _CumSum(
y[c] = exclusive ? T(0) : x[c];
for (int j = 1; j < cols; ++j) {
const int yi = c + inner_dim;
y[yi] = y[c] + x[exclusive ? c : yi];
y[yi] = math::PlusFunctor<T>()(y[c], x[exclusive ? c : yi]);
c = yi;
}
}
......@@ -38,15 +39,13 @@ __global__ void _CumSum<half>(
half* y) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, rows) {
#if __CUDA_ARCH__ >= 530
int c = (i / inner_dim) * cols * inner_dim + (i % inner_dim);
y[c] = exclusive ? kZero : x[c];
for (int j = 1; j < cols; ++j) {
const int yi = c + inner_dim;
y[yi] = __hadd(y[c], x[exclusive ? c : yi]);
y[yi] = math::PlusFunctor<half>()(y[c], x[exclusive ? c : yi]);
c = yi;
}
#endif
}
}
......@@ -63,7 +62,7 @@ __global__ void _CumSumReverse(
y[c] = exclusive ? T(0) : x[c];
for (int j = cols - 2; j >= 0; --j) {
const int yi = c - inner_dim;
y[yi] = y[c] + x[exclusive ? c : yi];
y[yi] = math::PlusFunctor<T>()(y[c], x[exclusive ? c : yi]);
c = yi;
}
}
......@@ -79,15 +78,13 @@ __global__ void _CumSumReverse<half>(
half* y) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, rows) {
#if __CUDA_ARCH__ >= 530
int c = ((i / inner_dim) * cols + (cols - 1)) * inner_dim + (i % inner_dim);
y[c] = exclusive ? kZero : x[c];
for (int j = cols - 2; j >= 0; --j) {
const int yi = c - inner_dim;
y[yi] = __hadd(y[c], x[exclusive ? c : yi]);
y[yi] = math::PlusFunctor<half>()(y[c], x[exclusive ? c : yi]);
c = yi;
}
#endif
}
}
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -52,40 +53,8 @@ __global__ void _IndexSelectGrad(
int pos = index[k];
#endif
pos = pos >= 0 ? pos : pos + axis_dim;
dx[x_offset + pos * inner_dim] += (*offset_dy);
offset_dy += inner_dim;
}
}
}
template <>
__global__ void _IndexSelectGrad<half>(
const int nthreads,
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* index,
const half* dy,
half* dx) {
CUDA_1D_KERNEL_LOOP(ti, nthreads) {
const int i = ti / inner_dim;
const int j = ti % inner_dim;
const int x_offset = i * axis_dim * inner_dim + j;
const half* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < select_dim; ++k) {
#if __CUDA_ARCH__ >= 350
int pos = __ldg(index + k);
#else
int pos = index[k];
#endif
pos = pos >= 0 ? pos : pos + axis_dim;
pos = x_offset + pos * inner_dim;
#if __CUDA_ARCH__ >= 530
dx[pos] = __hadd(dx[pos], *(offset_dy));
#else
dx[pos] =
__float2half(__half2float(dx[pos]) + __half2float(*(offset_dy)));
#endif
dx[pos] = math::PlusFunctor<T>()(dx[pos], *(offset_dy));
offset_dy += inner_dim;
}
}
......@@ -95,31 +64,6 @@ __global__ void _IndexSelectGrad<half>(
/* ------------------- Launcher Separator ------------------- */
template <>
void IndexSelectGrad<float16, CUDAContext>(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int select_dim,
const int64_t* index,
const float16* dy,
float16* dx,
CUDAContext* ctx) {
const int nthreads = outer_dim * inner_dim;
_IndexSelectGrad<<<
CUDA_BLOCKS(nthreads),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
nthreads,
inner_dim,
axis_dim,
select_dim,
index,
reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx));
} // IndexSelectGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void IndexSelect<T, CUDAContext>( \
......@@ -169,6 +113,7 @@ DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
......
......@@ -41,8 +41,9 @@ __global__ void _RepeatGrad(
const int i = xi / x_inner_dim / axis_dim;
const T* offset_dy = dy + ((i * axis_dim + j) * y_inner_dim + k);
T val = T(0);
for (int r = 0; r < repeats; ++r)
for (int r = 0; r < repeats; ++r) {
val += offset_dy[r * x_inner_dim];
}
dx[xi] = val;
}
}
......@@ -57,16 +58,15 @@ __global__ void _RepeatGrad<half>(
const half* dy,
half* dx) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int k = xi % x_inner_dim;
const int j = (xi / x_inner_dim) % axis_dim;
const int i = xi / x_inner_dim / axis_dim;
const half* offset_dy = dy + ((i * axis_dim + j) * y_inner_dim + k);
float val = 0.f;
for (int r = 0; r < repeats; ++r)
for (int r = 0; r < repeats; ++r) {
val += __half2float(offset_dy[r * x_inner_dim]);
}
dx[xi] = __float2half(val);
#endif
}
}
......
......@@ -42,15 +42,12 @@ __global__ void _TileGrad(
const int i = xi / x_cols;
const int j = xi % x_cols;
const T* offset_dy = dy + i * y_cols + j;
T val = (*offset_dy);
offset_dy += x_cols;
for (int k = 1; k < multiple; ++k) {
val += (*offset_dy);
offset_dy += x_cols;
}
dx[xi] = val;
}
}
......@@ -64,21 +61,16 @@ __global__ void _TileGrad<half>(
const half* dy,
half* dx) {
CUDA_1D_KERNEL_LOOP(xi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int i = xi / x_cols;
const int j = xi % x_cols;
const half* offset_dy = dy + i * y_cols + j;
half val = (*offset_dy);
float val = __half2float(*offset_dy);
offset_dy += x_cols;
for (int k = 1; k < multiple; ++k) {
val += __hadd((*offset_dy), val);
val += __half2float(*offset_dy);
offset_dy += x_cols;
}
dx[xi] = val;
#endif
dx[xi] = __float2half(val);
}
}
......
......@@ -25,12 +25,19 @@ __global__ void _Clip<half>(
const half high,
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
const half val = __hlt(__ldg(x + i), high) ? __ldg(x + i) : high;
y[i] = __hgt(val, low) ? val : low;
#endif
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __hlt(__ldg(x + i), high)
? (__hgt(__ldg(x + i), low) ? __ldg(x + i) : low)
: high;
}
#else
const float kLow = __half2float(low);
const float kHigh = __half2float(high);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __float2half(max(kLow, min(__half2float(x[i]), kHigh)));
}
#endif
}
template <typename T>
......@@ -59,12 +66,28 @@ __global__ void _ClipGrad<half>(
const half* x,
half* dx) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] =
__hlt(__ldg(x + i), low) || __hgt(__ldg(x + i), high) ? kZero : dy[i];
#endif
(__hlt(__ldg(x + i), low) || __hgt(__ldg(x + i), high)) ? kZero : dy[i];
}
#elif __CUDA_ARCH__ >= 350
const float kLow = __half2float(low);
const float kHigh = __half2float(high);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = (__half2float(__ldg(x + i)) < kLow ||
__half2float(__ldg(x + i)) > kHigh)
? kZero
: dy[i];
}
#else
const float kLow = __half2float(low);
const float kHigh = __half2float(high);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] = (__half2float(x[i]) < kLow || __half2float(x[i]) > kHigh) ? kZero
: dy[i];
}
#endif
}
} // namespace
......
......@@ -20,11 +20,8 @@ __global__ void _CosGrad(const int nthreads, const T* dy, const T* x, T* dx) {
template <>
__global__ void
_CosGrad<half>(const int nthreads, const half* dy, const half* x, half* dx) {
const half kFactor = __float2half(-1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul(__hmul(dy[i], kFactor), hsin(x[i]));
#endif
dx[i] = __float2half(-__half2float(dy[i]) * sin(__half2float(x[i])));
}
}
......@@ -34,11 +31,10 @@ __global__ void _CosGrad<half2>(
const half2* dy,
const half2* x,
half2* dx) {
const half2 kFactor = __float2half2_rn(-1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul2(__hmul2(dy[i], kFactor), h2sin(x[i]));
#endif
const float2 val = __half22float2(x[i]);
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(-grad.x * sin(val.x), -grad.y * sin(val.y));
}
}
......@@ -53,9 +49,7 @@ template <>
__global__ void
_SinGrad<half>(const int nthreads, const half* dy, const half* x, half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul(dy[i], hcos(x[i]));
#endif
dx[i] = __float2half(__half2float(dy[i]) * cos(__half2float(x[i])));
}
}
......@@ -66,9 +60,9 @@ __global__ void _SinGrad<half2>(
const half2* x,
half2* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul2(dy[i], h2cos(x[i]));
#endif
const float2 val = __half22float2(x[i]);
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(grad.x * cos(val.x), grad.y * cos(val.y));
}
}
......@@ -86,11 +80,9 @@ __global__ void _ReciprocalGrad<half>(
const half* dy,
const half* y,
half* dx) {
const half c = __float2half(-1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul(__hmul(c, dy[i]), utils::math::Square(y[i]));
#endif
dx[i] = __float2half(
-__half2float(dy[i]) * utils::math::Square(__half2float(y[i])));
}
}
......@@ -100,11 +92,11 @@ __global__ void _ReciprocalGrad<half2>(
const half2* dy,
const half2* y,
half2* dx) {
const half2 c = __float2half2_rn(-1.f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul2(__hmul2(c, dy[i]), utils::math::Square(y[i]));
#endif
const float2 val = __half22float2(y[i]);
const float2 grad = __half22float2(dy[i]);
dx[i] =
__floats2half2_rn(-grad.x * (val.x * val.x), -grad.y * (val.y * val.y));
}
}
......@@ -118,11 +110,9 @@ __global__ void _RsqrtGrad(const int nthreads, const T* dy, const T* y, T* dx) {
template <>
__global__ void
_RsqrtGrad<half>(const int nthreads, const half* dy, const half* y, half* dx) {
const half c = __float2half(-0.5f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul(__hmul(c, dy[i]), utils::math::Cube(y[i]));
#endif
dx[i] = __float2half(
-0.5f * __half2float(dy[i]) * utils::math::Cube(__half2float(y[i])));
}
}
......@@ -132,11 +122,12 @@ __global__ void _RsqrtGrad<half2>(
const half2* dy,
const half2* y,
half2* dx) {
const half2 c = __float2half2_rn(-0.5f);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hmul2(__hmul2(c, dy[i]), utils::math::Cube(y[i]));
#endif
const float2 val = __half22float2(y[i]);
const float2 grad = __half22float2(dy[i]);
dx[i] = __floats2half2_rn(
-0.5f * grad.x * (val.x * val.x * val.x),
-0.5f * grad.y * (val.y * val.y * val.y));
}
}
......
......@@ -51,7 +51,6 @@ __global__ void _RowwiseMoments<half, float>(
const half* x,
float* mean,
float* var) {
#if __CUDA_ARCH__ >= 530
__shared__ typename BlockReduce<float>::TempStorage m_storage;
__shared__ typename BlockReduce<float>::TempStorage v_storage;
const float scale = 1.f / (float)rows;
......@@ -70,7 +69,6 @@ __global__ void _RowwiseMoments<half, float>(
var[i] = v_val * scale - mu * mu;
}
}
#endif
}
template <typename Tx, typename Ty>
......@@ -112,7 +110,6 @@ __global__ void _ColwiseMoments<half, float>(
const half* x,
float* mean,
float* var) {
#if __CUDA_ARCH__ >= 530
__shared__ typename BlockReduce<float>::TempStorage m_storage;
__shared__ typename BlockReduce<float>::TempStorage v_storage;
const float scale = 1.f / (float)cols;
......@@ -131,7 +128,6 @@ __global__ void _ColwiseMoments<half, float>(
var[i] = v_val * scale - mu * mu;
}
}
#endif
}
template <typename Tx, typename Ty, int D>
......
......@@ -10,9 +10,9 @@ namespace dragon {
namespace kernel {
#if __CUDA_ARCH__ >= 350
#define L(x, i) __ldg(x + i)
#define LOAD(x, i) __ldg(x + i)
#else
#define L(x, i) x[i]
#define LOAD(x, i) x[i]
#endif
namespace {
......@@ -34,8 +34,8 @@ __global__ void _BatchNormExpectation(
CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i;
ex_val += L(x, xi);
ex2_val += utils::math::Square(L(x, xi));
ex_val += LOAD(x, xi);
ex2_val += utils::math::Square(LOAD(x, xi));
}
ex_val = BlockReduce<Tp>(ex_storage).Reduce(ex_val, cub::Sum());
ex2_val = BlockReduce<Tp>(ex2_storage).Reduce(ex2_val, cub::Sum());
......@@ -66,8 +66,8 @@ __global__ void _BatchNormInternalGrad(
CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i;
dg_val += L(dy, xi) * (L(x, xi) - L(mu, i)) * L(rsig, i);
db_val += L(dy, xi);
dg_val += LOAD(dy, xi) * (LOAD(x, xi) - LOAD(mu, i)) * LOAD(rsig, i);
db_val += LOAD(dy, xi);
}
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
......@@ -95,9 +95,9 @@ __global__ void _BatchNormTrainingGrad(
const Tp denom = Tp(1) / Tp(N * S);
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
const Tp x_norm = (L(x, i) - L(mu, pi)) * L(rsig, pi);
dx[i] = L(gamma, pi) * L(rsig, pi) *
(L(dy, i) - fma(x_norm, L(dgamma, pi), L(dbeta, pi)) * denom);
const Tp x_norm = (LOAD(x, i) - LOAD(mu, pi)) * LOAD(rsig, pi);
dx[i] = LOAD(gamma, pi) * LOAD(rsig, pi) *
(LOAD(dy, i) - fma(x_norm, LOAD(dgamma, pi), LOAD(dbeta, pi)) * denom);
}
}
......@@ -120,8 +120,8 @@ __global__ void _BatchNormWGrad(
CUDA_2D_KERNEL_LOOP2(j, outer_dim) {
const int xi = kOrder == StorageOrder::NCHW ? (j / S * C + i) * S + j % S
: j * C + i;
dg_val += L(dy, xi) * (L(x, xi) - L(mu, i)) * L(rsig, i);
db_val += L(dy, xi);
dg_val += LOAD(dy, xi) * (LOAD(x, xi) - LOAD(mu, i)) * LOAD(rsig, i);
db_val += LOAD(dy, xi);
}
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
......@@ -143,10 +143,12 @@ __global__ void _BatchNormInferenceGrad(
Tx* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int pi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
dx[i] = L(gamma, pi) * L(dy, i) * L(rsig, pi);
dx[i] = LOAD(gamma, pi) * LOAD(dy, i) * LOAD(rsig, pi);
}
}
#undef LOAD
} // namespace
/* ------------------- Launcher Separator ------------------- */
......@@ -294,7 +296,6 @@ __global__ void _BatchNormInferenceGrad(
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
#undef L
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -10,11 +10,11 @@ namespace dragon {
namespace kernel {
#if __CUDA_ARCH__ >= 350
#define L(x, i) __ldg(x + i)
#define LF(x, i) __half2float(__ldg(x + i))
#define LOAD(x, i) __ldg(x + i)
#define LOADF(x, i) __half2float(__ldg(x + i))
#else
#define L(x, i) x[i]
#define LF(x, i) __half2float(x[i])
#define LOAD(x, i) x[i]
#define LOADF(x, i) __half2float(x[i])
#endif
namespace {
......@@ -33,14 +33,14 @@ __global__ void _GroupNormFusedParams(
const int outer_dim = N * G;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const int g = i % G;
const T mu_val = L(mu, i);
const T rsig_val = L(rsig, i);
const T mu_val = LOAD(mu, i);
const T rsig_val = LOAD(rsig, i);
CUDA_2D_KERNEL_LOOP2(j, D) {
const int wi = i * D + j;
const int gi = g * D + j;
const T w = L(gamma, gi) * rsig_val;
const T w = LOAD(gamma, gi) * rsig_val;
scale[wi] = w;
bias[wi] = fma(-w, mu_val, L(beta, gi));
bias[wi] = fma(-w, mu_val, LOAD(beta, gi));
}
}
}
......@@ -56,11 +56,11 @@ __global__ void _GroupNormForwardNCHW(
Tx* y) {
const int outer_dim = N * C;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const Tp w = L(scale, i);
const Tp b = L(bias, i);
const Tp w = LOAD(scale, i);
const Tp b = LOAD(bias, i);
CUDA_2D_KERNEL_LOOP2(j, S) {
const int xi = i * S + j;
y[xi] = fma(L(x, xi), w, b);
y[xi] = fma(LOAD(x, xi), w, b);
}
}
}
......@@ -76,11 +76,11 @@ __global__ void _GroupNormForwardNCHW<half, float>(
half* y) {
const int outer_dim = N * C;
CUDA_2D_KERNEL_LOOP1(i, outer_dim) {
const float w = L(scale, i);
const float b = L(bias, i);
const float w = LOAD(scale, i);
const float b = LOAD(bias, i);
CUDA_2D_KERNEL_LOOP2(j, S) {
const int xi = i * S + j;
y[xi] = __float2half(fmaf(LF(x, xi), w, b));
y[xi] = __float2half(fmaf(LOADF(x, xi), w, b));
}
}
}
......@@ -100,7 +100,7 @@ __global__ void _GroupNormForwardNHWC(
CUDA_2D_KERNEL_LOOP2(j, C) {
const int xi = i * C + j;
const int wi = n * C + j;
y[xi] = fma(L(x, xi), L(scale, wi), L(bias, wi));
y[xi] = fma(LOAD(x, xi), LOAD(scale, wi), LOAD(bias, wi));
}
}
}
......@@ -120,7 +120,7 @@ __global__ void _GroupNormForwardNHWC<half, float>(
CUDA_2D_KERNEL_LOOP2(j, C) {
const int xi = i * C + j;
const int wi = n * C + j;
y[xi] = __float2half(fmaf(LF(x, xi), L(scale, wi), L(bias, wi)));
y[xi] = __float2half(fmaf(LOADF(x, xi), LOAD(scale, wi), LOAD(bias, wi)));
}
}
}
......@@ -149,8 +149,8 @@ __global__ void _GroupNormWGrad(
? (n * outer_dim + i) * S + j % S
: j * outer_dim + i;
const int mi = n * G + i / D;
dg_val += L(dy, xi) * (L(x, xi) - L(mu, mi)) * L(rsig, mi);
db_val += L(dy, xi);
dg_val += LOAD(dy, xi) * (LOAD(x, xi) - LOAD(mu, mi)) * LOAD(rsig, mi);
db_val += LOAD(dy, xi);
}
dg_val = BlockReduce<Tp>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
......@@ -185,8 +185,8 @@ __global__ void _GroupNormWGradHalf(
? (n * outer_dim + i) * S + j % S
: j * outer_dim + i;
const int mi = n * G + i / D;
dg_val += LF(dy, xi) * (LF(x, xi) - L(mu, mi)) * L(rsig, mi);
db_val += LF(dy, xi);
dg_val += LOADF(dy, xi) * (LOADF(x, xi) - LOAD(mu, mi)) * LOAD(rsig, mi);
db_val += LOADF(dy, xi);
}
dg_val = BlockReduce<float>(dg_storage).Reduce(dg_val, cub::Sum());
db_val = BlockReduce<float>(db_storage).Reduce(db_val, cub::Sum());
......@@ -219,8 +219,8 @@ __global__ void _GroupNormInternalGrad(
const int xi = kOrder == StorageOrder::NCHW
? i * inner_dim + j
: (i / G * S + j % S) * G * D + gi;
ds_val += L(gamma, gi) * L(dy, xi) * L(x, xi);
db_val += L(gamma, gi) * L(dy, xi);
ds_val += LOAD(gamma, gi) * LOAD(dy, xi) * LOAD(x, xi);
db_val += LOAD(gamma, gi) * LOAD(dy, xi);
}
ds_val = BlockReduce<Tp>(ds_storage).Reduce(ds_val, cub::Sum());
db_val = BlockReduce<Tp>(db_storage).Reduce(db_val, cub::Sum());
......@@ -253,8 +253,8 @@ __global__ void _GroupNormInternalGradHalf(
const int xi = kOrder == StorageOrder::NCHW
? i * inner_dim + j
: (i / G * S + j % S) * G * D + gi;
ds_val += L(gamma, gi) * LF(dy, xi) * LF(x, xi);
db_val += L(gamma, gi) * LF(dy, xi);
ds_val += LOAD(gamma, gi) * LOADF(dy, xi) * LOADF(x, xi);
db_val += LOAD(gamma, gi) * LOADF(dy, xi);
}
ds_val = BlockReduce<float>(ds_storage).Reduce(ds_val, cub::Sum());
db_val = BlockReduce<float>(db_storage).Reduce(db_val, cub::Sum());
......@@ -285,10 +285,10 @@ __global__ void _GroupNormGrad(
const int mi = kOrder == StorageOrder::NCHW ? i / (D * S)
: i / (C * S) * G + (i / D % G);
const int gi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
const Tp u = fma(L(db, mi), L(mu, mi), -L(ds, mi)) * (L(x, i) - L(mu, mi)) *
utils::math::Cube(L(rsig, mi));
const Tp v = L(db, mi) * L(rsig, mi);
dx[i] = L(gamma, gi) * L(dy, i) * L(rsig, mi) + (u - v) * denom;
const Tp u = fma(LOAD(db, mi), LOAD(mu, mi), -LOAD(ds, mi)) *
(LOAD(x, i) - LOAD(mu, mi)) * utils::math::Cube(LOAD(rsig, mi));
const Tp v = LOAD(db, mi) * LOAD(rsig, mi);
dx[i] = LOAD(gamma, gi) * LOAD(dy, i) * LOAD(rsig, mi) + (u - v) * denom;
}
}
......@@ -312,14 +312,17 @@ __global__ void _GroupNormGradHalf(
const int mi = kOrder == StorageOrder::NCHW ? i / (D * S)
: i / (C * S) * G + (i / D % G);
const int gi = kOrder == StorageOrder::NCHW ? (i / S) % C : i % C;
const float u = fmaf(L(db, mi), L(mu, mi), -L(ds, mi)) *
(LF(x, i) - L(mu, mi)) * utils::math::Cube(L(rsig, mi));
const float v = L(db, mi) * L(rsig, mi);
dx[i] =
__float2half(L(gamma, gi) * LF(dy, i) * L(rsig, mi) + (u - v) * denom);
const float u = fmaf(LOAD(db, mi), LOAD(mu, mi), -LOAD(ds, mi)) *
(LOADF(x, i) - LOAD(mu, mi)) * utils::math::Cube(LOAD(rsig, mi));
const float v = LOAD(db, mi) * LOAD(rsig, mi);
dx[i] = __float2half(
LOAD(gamma, gi) * LOADF(dy, i) * LOAD(rsig, mi) + (u - v) * denom);
}
}
#undef LOAD
#undef LOADF
} // namespace
/* ------------------- Launcher Separator ------------------- */
......@@ -543,8 +546,6 @@ void GroupNormBackward<float16, float, CUDAContext>(
DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
#undef L
#undef LF
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -14,23 +15,9 @@ __global__ void
_BiasAdd(const int nthreads, const int axis_dim, const T* x, const T* b, T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = x[i] + __ldg(b + i % axis_dim);
y[i] = math::PlusFunctor<T>()(x[i], __ldg(b + i % axis_dim));
#else
y[i] = x[i] + b[i % axis_dim];
#endif
}
}
template <>
__global__ void _BiasAdd<half>(
const int nthreads,
const int axis_dim,
const half* x,
const half* b,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hadd(x[i], __ldg(b + i % axis_dim));
y[i] = math::PlusFunctor<T>()(x[i], b[i % axis_dim]);
#endif
}
}
......@@ -45,24 +32,9 @@ __global__ void _BiasAdd(
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 350
y[i] = x[i] + __ldg(b + (i / inner_dim) % axis_dim);
y[i] = math::PlusFunctor<T>()(x[i], __ldg(b + (i / inner_dim) % axis_dim));
#else
y[i] = x[i] + b[(i / inner_dim) % axis_dim];
#endif
}
}
template <>
__global__ void _BiasAdd<half>(
const int nthreads,
const int inner_dim,
const int axis_dim,
const half* x,
const half* b,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
y[i] = __hadd(x[i], __ldg(b + (i / inner_dim) % axis_dim));
y[i] = math::PlusFunctor<T>()(x[i], b[(i / inner_dim) % axis_dim]);
#endif
}
}
......
......@@ -10,6 +10,14 @@ namespace kernel {
namespace {
#if __CUDA_ARCH__ >= 350
#define LOAD(x, i) __ldg(x + i)
#define LOADF(x, i) __half2float(__ldg(x + i))
#else
#define LOAD(x, i) x[i]
#define LOADF(x, i) __half2float(x[i])
#endif
template <typename T>
float ComputeScale(T in_size, T out_size, bool align_corners) {
if (align_corners) {
......@@ -62,17 +70,10 @@ __global__ void _ResizeLinearNCHW(
const float u = w_in - li;
const int offset = (n * C + c) * H;
#if __CUDA_ARCH__ >= 350
const float tl = __ldg(x + ((offset + ti) * W + li));
const float tr = __ldg(x + ((offset + ti) * W + ri));
const float bl = __ldg(x + ((offset + bi) * W + li));
const float br = __ldg(x + ((offset + bi) * W + ri));
#else
const float tl = x[(offset + ti) * W + li];
const float tr = x[(offset + ti) * W + ri];
const float bl = x[(offset + bi) * W + li];
const float br = x[(offset + bi) * W + ri];
#endif
const float tl = LOAD(x, ((offset + ti) * W + li));
const float tr = LOAD(x, ((offset + ti) * W + ri));
const float bl = LOAD(x, ((offset + bi) * W + li));
const float br = LOAD(x, ((offset + bi) * W + ri));
const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u;
y[yi] = (T)(t + (b - t) * v);
......@@ -93,7 +94,6 @@ __global__ void _ResizeLinearNCHW<half>(
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int w = yi % out_w;
const int h = (yi / out_w) % out_h;
const int c = (yi / out_w / out_h) % C;
......@@ -110,15 +110,13 @@ __global__ void _ResizeLinearNCHW<half>(
const float u = w_in - li;
const int offset = (n * C + c) * H;
const float tl = __half2float(__ldg(x + ((offset + ti) * W + li)));
const float tr = __half2float(__ldg(x + ((offset + ti) * W + ri)));
const float bl = __half2float(__ldg(x + ((offset + bi) * W + li)));
const float br = __half2float(__ldg(x + ((offset + bi) * W + ri)));
const float tl = LOADF(x, ((offset + ti) * W + li));
const float tr = LOADF(x, ((offset + ti) * W + ri));
const float bl = LOADF(x, ((offset + bi) * W + li));
const float br = LOADF(x, ((offset + bi) * W + ri));
const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u;
y[yi] = __float2half(t + (b - t) * v);
#endif
}
}
......@@ -152,17 +150,10 @@ __global__ void _ResizeLinearNHWC(
const float u = w_in - li;
const int offset = n * H;
#if __CUDA_ARCH__ >= 350
const float tl = __ldg(x + (((offset + ti) * W + li) * C + c));
const float tr = __ldg(x + (((offset + ti) * W + ri) * C + c));
const float bl = __ldg(x + (((offset + bi) * W + li) * C + c));
const float br = __ldg(x + (((offset + bi) * W + ri) * C + c));
#else
const float tl = x[((offset + ti) * W + li) * C + c];
const float tr = x[((offset + ti) * W + ri) * C + c];
const float bl = x[((offset + bi) * W + li) * C + c];
const float br = x[((offset + bi) * W + ri) * C + c];
#endif
const float tl = LOAD(x, (((offset + ti) * W + li) * C + c));
const float tr = LOAD(x, (((offset + ti) * W + ri) * C + c));
const float bl = LOAD(x, (((offset + bi) * W + li) * C + c));
const float br = LOAD(x, (((offset + bi) * W + ri) * C + c));
const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u;
y[yi] = (T)(t + (b - t) * v);
......@@ -183,7 +174,6 @@ __global__ void _ResizeLinearNHWC<half>(
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int c = yi % C;
const int w = (yi / C) % out_w;
const int h = (yi / C / out_w) % out_h;
......@@ -200,19 +190,13 @@ __global__ void _ResizeLinearNHWC<half>(
const float u = w_in - li;
const int offset = n * H;
const float tl =
__half2float(__ldg(x + (((offset + ti) * W + li) * C + c)));
const float tr =
__half2float(__ldg(x + (((offset + ti) * W + ri) * C + c)));
const float bl =
__half2float(__ldg(x + (((offset + bi) * W + li) * C + c)));
const float br =
__half2float(__ldg(x + (((offset + bi) * W + ri) * C + c)));
const float tl = LOADF(x, (((offset + ti) * W + li) * C + c));
const float tr = LOADF(x, (((offset + ti) * W + ri) * C + c));
const float bl = LOADF(x, (((offset + bi) * W + li) * C + c));
const float br = LOADF(x, (((offset + bi) * W + ri) * C + c));
const float t = tl + (tr - tl) * u;
const float b = bl + (br - bl) * u;
y[yi] = __float2half(t + (b - t) * v);
#endif
}
}
......@@ -245,13 +229,8 @@ __global__ void _ResizeLinearGradNCHW(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li;
#if __CUDA_ARCH__ >= 350
const float dt = (1.f - v) * ((float)__ldg(dy + yi));
const float db = v * ((float)__ldg(dy + yi));
#else
const float dt = (1.f - v) * ((float)dy[yi]);
const float db = v * ((float)dy[yi]);
#endif
const float dt = (1.f - v) * LOAD(dy, yi);
const float db = v * LOAD(dy, yi);
const int offset = (n * C + c) * H;
atomicAdd(&dx[(offset + ti) * W + li], (1.f - u) * dt);
......@@ -275,7 +254,6 @@ __global__ void _ResizeLinearGradNCHW<half>(
const half* dy,
float* dx) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int w = yi % out_w;
const int h = (yi / out_w) % out_h;
const int c = (yi / out_w / out_h) % C;
......@@ -291,15 +269,14 @@ __global__ void _ResizeLinearGradNCHW<half>(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li;
const float dt = (1.f - v) * __half2float(__ldg(dy + yi));
const float db = v * __half2float(__ldg(dy + yi));
const float dt = (1.f - v) * LOADF(dy, yi);
const float db = v * LOADF(dy, yi);
const int offset = (n * C + c) * H;
atomicAdd(&dx[(offset + ti) * W + li], (1.f - u) * dt);
atomicAdd(&dx[(offset + ti) * W + ri], u * dt);
atomicAdd(&dx[(offset + bi) * W + li], (1.f - u) * db);
atomicAdd(&dx[(offset + bi) * W + ri], u * db);
#endif
}
}
......@@ -332,13 +309,8 @@ __global__ void _ResizeLinearGradNHWC(
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li;
#if __CUDA_ARCH__ >= 350
const float dt = (1.f - v) * ((float)__ldg(dy + yi));
const float db = v * ((float)__ldg(dy + yi));
#else
const float dt = (1.f - v) * ((float)dy[yi]);
const float db = v * ((float)dy[yi]);
#endif
const float dt = (1.f - v) * LOAD(dy, yi);
const float db = v * LOAD(dy, yi);
const int offset = n * H;
atomicAdd(&dx[((offset + ti) * W + li) * C + c], (1.f - u) * dt);
......@@ -348,6 +320,49 @@ __global__ void _ResizeLinearGradNHWC(
}
}
template <>
__global__ void _ResizeLinearGradNHWC<half>(
const int nthreads,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const float scale_h,
const float scale_w,
const bool align_corners,
const half* dy,
float* dx) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int c = yi % C;
const int w = (yi / C) % out_w;
const int h = (yi / C / out_w) % out_h;
const int n = yi / C / out_w / out_h;
const float h_in = TransformCoordinate(h, scale_h, align_corners);
const int ti = floorf(h_in);
const int bi = (h_in < H - 1) ? ceilf(h_in) : H - 1;
const float v = h_in - ti;
const float w_in = TransformCoordinate(w, scale_w, align_corners);
const int li = floorf(w_in);
const int ri = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float u = w_in - li;
const float dt = (1.f - v) * LOADF(dy, yi);
const float db = v * LOADF(dy, yi);
const int offset = n * H;
atomicAdd(&dx[((offset + ti) * W + li) * C + c], (1.f - u) * dt);
atomicAdd(&dx[((offset + ti) * W + ri) * C + c], u * dt);
atomicAdd(&dx[((offset + bi) * W + li) * C + c], (1.f - u) * db);
atomicAdd(&dx[((offset + bi) * W + ri) * C + c], u * db);
}
}
#undef LOAD
#undef LOADF
} // namespace
/* ------------------- Launcher Separator ------------------- */
......
......@@ -104,15 +104,17 @@ __global__ void _ResizeNearestGradNCHW<half>(
const half* dy,
float* dx) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int w = yi % out_w;
const int h = (yi / out_w) % out_h;
const int c = (yi / out_w / out_h) % C;
const int n = yi / out_w / out_h / C;
const int h_in = min(int(h * scale_h), H - 1);
const int w_in = min(int(w * scale_w), W - 1);
#if __CUDA_ARCH__ >= 350
atomicAdd(
&dx[((n * C + c) * H + h_in) * W + w_in], __half2float(__ldg(dy + yi)));
#else
atomicAdd(&dx[((n * C + c) * H + h_in) * W + w_in], __half2float(dy[yi]));
#endif
}
}
......@@ -157,15 +159,17 @@ __global__ void _ResizeNearestGradNHWC<half>(
const half* dy,
float* dx) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int c = yi % C;
const int w = (yi / C) % out_w;
const int h = (yi / C / out_w) % out_h;
const int n = yi / C / out_w / out_h;
const int h_in = min(int(h * scale_h), H - 1);
const int w_in = min(int(w * scale_w), W - 1);
#if __CUDA_ARCH__ >= 350
atomicAdd(
&dx[((n * H + h_in) * W + w_in) * C + c], __half2float(__ldg(dy + yi)));
#else
atomicAdd(&dx[((n * H + h_in) * W + w_in) * C + c], __half2float(dy[yi]));
#endif
}
}
......
......@@ -9,6 +9,14 @@ namespace kernel {
namespace {
#if __CUDA_ARCH__ >= 350
#define LOAD(x, i) __ldg(x + i)
#define LOADF(x, i) __half2float(__ldg(x + i))
#else
#define LOAD(x, i) x[i]
#define LOADF(x, i) __half2float(x[i])
#endif
template <typename T>
__device__ float
_RoiAlignIntp(const int H, const int W, float h, float w, const T* x) {
......@@ -34,17 +42,10 @@ _RoiAlignIntp(const int H, const int W, float h, float w, const T* x) {
w = (float)li;
}
#if __CUDA_ARCH__ >= 350
const float tl = __ldg(x + (ti * W + li));
const float tr = __ldg(x + (ti * W + ri));
const float bl = __ldg(x + (bi * W + li));
const float br = __ldg(x + (bi * W + ri));
#else
const float tl = x[ti * W + li];
const float tr = x[ti * W + ri];
const float bl = x[bi * W + li];
const float br = x[bi * W + ri];
#endif
const float tl = LOAD(x, (ti * W + li));
const float tr = LOAD(x, (ti * W + ri));
const float bl = LOAD(x, (bi * W + li));
const float br = LOAD(x, (bi * W + ri));
const float v = h - ti;
const float u = w - li;
......@@ -79,17 +80,10 @@ _RoiAlignIntp<half>(const int H, const int W, float h, float w, const half* x) {
w = (float)li;
}
#if __CUDA_ARCH__ >= 350
const float tl = __half2float(__ldg(x + (ti * W + li)));
const float tr = __half2float(__ldg(x + (ti * W + ri)));
const float bl = __half2float(__ldg(x + (bi * W + li)));
const float br = __half2float(__ldg(x + (bi * W + ri)));
#else
const float tl = __half2float(x[ti * W + li]);
const float tr = __half2float(x[ti * W + ri]);
const float bl = __half2float(x[bi * W + li]);
const float br = __half2float(x[bi * W + ri]);
#endif
const float tl = LOADF(x, (ti * W + li));
const float tr = LOADF(x, (ti * W + ri));
const float bl = LOADF(x, (bi * W + li));
const float br = LOADF(x, (bi * W + ri));
const float v = h - ti;
const float u = w - li;
......@@ -389,6 +383,9 @@ __global__ void _RoiAlignGrad<half>(
}
}
#undef LOAD
#undef LOADF
} // namespace
/* ------------------- Launcher Separator ------------------- */
......
......@@ -58,30 +58,29 @@ __global__ void _RoiPool(
wend = min(max(wend + roi_start_w, 0), W);
const bool empty = (hend <= hstart) || (wend <= wstart);
int max_idx = empty ? -1 : 0;
const T* offset_x = x + (batch_ind * C + c) * H * W;
int maxi = empty ? -1 : 0;
T val = empty ? T(0) : offset_x[0];
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int xi = h * W + w;
const int xi = h * W + w;
#if __CUDA_ARCH__ >= 350
if (__ldg(offset_x + xi) > val) {
maxi = xi;
val = __ldg(offset_x + xi);
max_idx = xi;
}
#else
if (x[xi] > val) {
maxi = xi;
if (offset_x[xi] > val) {
val = offset_x[xi];
max_idx = xi;
}
#endif
}
}
y[yi] = val;
mask[yi] = maxi;
mask[yi] = max_idx;
}
}
......@@ -98,9 +97,7 @@ __global__ void _RoiPool<half>(
const float* rois,
int* mask,
half* y) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int ow = yi % out_w;
const int oh = (yi / out_w) % out_h;
const int c = (yi / out_w / out_h) % C;
......@@ -136,24 +133,42 @@ __global__ void _RoiPool<half>(
wend = min(max(wend + roi_start_w, 0), W);
const bool empty = (hend <= hstart) || (wend <= wstart);
int max_idx = empty ? -1 : 0;
const half* offset_x = x + ((batch_ind * C + c) * H * W);
int maxi = empty ? -1 : 0;
half val = empty ? kZero : __ldg(offset_x);
#if __CUDA_ARCH__ >= 530
half val = empty ? __float2half(0.f) : __ldg(offset_x);
#else
float val = empty ? 0.f : __half2float(*offset_x);
#endif
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int xi = h * W + w;
const int xi = h * W + w;
#if __CUDA_ARCH__ >= 530
if (__hgt(__ldg(offset_x + xi), val)) {
maxi = xi;
val = __ldg(offset_x + xi);
max_idx = xi;
}
#elif __CUDA_ARCH__ >= 350
if (__half2float(__ldg(offset_x + xi)) > val) {
val = __half2float(__ldg(offset_x + xi));
max_idx = xi;
}
#else
if (__half2float(offset_x[xi]) > val) {
val = __half2float(offset_x[xi]);
max_idx = xi;
}
#endif
}
}
#if __CUDA_ARCH__ >= 530
y[yi] = val;
mask[yi] = maxi;
#else
y[yi] = __float2half(val);
#endif
mask[yi] = max_idx;
}
}
......@@ -205,19 +220,22 @@ __global__ void _RoiPoolGrad<half>(
const int* mask,
float* dx) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
#if __CUDA_ARCH__ >= 530
const int c = (yi / out_w / out_h) % C;
const int n = yi / out_w / out_h / C;
const float* roi = rois + n * 5;
const int batch_ind = roi[0];
if (batch_ind < 0) continue;
float* offset_dx = dx + (batch_ind * C + c) * H * W;
#if __CUDA_ARCH__ >= 350
if (__ldg(mask + yi) != -1) {
atomicAdd(offset_dx + __ldg(mask + yi), __half2float(dy[yi]));
}
#else
if (mask[yi] != -1) {
atomicAdd(offset_dx + mask[yi], __half2float(dy[yi]));
}
#endif
}
}
......
......@@ -12,27 +12,20 @@ void DropBlock2dOp<Context>::DoRunWithType() {
if (phase() == "TEST") {
Y->ReshapeLike(X)->CopyFrom(X, ctx());
} else if (phase() == "TRAIN") {
int64_t feat_h, feat_w, seed_h, seed_w;
int64_t feature_h, feature_w, seed_h, seed_w;
if (data_format() == "NCHW") {
feat_h = X.dim(2), feat_w = X.dim(3);
feature_h = X.dim(2), feature_w = X.dim(3);
} else if (data_format() == "NHWC") {
feat_h = X.dim(1), feat_w = X.dim(2);
feature_h = X.dim(1), feature_w = X.dim(2);
} else {
LOG(FATAL) << "Unknown DataFormat: " << data_format();
}
seed_h = feat_h - block_size_ + 1;
seed_w = feat_w - block_size_ + 1;
CHECK(seed_h > 0 && seed_w > 0) << "\nExcepted block_size <= feat_size.";
// Schedule the keep ratio
auto kp = keep_prob();
if (decrement_ > 0.f && prob_ > kp) {
prob_ -= decrement_;
} else {
prob_ = kp; // Fixed to the limit value
}
seed_h = feature_h - block_size_ + 1;
seed_w = feature_w - block_size_ + 1;
CHECK(seed_h > 0 && seed_w > 0) << "\nExcepted block_size <= feature_size.";
// Compute the drop ratio
float gamma = (1.f - prob_) / std::pow(block_size_, 2);
gamma *= (alpha_ * (feat_h * feat_w) / (seed_h * seed_w));
float gamma = ratio() / std::pow(block_size_, 2);
gamma *= (float(feature_h * feature_w) / float(seed_h * seed_w));
// Prepare buffers
auto* mask = Buffer("mask")
->ReshapeLike(X)
......@@ -50,8 +43,8 @@ void DropBlock2dOp<Context>::DoRunWithType() {
kernel::DropBlock2d(
X.dim(0),
data_format() == "NCHW" ? X.dim(1) : X.dim(-1),
feat_h,
feat_w,
feature_h,
feature_w,
seed_h,
seed_w,
block_size_,
......
......@@ -22,10 +22,8 @@ class DropBlock2dOp final : public Operator<Context> {
public:
DropBlock2dOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
block_size_(OP_SINGLE_ARG(int64_t, "block_size", 7)),
alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)),
decrement_(OP_SINGLE_ARG(float, "decrement", 0.f)) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, keep_prob, 0.9f);
block_size_(OP_SINGLE_ARG(int64_t, "block_size", 7)) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, ratio, 0.1f);
}
USE_OPERATOR_FUNCTIONS;
......@@ -36,8 +34,7 @@ class DropBlock2dOp final : public Operator<Context> {
protected:
int64_t block_size_;
float alpha_, decrement_, prob_ = 1.;
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, keep_prob);
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, ratio);
};
template <class Context>
......@@ -52,7 +49,7 @@ class DropBlock2dGradientOp final : public Operator<Context> {
void DoRunWithType();
};
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropBlock2dOp, keep_prob);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropBlock2dOp, ratio);
} // namespace dragon
......
......@@ -12,27 +12,19 @@ void DropPathOp<Context>::DoRunWithType() {
if (phase() == "TEST") {
Y->ReshapeLike(X)->CopyFrom(X, ctx());
} else if (phase() == "TRAIN") {
// Schedule the drop ratio
auto dp = prob();
if (inc_ > 0.f && drop_prob_ < dp) {
drop_prob_ += inc_;
} else {
drop_prob_ = dp; // Fixed to the limit value
}
auto* mask = Buffer("mask")
->Reshape({X.dim(0)})
->template mutable_data<float, Context>();
auto* scale = Buffer("scale")
->Reshape({})
->template mutable_data<float, CPUContext>();
scale[0] = 1.f / (1.f - drop_prob_);
// Generate mask for each example
math::RandomUniform(X.dim(0), 0.f, 1.f, mask, ctx());
// Apply mask to the feature
kernel::DropPath(
X.dim(0),
X.stride(0),
scale[0],
1.f / (1.f - ratio()),
X.template data<T, Context>(),
mask,
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
......@@ -57,7 +49,7 @@ void DropPathGradientOp<Context>::DoRunWithType() {
kernel::DropPath(
dY.dim(0),
dY.stride(0),
Buffer("scale")->template data<float, CPUContext>()[0],
1.f / (1.f - ratio()),
dY.template data<T, Context>(),
Buffer("mask")->template data<float, Context>(),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
......
......@@ -10,8 +10,8 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_DROPPATH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROPPATH_OP_H_
#ifndef DRAGON_OPERATORS_ACTIVATION_DROP_PATH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROP_PATH_OP_H_
#include "dragon/core/operator.h"
......@@ -21,9 +21,8 @@ template <class Context>
class DropPathOp final : public Operator<Context> {
public:
DropPathOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
inc_(OP_SINGLE_ARG(float, "increment", 0.f)) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, prob, 0.2f);
: Operator<Context>(def, ws) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, ratio, 0.2f);
}
USE_OPERATOR_FUNCTIONS;
......@@ -33,24 +32,30 @@ class DropPathOp final : public Operator<Context> {
void DoRunWithType();
protected:
float inc_, drop_prob_ = 0.f;
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, prob);
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, ratio);
};
template <class Context>
class DropPathGradientOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(DropPathGradientOp);
DropPathGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, ratio, 0.5f);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, ratio);
};
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropPathOp, prob);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropPathOp, ratio);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropPathGradientOp, ratio);
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_DROPPATH_OP_H_
#endif // DRAGON_OPERATORS_ACTIVATION_DROP_PATH_OP_H_
......@@ -15,8 +15,8 @@ void DropoutOp<Context>::DoRunWithType() {
Buffer("mask")->ReshapeLike(X);
kernel::Dropout(
X.count(),
prob(),
1.f / (1.f - prob()),
ratio(),
1.f / (1.f - ratio()),
X.template data<T, Context>(),
Buffer("mask")->template mutable_data<uint8_t, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
......@@ -41,7 +41,7 @@ void DropoutGradientOp<Context>::DoRunWithType() {
} else if (phase() == "TRAIN") {
kernel::ApplyMask(
dY.count(),
1.f / (1.f - prob()),
1.f / (1.f - ratio()),
dY.template data<T, Context>(),
Buffer("mask")->template data<uint8_t, Context>(),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
......
......@@ -22,7 +22,7 @@ class DropoutOp : public Operator<Context> {
public:
DropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, prob, 0.5f);
INIT_OP_SINGLE_ARG_WITH_DESC(float, ratio, 0.5f);
}
USE_OPERATOR_FUNCTIONS;
......@@ -32,7 +32,7 @@ class DropoutOp : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, prob);
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, ratio);
};
template <class Context>
......@@ -40,7 +40,7 @@ class DropoutGradientOp : public Operator<Context> {
public:
DropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, prob, 0.5f);
INIT_OP_SINGLE_ARG_WITH_DESC(float, ratio, 0.5f);
}
USE_OPERATOR_FUNCTIONS;
......@@ -50,11 +50,11 @@ class DropoutGradientOp : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, prob);
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, ratio);
};
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutOp, prob);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutGradientOp, prob);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutOp, ratio);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutGradientOp, ratio);
#ifdef USE_CUDNN
......
......@@ -28,7 +28,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
dropout_desc_,
ctx()->cudnn_handle(),
this->prob(),
this->ratio(),
X_states->template mutable_data<uint8_t, Context>(),
states_size,
rng_seed_));
......@@ -37,7 +37,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
CUDNN_CHECK(cudnnSetDropoutDescriptor(
dropout_desc_,
ctx()->cudnn_handle(),
this->prob(),
this->ratio(),
X_states->template mutable_data<uint8_t, Context>(),
states_size,
rng_seed_));
......@@ -86,7 +86,7 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
dropout_desc_,
ctx()->cudnn_handle(),
this->prob(),
this->ratio(),
X_states->template mutable_data<uint8_t, Context>(),
states_size,
rng_seed_));
......
......@@ -24,14 +24,14 @@ from dragon.core.ops.utils import parse_args
@OpSchema.num_inputs(1)
@ArgHelper.desc('prob', as_target=False)
def dropout(inputs, prob=0.5, scale=True, **kwargs):
@ArgHelper.desc('ratio', as_target=False)
def dropout(inputs, ratio=0.5, **kwargs):
r"""Set the elements of the input to zero randomly.
`[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
The **Dropout** function is defined as:
.. math:: \text{Dropout}(x) = x * (r \sim \mathcal{B}(1, 1 - \text{prob}))
.. math:: \text{Dropout}(x) = x * (r \sim \mathcal{B}(1, 1 - \text{ratio}))
Examples:
......@@ -44,10 +44,8 @@ def dropout(inputs, prob=0.5, scale=True, **kwargs):
----------
inputs : dragon.Tensor
The input tensor.
prob : Union[float, dragon.Tensor], optional, default=0.2
The dropping probability.
scale : bool, optional, default=True
Whether to scale the output during training.
ratio : Union[float, dragon.Tensor], optional, default=0.5
The dropping ratio.
Returns
-------
......@@ -56,56 +54,39 @@ def dropout(inputs, prob=0.5, scale=True, **kwargs):
"""
args = parse_args(locals())
args['prob'] = float(prob)
inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Dropout
if context.executing_eagerly():
return op_lib \
.instantiate(prob=args['prob'], scale=scale) \
.apply([inputs], inplace=inplace)
.instantiate() \
.apply([inputs], ratio, inplace=inplace)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
@ArgHelper.desc('keep_prob', as_target=False)
def drop_block2d(
inputs,
block_size=7,
keep_prob=0.9,
alpha=1.,
decrement=0.,
data_format='NCHW',
**kwargs
):
@ArgHelper.desc('ratio', as_target=False)
def drop_block2d(inputs, ratio=0.1, block_size=7, data_format='NCHW', **kwargs):
r"""Set the spatial blocks over input to zero randomly.
`[Ghiasi et.al, 2018] <https://arxiv.org/abs/1810.12890>`_.
The **DropBlock** function is defined as:
.. math::
\text{DropBlock}(x_{ijk} =
x_{ijk} * (r_{ik} \sim \mathcal{B}(1, \alpha\gamma)) \\ \quad \\
\text{DropBlock}(x_{ijk}) =
x_{ijk} * (r_{ik} \sim \mathcal{B}(1, 1 - \gamma)) \\ \quad \\
\text{where}\quad \gamma =
\frac{\text{keep\_prob}}{\text{block\_size}^{n}}
\frac{\text{ratio}}{\text{block\_size}^{n}}
\frac{\text{feat\_size}^{n}}{(\text{feat\_size} - \text{block\_size} + 1)^n}
Set the ``decrement`` to schedule ``keep_prob`` from **1.0**.
Set the ``alpha`` to decrease :math:`\gamma` for different stages.
Parameters
----------
inputs : dragon.Tensor
The input tensor.
ratio : Union[float, dragon.Tensor], optional, default=0.1
The dropping ratio.
block_size : int, optional, default=7
The size of a spatial block.
keep_prob : Union[float, dragon.Tensor], optional, default=0.9
The keeping prob.
alpha : float, optional, default=1.
The value to :math:`\gamma`.
decrement : float, optional, default=0.
The decrement value to ``keep_prob``.
The spatial block size.
data_format : {'NCHW', 'NHWC'}, optional
The optional data format.
......@@ -116,43 +97,34 @@ def drop_block2d(
"""
args = parse_args(locals())
args['alpha'] = float(alpha)
args['decrement'] = float(decrement)
inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.DropBlock2d
if context.executing_eagerly():
return op_lib \
.instantiate(
block_size=block_size,
keep_prob=float(args['keep_prob']),
alpha=args['alpha'],
decrement=args['decrement'],
data_format=data_format,
slot=args.get('slot', None),
).apply([inputs], args.get('inplace', False))
).apply([inputs], ratio, inplace=inplace)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
@ArgHelper.desc('prob', as_target=False)
def drop_path(inputs, prob=0.2, increment=0., **kwargs):
@ArgHelper.desc('ratio', as_target=False)
def drop_path(inputs, ratio=0.2, **kwargs):
r"""Set the examples over the input to zero randomly.
`[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_.
The **DropPath** function is defined as:
.. math:: \text{DropPath}(x_{ij}) = x_{ij} * (r_{i} \sim \mathcal{B}(1, 1 - \text{prob}))
Set the ``increment`` to schedule ``prob`` from **0.0** after each run.
.. math:: \text{DropPath}(x_{ij}) = x_{ij} * (r_{i} \sim \mathcal{B}(1, 1 - \text{ratio}))
Parameters
----------
inputs : dragon.Tensor
The input tensor.
prob : Union[float, dragon.Tensor], optional, default=0.2
The dropping prob.
increment : float, optional, default=0.0
The increment ``prob``.
ratio : Union[float, dragon.Tensor], optional, default=0.2
The dropping ratio.
Returns
-------
......@@ -161,23 +133,18 @@ def drop_path(inputs, prob=0.2, increment=0., **kwargs):
"""
args = parse_args(locals())
args['prob'] = float(prob)
args['increment'] = float(increment)
inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.DropPath
if context.executing_eagerly():
return op_lib \
.instantiate(
prob=args['prob'],
increment=args['increment'],
slot=args.get('slot', None),
).apply([inputs], inplace=inplace)
.instantiate() \
.apply([inputs], ratio, inplace=inplace)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def elu(inputs, alpha=1., **kwargs):
def elu(inputs, alpha=1.0, **kwargs):
r"""Apply the exponential linear unit.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
......@@ -201,7 +168,7 @@ def elu(inputs, alpha=1., **kwargs):
----------
inputs : dragon.Tensor
The input tensor.
alpha : float, optional, default=1.
alpha : float, optional, default=1.0
The value to :math:`\alpha`.
Returns
......
......@@ -32,63 +32,57 @@ class Activation(Operator):
return self.dispatch(inputs, outputs)
class Dropout(Activation):
class Dropout(Operator):
"""Dropout operator."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
self.prob = kwargs.get('prob', 0.5)
self.scale = kwargs.get('scale', True)
def attributes(self):
return {
'op_type': 'Dropout',
'arguments': {
'prob': self.prob,
'scale': self.scale,
}
'arguments': {'ratio_desc': '${HANDLE}/ratio'},
}
def feed(self, ws, handle, ratio):
self.feed_arg(ws, '{}/ratio'.format(handle), ratio, 'float32')
def forward(self, inputs, ratio, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs,
callback=lambda ws, handle:
self.feed(ws, handle, ratio))
class DropBlock2d(Activation):
class DropBlock2d(Dropout):
"""DropBlock2d operator."""
def __init__(self, key, dev, **kwargs):
super(DropBlock2d, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', 7)
self.keep_prob = kwargs.get('keep_prob', 0.9)
self.alpha = kwargs.get('alpha', 1.)
self.decrement = kwargs.get('decrement', 0.)
self.data_format = kwargs.get('data_format', 'NCHW')
def attributes(self):
return {
'op_type': 'DropBlock2d',
'arguments': {
'ratio_desc': '${HANDLE}/ratio',
'block_size': self.block_size,
'keep_prob': self.keep_prob,
'alpha': self.alpha,
'decrement': self.decrement,
'data_format': self.data_format,
},
}
class DropPath(Activation):
class DropPath(Dropout):
"""DropPath operator."""
def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs)
self.prob = kwargs.get('prob', 0.2)
self.increment = kwargs.get('increment', 0.)
def attributes(self):
return {
'op_type': 'DropPath',
'arguments': {
'prob': self.prob,
'increment': self.increment,
}
'arguments': {'ratio_desc': '${HANDLE}/ratio'},
}
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_MATH_SORT_H_
#define DRAGON_UTILS_MATH_SORT_H_
#endif // DRAGON_UTILS_MATH_SORT_H_
......@@ -33,7 +33,7 @@ void ApplyMask(
template <typename T, class Context>
void Dropout(
const int count,
const float prob,
const float ratio,
const float scale,
const T* x,
uint8_t* mask,
......
......@@ -133,7 +133,7 @@ class Dropout(Layer):
The **Dropout** function is defined as:
.. math:: \text{Dropout}(x) = x * \text{Bernoulli}(p=1 - prob)
.. math:: \text{Dropout}(x) = x * (r \sim \mathcal{B}(1, 1 - \text{rate}))
Examples:
......@@ -150,7 +150,7 @@ class Dropout(Layer):
Parameters
----------
rate : Union[float, dragon.Tensor]
The dropping probability.
The dropping ratio.
"""
super(Dropout, self).__init__(**kwargs)
......@@ -161,7 +161,7 @@ class Dropout(Layer):
if self.trainable:
return activation_ops.dropout(
inputs,
prob=self.rate,
ratio=self.rate,
inplace=self.inplace,
)
return inputs
......@@ -238,8 +238,7 @@ class Permute(Layer):
if sorted(dims) != list(range(1, len(dims) + 1)):
raise ValueError(
'Argument <dims> should be consecutive and start from 1.\n'
'Got {}'.format(str(dims))
)
'Got {}'.format(str(dims)))
self.input_spec = InputSpec(ndim=len(self.dims) + 1)
def call(self, inputs):
......
......@@ -408,7 +408,7 @@ def dropout(x, rate, name=None, **kwargs):
x : dragon.Tensor
The tensor :math:`x`.
rate : Union[float, dragon.Tensor]
The dropping probability.
The dropping ratio.
name : str, optional
The operation name.
......
......@@ -97,14 +97,14 @@ class TestActivationOps(OpTestCase):
self.cudnn_ws = dragon.Workspace()
def test_dropout(self):
prob = 0.
ratio = 0.
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data = uniform((2, 3))
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.nn.dropout(x, prob=prob)
y = dragon.nn.dropout(x, ratio=ratio)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual([y, dx], [data, data])
......@@ -121,7 +121,7 @@ class TestActivationOps(OpTestCase):
self.test_dropout()
def test_drop_block2d(self):
keep_prob, block_size = 1., 2
ratio, block_size = 0., 2
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for data_format in ('NCHW', 'NHWC'):
......@@ -132,8 +132,8 @@ class TestActivationOps(OpTestCase):
tape.watch(x)
y = dragon.nn.drop_block2d(
x,
ratio=ratio,
block_size=block_size,
keep_prob=keep_prob,
data_format=data_format)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual([y, dx], [data, data])
......@@ -145,14 +145,14 @@ class TestActivationOps(OpTestCase):
self.test_drop_block2d()
def test_drop_path(self):
prob = 0.
ratio = 0.
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
data = uniform((2, 3))
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.nn.drop_path(x, prob=prob)
y = dragon.nn.drop_path(x, ratio=ratio)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual([y, dx], [data, data])
......
......@@ -35,6 +35,10 @@ TESTS_AND_SOURCES = [
('torch/test_torch', 'dragon.vm.torch.core'),
]
DISTRIBUTED_BLOCKLIST = [
'dragon/test_distributed',
]
TESTS = [t[0] for t in TESTS_AND_SOURCES]
SOURCES = [t[1] for t in TESTS_AND_SOURCES]
......@@ -66,6 +70,10 @@ def parse_args():
metavar='TESTS',
default=[],
help='select a set of tests to exclude')
parser.add_argument(
'--ignore-distributed-blocklist',
action='store_true',
help='always run blocklisted distributed tests')
return parser.parse_args()
......@@ -95,6 +103,9 @@ def main():
base_command = get_base_command(args)
tests, sources = get_selected_tests(args, TESTS, SOURCES)
for i, test in enumerate(tests):
if (test in DISTRIBUTED_BLOCKLIST and
not args.ignore_distributed_blocklist):
continue
command = base_command[:]
if args.coverage:
if sources[i]:
......
......@@ -292,26 +292,26 @@ class TestModules(OpTestCase):
self.assertEqual(y, result)
def test_dropout(self):
prob = 0.
p = 0.
data = uniform((2, 3))
x = new_tensor(data)
m = torch.nn.Dropout(p=prob, inplace=True)
m = torch.nn.Dropout(p, inplace=True)
y, _ = m(x), repr(m)
self.assertEqual(y, data)
def test_drop_block2d(self):
keep_prob = 1.
p = 0.
data = uniform((2, 3, 4, 4))
x = new_tensor(data)
m = torch.nn.DropBlock2d(kp=keep_prob, block_size=2, inplace=True)
m = torch.nn.DropBlock2d(p, block_size=2, inplace=True)
y, _ = m(x), repr(m)
self.assertEqual(y, data)
def test_drop_path(self):
prob = 0.
p = 0.
data = uniform((2, 3))
x = new_tensor(data)
m = torch.nn.DropPath(p=prob, inplace=True)
m = torch.nn.DropPath(p, inplace=True)
y, _ = m(x), repr(m)
self.assertEqual(y, data)
......
......@@ -427,16 +427,16 @@ def dropout(input, p=0.5, training=True, inplace=False):
The **Dropout** function is defined as:
.. math:: \text{Dropout}(x) = x * \text{Bernoulli}(p=1 - prob)
.. math:: \text{Dropout}(x) = x * (r \sim \mathcal{B}(1, 1 - p))
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
p : float, optional, default=0.5
The dropping prob.
The dropping ratio.
training : bool, optional, default=True
The training flag.
Apply dropping if **True**.
inplace : bool, optional, default=False
Whether to do the operation in-place.
......@@ -453,50 +453,40 @@ def dropout(input, p=0.5, training=True, inplace=False):
if not training:
return input
return _functions.Dropout \
.instantiate(
input.device,
p=p,
).apply(input, inplace=inplace)
.instantiate(input.device) \
.apply(input, p, inplace=inplace)
def drop_block2d(
input,
kp=0.9,
p=0.1,
block_size=7,
alpha=1.,
decrement=0.,
training=True,
inplace=False,
slot=None,
):
r"""Set the spatial blocks over input to zero randomly.
The **DropBlock** function is defined as:
.. math::
\text{DropBlock}(x) = x \cdot \text{Bernoulli}(\alpha\cdot\gamma) \\
\quad \\ \text{where}\quad \gamma =
\frac{keep\_prob}{block\_size^{n}}
\frac{feat\_size^{n}}{(feat\_size - block\_size + 1)^n}
\text{DropBlock}(x_{ijk}) =
x_{ijk} * (r_{ik} \sim \mathcal{B}(1, 1 - \gamma)) \\ \quad \\
\text{where}\quad \gamma =
\frac{p}{\text{block\_size}^{n}}
\frac{\text{feat\_size}^{n}}{(\text{feat\_size} - \text{block\_size} + 1)^n}
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
kp : float, optional, default=0.9
The keeping prob.
p : float, optional, default=0.1
The dropping ratio.
block_size : int, optional, default=7
The size of a spatial block.
alpha : float, optional, default=1.
The scale factor to :math:`\gamma`.
decrement : float, optional, default=0.
The decrement value to ``kp``.
The spatial block size.
training : bool, optional, default=True
The training flag.
Apply dropping if **True**.
inplace : bool, optional, default=False
Whether to do the operation in-place.
slot : int, optional
The optional slot index.
Returns
-------
......@@ -513,28 +503,17 @@ def drop_block2d(
return _functions.DropBlock2d \
.instantiate(
input.device,
keep_prob=kp,
block_size=block_size,
alpha=alpha,
decrement=decrement,
slot=slot,
).apply(input, inplace=inplace)
).apply(input, p, inplace=inplace)
def drop_path(
input,
p=0.2,
increment=0.,
training=True,
inplace=False,
slot=None,
):
def drop_path(input, p=0.2, training=True, inplace=False):
r"""Set the examples over input to zero randomly.
`[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_.
The **DropPath** function is defined as:
.. math:: \text{DropPath}(x) = x * \text{Bernoulli}(p=1 - prob)
.. math:: \text{DropPath}(x_{ij}) = x_{ij} * (r_{i} \sim \mathcal{B}(1, 1 - p))
Parameters
----------
......@@ -542,14 +521,10 @@ def drop_path(
The input tensor.
p : float, optional, default=0.2
The dropping prob.
increment : float, optional, default=0.
The increment value to ``p``.
training : bool, optional, default=True
The training flag.
Apply dropping if **True**.
inplace : bool, optional, default=False
Whether to do the operation in-place.
slot : int, optional
The optional slot index.
Returns
-------
......@@ -564,12 +539,8 @@ def drop_path(
if not training:
return input
return _functions.DropPath \
.instantiate(
input.device,
p=p,
increment=increment,
slot=slot,
).apply(input, inplace=inplace)
.instantiate(input.device) \
.apply(input, p, inplace=inplace)
def elu(input, alpha=1., inplace=False):
......
......@@ -169,55 +169,56 @@ class DepthwiseConv2d(_ConvNd):
super(DepthwiseConv2d, self).__init__(key, dev, **kwargs)
class DropBlock2d(_Activation):
class Dropout(function.Function):
"""Dropout function."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'Dropout',
'arguments': {'ratio_desc': '${HANDLE}/ratio'},
}
def feed(self, ws, handle, ratio):
self.feed_arg(ws, '{}/ratio'.format(handle), ratio, 'float32')
def forward(self, input, ratio, inplace=False):
out = input if inplace else self.alloc()
return self.dispatch([input], [out],
callback=lambda ws, handle:
self.feed(ws, handle, ratio))
class DropBlock2d(Dropout):
"""DropBlock2d function."""
def __init__(self, key, dev, **kwargs):
super(DropBlock2d, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', 7)
self.keep_prob = kwargs.get('keep_prob', 0.9)
self.alpha = kwargs.get('alpha', 1.)
self.decrement = kwargs.get('decrement', 0.)
def attributes(self):
return {
'op_type': 'DropBlock2d',
'arguments': {
'ratio_desc': '${HANDLE}/ratio',
'block_size': self.block_size,
'keep_prob': self.keep_prob,
'alpha': self.alpha,
'decrement': self.decrement,
'data_format': 'NCHW',
}
}
class Dropout(_Activation):
"""Dropout function."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
self.p = kwargs.get('p', 0.5)
def attributes(self):
return {'op_type': 'Dropout', 'arguments': {'prob': self.p}}
class DropPath(_Activation):
class DropPath(Dropout):
"""DropPath function."""
def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs)
self.p = kwargs.get('p', 0.2)
self.increment = kwargs.get('increment', 0.)
def attributes(self):
return {
'op_type': 'DropPath',
'arguments': {
'prob': self.p,
'increment': self.increment,
}
'arguments': {'ratio_desc': '${HANDLE}/ratio'},
}
......
......@@ -25,10 +25,10 @@ class DropBlock2d(Module):
The **DropBlock** function is defined as:
.. math::
\text{DropBlock}(x_{ijk} =
x_{ijk} * (r_{ik} \sim \mathcal{B}(1, \alpha\gamma)) \\ \quad \\
\text{DropBlock}(x_{ijk}) =
x_{ijk} * (r_{ik} \sim \mathcal{B}(1, 1 - \gamma)) \\ \quad \\
\text{where}\quad \gamma =
\frac{\text{keep\_prob}}{\text{block\_size}^{n}}
\frac{p}{\text{block\_size}^{n}}
\frac{\text{feat\_size}^{n}}{(\text{feat\_size} - \text{block\_size} + 1)^n}
Examples:
......@@ -45,56 +45,35 @@ class DropBlock2d(Module):
"""
# Store the global unique slot index
_DEFAULT_UNIQUE_SLOT_ID = 0
def __init__(
self,
kp=0.9,
block_size=7,
alpha=1.,
decrement=0.,
inplace=False,
):
def __init__(self, p=0.1, block_size=7, inplace=False):
r"""Create a ``DropBlock2d`` module.
Parameters
----------
kp : float, optional, default=0.9
The keeping prob.
p : float, optional, default=0.1
The dropping ratio.
block_size : int, optional, default=7
The size of a spatial block.
alpha : float, optional, default=1.
The scale factor to :math:`\gamma`.
decrement : float, optional, default=0.
The decrement value to ``kp``.
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(DropBlock2d, self).__init__()
self.kp = kp
self.p = p
self.block_size = block_size
self.alpha = alpha
self.decrement = decrement
self.inplace = inplace
DropBlock2d._DEFAULT_UNIQUE_SLOT_ID += 1
self.slot = DropBlock2d._DEFAULT_UNIQUE_SLOT_ID
def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
return 'block_size={}, kp={}{}'.format(self.block_size, self.kp, inplace_str)
return 'p={}, block_size={}{}' \
.format(self.p, self.block_size, inplace_str)
def forward(self, input):
return F.drop_block2d(
input,
kp=self.kp,
input, self.p,
block_size=self.block_size,
alpha=self.alpha,
decrement=self.decrement,
training=self.training,
inplace=self.inplace,
slot=self.slot,
)
......@@ -104,7 +83,7 @@ class Dropout(Module):
The **Dropout** function is defined as:
.. math:: \text{Dropout}(x) = x * (r \sim \mathcal{B}(1, 1 - \text{prob}))
.. math:: \text{Dropout}(x) = x * (r \sim \mathcal{B}(1, 1 - p))
Examples:
......@@ -131,7 +110,7 @@ class Dropout(Module):
Parameters
----------
p : float, optional, default=0.5
The dropping prob.
The dropping ratio.
inplace : bool, optional, default=False
Whether to do the operation in-place.
......@@ -154,7 +133,7 @@ class DropPath(Module):
The **DropPath** function is defined as:
.. math:: \text{DropPath}(x_{ij}) = x_{ij} * (r_{i} \sim \mathcal{B}(1, 1 - \text{prob}))
.. math:: \text{DropPath}(x_{ij}) = x_{ij} * (r_{i} \sim \mathcal{B}(1, 1 - p))
Examples:
......@@ -170,39 +149,24 @@ class DropPath(Module):
"""
# Store the global unique slot index
_DEFAULT_UNIQUE_SLOT_ID = 0
def __init__(self, p=0.2, increment=0., inplace=False):
def __init__(self, p=0.2, inplace=False):
"""Create a ``DropPath`` module.
Parameters
----------
p : float, optional, default=0.2
The dropping prob.
increment : float, optional, default=0.
The increment value to ``p``.
The dropping ratio.
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(DropPath, self).__init__()
self.p = p
self.increment = increment
self.inplace = inplace
DropPath._DEFAULT_UNIQUE_SLOT_ID += 1
self.slot = DropPath._DEFAULT_UNIQUE_SLOT_ID
def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
return 'p={}{}'.format(self.p, inplace_str)
def forward(self, input):
return F.drop_path(
input,
p=self.p,
increment=self.increment,
training=self.training,
inplace=self.inplace,
slot=self.slot,
)
return F.drop_path(input, self.p, self.training, self.inplace)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!