Commit ac051717 by Ting PAN

Fix cuBLAS fp32 downcast issue on ampere devices

Summary:
This commit removes the default cuBLAS tensor core math mode
when CUDA >= 11.0 on ampere devices to avoid the FP32 downcast math.
1 parent b7e2298f
Showing with 1326 additions and 686 deletions
......@@ -19,6 +19,7 @@ from dragon.vm.dali.core.ops.builtin_ops import ExternalSource
from dragon.vm.dali.core.ops.decoder_ops import ImageDecoder
from dragon.vm.dali.core.ops.decoder_ops import ImageDecoderRandomCrop
from dragon.vm.dali.core.ops.generic_ops import Cast
from dragon.vm.dali.core.ops.generic_ops import Erase
from dragon.vm.dali.core.ops.generic_ops import Pad
from dragon.vm.dali.core.ops.generic_ops import Reshape
from dragon.vm.dali.core.ops.generic_ops import Slice
......@@ -31,6 +32,8 @@ from dragon.vm.dali.core.ops.image_ops import Paste
from dragon.vm.dali.core.ops.image_ops import RandomBBoxCrop
from dragon.vm.dali.core.ops.image_ops import RandomResizedCrop
from dragon.vm.dali.core.ops.image_ops import Resize
from dragon.vm.dali.core.ops.image_ops import Rotate
from dragon.vm.dali.core.ops.image_ops import WarpAffine
from dragon.vm.dali.core.ops.random_ops import CoinFlip
from dragon.vm.dali.core.ops.random_ops import Uniform
from dragon.vm.dali.core.ops.reader_ops import KPLRecordReader
......
......@@ -48,7 +48,7 @@ class Iterator(object):
with self._api_scope():
self._pipe.build()
# Enforce the correct device of current process
# to initialize cuda handles instead of device:0
# to initialize cuda handles instead of device 0.
cuda.set_device(self._pipe.device_id)
self._pipe.schedule_run()
self._copies = None
......@@ -91,12 +91,9 @@ class Iterator(object):
shape=tensor.shape(),
dtype=str(types.np_dtype(tensor.dtype())),
device=self.new_device(
device_type='cuda' if isinstance(
tensor, TensorGPU) else 'cpu',
device_index=self._pipe.device_id,
)
)
)
device_type=('cuda' if isinstance(tensor, TensorGPU)
else 'cpu'),
device_index=self._pipe.device_id)))
# Transfer the data: DALI => Storage
for i, tensor in enumerate(tensors):
self._transfer_tensor(tensor, self._copies[i])
......@@ -160,7 +157,7 @@ class Iterator(object):
def __next__(self):
"""Return the next batch of data."""
# Return and reset the first batch if necessary
# Return and reset the first batch if necessary.
if self._first_batch is not None:
outputs = self._first_batch
self._first_batch = None
......
......@@ -34,6 +34,8 @@ try:
num_threads=1,
seed=3,
prefetch_queue_depth=2,
py_num_workers=1,
**kwargs
):
"""Create a ``Pipeline``.
......@@ -47,6 +49,8 @@ try:
The seed for random generator.
prefetch_queue_depth : int, optional, default=2
The number of prefetch queues.
py_num_workers : int, optional, default=1
The number of workers to process external source.
"""
device_id = context.get_device()['device_index']
......@@ -56,6 +60,7 @@ try:
device_id=device_id,
seed=seed,
prefetch_queue_depth=prefetch_queue_depth,
**kwargs
)
@property
......@@ -68,7 +73,7 @@ try:
The batch size.
"""
return self._batch_size
return self._max_batch_size
@property
def device_id(self):
......@@ -83,6 +88,18 @@ try:
return self._device_id
@property
def max_batch_size(self):
"""Return the maximum batch size of pipeline.
Returns
-------
int
The maximum batch size.
"""
return self._max_batch_size
@property
def num_threads(self):
"""Return the number of threads to execute pipeline.
......@@ -94,26 +111,24 @@ try:
"""
return self._num_threads
def build(self):
"""Build the pipeline."""
super(Pipeline, self).build()
def define_graph(self):
"""Define the symbolic operations for pipeline."""
super(Pipeline, self).define_graph()
def feed_input(self, ref, data):
"""Bind an array to the edge reference.
def build(self, define_graph=None):
"""Build the pipeline.
Parameters
----------
ref : _EdgeReference
The reference of a edge.
data : numpy.ndarray
The array data.
define_graph : callable, optional
The defined function to use instead.
"""
super(Pipeline, self).feed_input(ref, data)
super(Pipeline, self).build(define_graph)
def define_graph(self):
"""Define the symbolic operations for pipeline."""
super(Pipeline, self).define_graph()
def feed_input(self, *args, **kwargs):
"""Bind an array to the edge reference."""
super(Pipeline, self).feed_input(*args, **kwargs)
except ImportError:
......@@ -134,6 +149,8 @@ except ImportError:
num_threads=1,
seed=3,
prefetch_queue_depth=2,
py_num_workers=1,
**kwargs
):
"""Create a ``Pipeline``
......@@ -147,9 +164,11 @@ except ImportError:
The seed for random generator.
prefetch_queue_depth : int, optional, default=2
The number of prefetch queues.
py_num_workers : int, optional, default=1
The number of workers to process external source.
"""
self._batch_size = batch_size
self._max_batch_size = batch_size
self._num_threads = num_threads
self._seed = seed
self._prefetch_queue_depth = prefetch_queue_depth
......@@ -164,7 +183,7 @@ except ImportError:
The batch size.
"""
return self._batch_size
return self._max_batch_size
@property
def device_id(self):
......@@ -179,6 +198,18 @@ except ImportError:
return 0
@property
def max_batch_size(self):
"""Return the maximum batch size of pipeline.
Returns
-------
int
The maximum batch size.
"""
return self._max_batch_size
@property
def num_threads(self):
"""Return the number of threads to execute pipeline.
......@@ -190,23 +221,21 @@ except ImportError:
"""
return self._num_threads
def build(self):
"""Build the pipeline."""
def build(self, define_graph=None):
"""Build the pipeline.
Parameters
----------
define_graph : callable, optional
The defined function to use instead.
"""
pass
def define_graph(self):
"""Define the symbolic operations for pipeline."""
pass
def feed_input(self, ref, data):
"""Bind an array to the edge reference.
Parameters
----------
ref : _EdgeReference
The reference of a edge.
data : numpy.ndarray
The array data.
"""
def feed_input(self, *args, **kwargs):
"""Bind an array to the edge reference."""
pass
......@@ -60,6 +60,60 @@ class Cast(object):
)
class Erase(object):
"""Erase regions from the input.
Examples:
```python
erase = dali.ops.Erase(
# The axes to erase
axes=[0, 1],
# The value fill
fill_value=0.,
)
y = erase(inputs['x'], anchor=(0, 0), shape=(100, 100))
```
"""
def __new__(
cls,
axes=(0, 1),
fill_value=0,
normalized_anchor=True,
normalized_shape=True,
**kwargs
):
"""Create an ``Erase`` operator.
Parameters
----------
axes : Sequence[int], optional
The padding axes.
fill_value : Union[number, Sequence[float]], optional
The value to fill the erased regions.
normalized_anchor : bool, optional, default=True
Provided anchor is normalized or not.
normalized_shape : bool, optional, default=True
Provided shape is normalized or not.
Returns
-------
nvidia.dali.ops.Erase
The operator.
"""
return ops.Erase(
axes=axes,
fill_value=fill_value,
normalized_anchor=normalized_anchor,
normalized_shape=normalized_shape,
device=context.get_device_type(),
**kwargs
)
class Pad(object):
"""Pad input to have the same dimensions.
......@@ -77,14 +131,14 @@ class Pad(object):
"""
def __new__(cls, axes=(0, 1), fill_value=0., align=None, **kwargs):
def __new__(cls, axes=(0, 1), fill_value=0, align=None, **kwargs):
"""Create a ``Pad`` operator.
Parameters
----------
axes : Sequence[int], optional
The padding axes.
fill_value : number, optional, default=0.
fill_value : number, optional, default=0
The constant padding value.
align : Union[int, Sequence[int]], optional
The size to align the padding shape.
......
......@@ -487,3 +487,88 @@ class Resize(object):
device=context.get_device_type(),
**kwargs
)
class Rotate(object):
"""Rotate the image.
Examples:
```python
rotate = dali.ops.Rotate()
y = rotate(inputs['x'], angle=30)
```
"""
def __new__(
cls,
fill_value=0,
interp_type='linear',
keep_size=True,
**kwargs
):
"""Create a ``Rotate`` operator.
Parameters
----------
fill_value : number, optional
The value to fill the empty regions.
interp_type : str, optional, default='linear'
The interpolation method.
keep_size : bool, optional, default=True
Whether to keep the original image size.
Returns
------
nvidia.dali.ops.Rotate
The operator.
"""
if isinstance(interp_type, six.string_types):
interp_type = getattr(types, 'INTERP_' + interp_type.upper())
return ops.Rotate(
fill_value=fill_value,
interp_type=interp_type,
keep_size=keep_size,
device=context.get_device_type(),
**kwargs
)
class WarpAffine(object):
"""Apply an affine transformation to the image.
Examples:
```python
warp_affine = dali.ops.WarpAffine()
y = warp_affine(inputs['x'], matrix=[1, 0, 0, 0, 1, 0])
```
"""
def __new__(cls, fill_value=0, interp_type='linear', **kwargs):
"""Create a ``WarpAffine`` operator.
Parameters
----------
fill_value : number, optional
The value to fill the empty regions.
interp_type : str, optional, default='linear'
The interpolation method.
Returns
------
nvidia.dali.ops.WarpAffine
The operator.
"""
if isinstance(interp_type, six.string_types):
interp_type = getattr(types, 'INTERP_' + interp_type.upper())
return ops.WarpAffine(
fill_value=fill_value,
interp_type=interp_type,
device=context.get_device_type(),
**kwargs
)
......@@ -10,14 +10,14 @@ __init__
Properties
----------
batch_size
##########
.. autoattribute:: dragon.vm.dali.Pipeline.batch_size
device_id
#########
.. autoattribute:: dragon.vm.dali.Pipeline.device_id
max_batch_size
##############
.. autoattribute:: dragon.vm.dali.Pipeline.max_batch_size
num_threads
###########
.. autoattribute:: dragon.vm.dali.Pipeline.num_threads
......
......@@ -30,6 +30,9 @@ vm.dali.ops
`class CropMirrorNormalize <ops/CropMirrorNormalize.html>`_
: Crop and normalize image with the horizontal flip.
`class Erase <ops/Erase.html>`_
: Erase regions from the input.
`class ExternalSource <ops/Cast.html>`_
: Create a placeholder providing data from feeding.
......@@ -60,6 +63,9 @@ vm.dali.ops
`class Resize <ops/Resize.html>`_
: Resize the image.
`class Rotate <ops/Rotate.html>`_
: Rotate the image.
`class Slice <ops/Slice.html>`_
: Select an interval of elements from input.
......@@ -72,6 +78,9 @@ vm.dali.ops
`class Uniform <ops/Uniform.html>`_
: Sample values from an uniform distribution.
`class WarpAffine <ops/WarpAffine.html>`_
: Apply an affine transformation to the image.
.. toctree::
:hidden:
......@@ -83,6 +92,7 @@ vm.dali.ops
ops/CoinFlip
ops/Contrast
ops/CropMirrorNormalize
ops/Erase
ops/ExternalSource
ops/Hsv
ops/ImageDecoder
......@@ -93,10 +103,12 @@ vm.dali.ops
ops/RandomResizedCrop
ops/Reshape
ops/Resize
ops/Rotate
ops/Slice
ops/KPLRecordReader
ops/TFRecordReader
ops/Uniform
ops/WarpAffine
.. raw:: html
......
Erase
=====
.. autoclass:: dragon.vm.dali.ops.Erase
__new__
--------
.. automethod:: dragon.vm.dali.ops.Erase.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
Rotate
======
.. autoclass:: dragon.vm.dali.ops.Rotate
__new__
--------
.. automethod:: dragon.vm.dali.ops.Rotate.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
WarpAffine
==========
.. autoclass:: dragon.vm.dali.ops.WarpAffine
__new__
--------
.. automethod:: dragon.vm.dali.ops.WarpAffine.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
......@@ -65,6 +65,10 @@ glorot_uniform
##############
.. automethod:: dragon.Tensor.glorot_uniform
item
####
.. automethod:: dragon.Tensor.item
normal
######
.. automethod:: dragon.Tensor.normal
......@@ -77,6 +81,10 @@ reshape
#######
.. automethod:: dragon.Tensor.reshape
tolist
######
.. automethod:: dragon.Tensor.tolist
truncated_normal
################
.. automethod:: dragon.Tensor.truncated_normal
......
......@@ -305,8 +305,12 @@ is_floating_point
#################
.. automethod:: dragon.vm.torch.Tensor.is_floating_point
item
####
.. automethod:: dragon.vm.torch.Tensor.item
le
###
##
.. automethod:: dragon.vm.torch.Tensor.le
log
......@@ -577,6 +581,10 @@ to
##
.. automethod:: dragon.vm.torch.Tensor.to
tolist
######
.. automethod:: dragon.vm.torch.Tensor.tolist
topk
####
.. automethod:: dragon.vm.torch.Tensor.topk
......
......@@ -56,7 +56,13 @@ class CUDAObjects {
auto& handle = handles[stream_id];
CUBLAS_CHECK(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSetStream(handle, stream(device_id, stream_id)));
#if CUDA_VERSION >= 9000
#if CUDA_VERSION >= 11000
if (cudnn_allow_tf32_) {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#elif CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
}
......
......@@ -97,11 +97,7 @@ __global__ void _MaskBlock2dNHWC(
math::Set(NxCxHxW, uint8_t(1), mask, ctx); \
math::Random(num_seeds, r, ctx); \
if (data_format == "NCHW") { \
_MaskBlock2dNCHW<<< \
CUDA_2D_BLOCKS(num_seeds), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
_MaskBlock2dNCHW<<<num_seeds, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
C, \
H, \
W, \
......@@ -113,11 +109,7 @@ __global__ void _MaskBlock2dNHWC(
r, \
mask); \
} else if (data_format == "NHWC") { \
_MaskBlock2dNHWC<<< \
CUDA_2D_BLOCKS(num_seeds), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
_MaskBlock2dNHWC<<<num_seeds, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
C, \
H, \
W, \
......
......@@ -211,7 +211,7 @@ __global__ void _PReluWGrad(
_PReluWGrad, \
math::ScalarType<T>::type, \
math::AccmulatorType<T>::type, \
CUDA_2D_BLOCKS(C), \
C, \
CUDA_THREADS, \
NxS, \
S, \
......
......@@ -30,14 +30,26 @@ void _Softmax(const int N, const int S, const int C, const T* x, T* y) {
}
}
template <>
void _Softmax<float16>(
const int N,
const int S,
const int C,
const float16* x,
float16* y) {
CPU_FP16_NOT_SUPPORTED;
template <typename T>
void _LogSoftmax(const int N, const int S, const int C, const T* x, T* y) {
if (S == 1) {
ConstEigenArrayMap<T> X(x, C, N);
EigenArrayMap<T> Y(y, C, N);
Y = X.rowwise() - X.colwise().maxCoeff();
Y = Y.rowwise() - Y.exp().colwise().sum().log();
return;
}
for (int i = 0; i < N; ++i) {
const auto offset = i * C * S;
for (int j = 0; j < S; ++j) {
ConstEigenStridedVectorArrayMap<T> X_vec(
x + offset + j, 1, C, EigenInnerStride(S));
EigenStridedVectorArrayMap<T> Y_vec(
y + offset + j, 1, C, EigenInnerStride(S));
Y_vec = X_vec - X_vec.maxCoeff();
Y_vec -= std::log(Y_vec.exp().sum());
}
}
}
template <typename T>
......@@ -69,36 +81,90 @@ void _SoftmaxGrad(
}
}
template <>
void _SoftmaxGrad<float16>(
template <typename T>
void _LogSoftmaxGrad(
const int N,
const int S,
const int C,
const float16* dy,
const float16* y,
float16* dx) {
CPU_FP16_NOT_SUPPORTED;
} // SoftmaxGrad
const T* dy,
const T* y,
T* dx) {
if (S == 1) {
ConstEigenArrayMap<T> dY(dy, C, N);
ConstEigenArrayMap<T> Y(y, C, N);
EigenArrayMap<T> dX(dx, C, N);
dX = dY - Y.exp().rowwise() * dY.colwise().sum();
return;
}
for (int i = 0; i < N; ++i) {
const auto offset = i * C * S;
for (int j = 0; j < S; ++j) {
ConstEigenStridedVectorArrayMap<T> dY_vec(
dy + offset + j, 1, C, EigenInnerStride(S));
ConstEigenStridedVectorArrayMap<T> Y_vec(
y + offset + j, 1, C, EigenInnerStride(S));
EigenStridedVectorArrayMap<T> dX_vec(
dx + offset + j, 1, C, EigenInnerStride(S));
dX_vec = dY_vec - Y_vec.exp() * dY_vec.sum();
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_##name(N, S, C, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const T* dy, \
const T* y, \
T* dx, \
CPUContext* ctx) { \
_##name(N, S, C, dy, y, dx); \
}
DEFINE_KERNEL_LAUNCHER(Softmax, float);
DEFINE_KERNEL_LAUNCHER(Softmax, double);
DEFINE_KERNEL_LAUNCHER(LogSoftmax, float);
DEFINE_KERNEL_LAUNCHER(LogSoftmax, double);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(LogSoftmaxGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(LogSoftmaxGrad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void Softmax<T, CPUContext>( \
void name<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_Softmax(N, S, C, x, y); \
CPU_FP16_NOT_SUPPORTED; \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void SoftmaxGrad<T, CPUContext>( \
void name<T, CPUContext>( \
const int N, \
const int S, \
const int C, \
......@@ -106,15 +172,13 @@ void _SoftmaxGrad<float16>(
const T* y, \
T* dx, \
CPUContext* ctx) { \
_SoftmaxGrad(N, S, C, dy, y, dx); \
CPU_FP16_NOT_SUPPORTED; \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
DEFINE_KERNEL_LAUNCHER(Softmax, float16);
DEFINE_KERNEL_LAUNCHER(LogSoftmax, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxGrad, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(LogSoftmaxGrad, float16);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
......
......@@ -11,12 +11,12 @@ namespace kernels {
namespace {
#define LDG2(x, i) convert::To<AccT>(__ldg(x + i))
#define LDG(x, i) convert::To<AccT>(__ldg(x + i))
template <typename T, typename AccT>
__global__ void
_Softmax(const int NxS, const int S, const int C, const T* x, T* y) {
__shared__ AccT block_val;
__shared__ AccT block_max, block_sum;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, NxS) {
const int offset = (i / S) * C * S + (i % S);
......@@ -25,28 +25,58 @@ _Softmax(const int NxS, const int S, const int C, const T* x, T* y) {
AccT val = convert::To<AccT>(__ldg(offset_x));
CUDA_2D_KERNEL_LOOP2(j, C) {
val = max(val, LDG2(offset_x, j * S));
val = max(val, LDG(offset_x, j * S));
}
val = BlockReduce<AccT>(storage).Reduce(val, cub::Max());
if (threadIdx.x == 0) block_val = val;
if (threadIdx.x == 0) block_max = val;
__syncthreads();
val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, C) {
val += exp(LDG(offset_x, j * S) - block_max);
}
val = BlockReduce<AccT>(storage).Sum(val);
if (threadIdx.x == 0) block_sum = val;
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, C) {
const int k = j * S;
offset_y[k] = convert::To<T>(exp(LDG2(offset_x, k) - block_val));
val = exp(LDG(offset_x, k) - block_max);
offset_y[k] = convert::To<T>(val / block_sum);
}
}
}
template <typename T, typename AccT>
__global__ void
_LogSoftmax(const int NxS, const int S, const int C, const T* x, T* y) {
__shared__ AccT block_max, block_sum;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, NxS) {
const int offset = (i / S) * C * S + (i % S);
auto* offset_x = x + offset;
auto* offset_y = y + offset;
AccT val = convert::To<AccT>(__ldg(offset_x));
CUDA_2D_KERNEL_LOOP2(j, C) {
val = max(val, LDG(offset_x, j * S));
}
val = BlockReduce<AccT>(storage).Reduce(val, cub::Max());
if (threadIdx.x == 0) block_max = val;
__syncthreads();
val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, C) {
val += convert::To<AccT>(offset_y[j * S]);
val += exp(LDG(offset_x, j * S) - block_max);
}
val = BlockReduce<AccT>(storage).Sum(val);
if (threadIdx.x == 0) block_val = val;
if (threadIdx.x == 0) block_sum = val;
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, C) {
const int k = j * S;
offset_y[k] = convert::To<T>(convert::To<AccT>(offset_y[k]) / block_val);
val = LDG(offset_x, k) - block_max;
offset_y[k] = convert::To<T>(val - log(block_sum));
}
}
}
......@@ -70,7 +100,39 @@ __global__ void _SoftmaxGrad(
AccT val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, C) {
const int k = j * S;
val += LDG2(offset_dy, k) * LDG2(offset_y, k);
val += LDG(offset_dy, k) * LDG(offset_y, k);
}
val = BlockReduce<AccT>(storage).Sum(val);
if (threadIdx.x == 0) block_val = val;
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, C) {
const int k = j * S;
val = LDG(offset_dy, k) - block_val;
offset_dx[k] = convert::To<T>(val * LDG(offset_y, k));
}
}
}
template <typename T, typename AccT>
__global__ void _LogSoftmaxGrad(
const int NxS,
const int S,
const int C,
const T* dy,
const T* y,
T* dx) {
__shared__ AccT block_val;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, NxS) {
const int offset = (i / S) * C * S + (i % S);
auto* offset_dy = dy + offset;
auto* offset_y = y + offset;
auto* offset_dx = dx + offset;
AccT val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, C) {
val += LDG(offset_dy, j * S);
}
val = BlockReduce<AccT>(storage).Sum(val);
if (threadIdx.x == 0) block_val = val;
......@@ -78,21 +140,21 @@ __global__ void _SoftmaxGrad(
CUDA_2D_KERNEL_LOOP2(j, C) {
const int k = j * S;
offset_dx[k] =
convert::To<T>((LDG2(offset_dy, k) - block_val) * LDG2(offset_y, k));
val = exp(convert::To<AccT>(offset_y[k])) * block_val;
offset_dx[k] = convert::To<T>(LDG(offset_dy, k) - val);
}
}
}
#undef LDG2
#undef LDG
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void Softmax<T, CUDAContext>( \
void name<T, CUDAContext>( \
const int N, \
const int S, \
const int C, \
......@@ -100,8 +162,8 @@ __global__ void _SoftmaxGrad(
T* y, \
CUDAContext* ctx) { \
const auto NxS = N * S; \
_Softmax<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_2D_BLOCKS(NxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
_##name<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<NxS, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxS, \
S, \
C, \
......@@ -109,9 +171,9 @@ __global__ void _SoftmaxGrad(
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void SoftmaxGrad<T, CUDAContext>( \
void name<T, CUDAContext>( \
const int N, \
const int S, \
const int C, \
......@@ -120,8 +182,8 @@ __global__ void _SoftmaxGrad(
T* dx, \
CUDAContext* ctx) { \
const auto NxS = N * S; \
_SoftmaxGrad<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<CUDA_2D_BLOCKS(NxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
_##name<math::ScalarType<T>::type, math::AccmulatorType<T>::type> \
<<<NxS, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxS, \
S, \
C, \
......@@ -130,12 +192,18 @@ __global__ void _SoftmaxGrad(
reinterpret_cast<math::ScalarType<T>::type*>(dx)); \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
DEFINE_GRAD_KERNEL_LAUNCHER(float16);
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
DEFINE_KERNEL_LAUNCHER(Softmax, float16);
DEFINE_KERNEL_LAUNCHER(Softmax, float);
DEFINE_KERNEL_LAUNCHER(Softmax, double);
DEFINE_KERNEL_LAUNCHER(LogSoftmax, float16);
DEFINE_KERNEL_LAUNCHER(LogSoftmax, float);
DEFINE_KERNEL_LAUNCHER(LogSoftmax, double);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxGrad, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(SoftmaxGrad, double);
DEFINE_GRAD_KERNEL_LAUNCHER(LogSoftmaxGrad, float16);
DEFINE_GRAD_KERNEL_LAUNCHER(LogSoftmaxGrad, float);
DEFINE_GRAD_KERNEL_LAUNCHER(LogSoftmaxGrad, double);
#undef DEFINE_KERNEL_LAUNCHER
#undef DEFINE_GRAD_KERNEL_LAUNCHER
......
......@@ -64,7 +64,7 @@ __global__ void _ArgReduce(
CUDAContext* ctx) { \
using ScalarT = math::ScalarType<T>::type; \
const auto NxS = N * S; \
_ArgReduce<<<CUDA_2D_BLOCKS(NxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
_ArgReduce<<<NxS, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxS, \
S, \
C, \
......
......@@ -136,7 +136,7 @@ __global__ void _GetTopK(
#define DISPATCH_BLOCKSORT_KERNEL(T, kItemsPerThread) \
_BlockSort<T, kItemsPerThread> \
<<<CUDA_2D_BLOCKS(NxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
<<<NxS, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxS, \
S, \
C, \
......
......@@ -17,9 +17,8 @@ void _Transpose(
const auto N =
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0);
int64_t xi;
for (int yi = 0; yi < N; ++yi) {
xi = 0;
int64_t xi = 0;
for (int d = num_dims - 1; d >= 0; --d) {
xi += index[d] * x_strides[d];
}
......@@ -28,27 +27,6 @@ void _Transpose(
}
}
template <typename T>
void _TransposeGrad(
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const T* dy,
T* dx) {
const auto N =
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0);
int64_t xi;
for (int yi = 0; yi < N; ++yi) {
xi = 0;
for (int d = num_dims - 1; d >= 0; --d) {
xi += index[d] * x_strides[d];
}
dx[xi] = dy[yi];
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
......@@ -73,9 +51,6 @@ DEFINE_KERNEL_LAUNCHER(Transpose, int64_t);
DEFINE_KERNEL_LAUNCHER(Transpose, float16);
DEFINE_KERNEL_LAUNCHER(Transpose, float);
DEFINE_KERNEL_LAUNCHER(Transpose, double);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, float16);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, float);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
......
......@@ -13,14 +13,14 @@ namespace {
template <typename T, int D>
__global__ void _Transpose(
const int N,
const int num_dims,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) {
#pragma unroll
for (int d = D - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
xi += r * X_strides.data[d];
......@@ -30,31 +30,29 @@ __global__ void _Transpose(
}
template <typename T, int D>
__global__ void _TransposeGrad(
void _TransposeImpl(
const int N,
const int num_dims,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims,
const T* dy,
T* dx) {
CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
xi += r * X_strides.data[d];
}
dx[xi] = dy[yi];
const int64_t* x_strides,
const int64_t* y_dims,
const T* x,
T* y,
CUDAContext* ctx) {
SimpleArray<int, D> X_strides, Y_dims;
for (int i = 0; i < D; ++i) {
X_strides.data[i] = x_strides[i];
Y_dims.data[i] = y_dims[i];
}
_Transpose<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
N, X_strides, Y_dims, x, y);
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void name<T, CUDAContext>( \
void Transpose<T, CUDAContext>( \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
......@@ -62,28 +60,46 @@ __global__ void _TransposeGrad(
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims; \
const auto N = std::accumulate( \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \
switch (num_dims) { \
case 1: \
_TransposeImpl<T, 1>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 2: \
_TransposeImpl<T, 2>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 3: \
_TransposeImpl<T, 3>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 4: \
_TransposeImpl<T, 4>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 5: \
_TransposeImpl<T, 5>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 6: \
_TransposeImpl<T, 6>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 7: \
_TransposeImpl<T, 7>(N, x_strides, y_dims, x, y, ctx); \
break; \
case 8: \
_TransposeImpl<T, 8>(N, x_strides, y_dims, x, y, ctx); \
break; \
default: \
break; \
} \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, num_dims, X_strides, Y_dims, x, y); \
}
DEFINE_KERNEL_LAUNCHER(Transpose, bool);
DEFINE_KERNEL_LAUNCHER(Transpose, uint8_t);
DEFINE_KERNEL_LAUNCHER(Transpose, int8_t);
DEFINE_KERNEL_LAUNCHER(Transpose, int);
DEFINE_KERNEL_LAUNCHER(Transpose, int64_t);
DEFINE_KERNEL_LAUNCHER(Transpose, float16);
DEFINE_KERNEL_LAUNCHER(Transpose, float);
DEFINE_KERNEL_LAUNCHER(Transpose, double);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, float16);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, float);
DEFINE_KERNEL_LAUNCHER(TransposeGrad, double);
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
......
#include "dragon/utils/device/common_openmp.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -16,18 +15,14 @@ void _RowwiseMoments(
AccT* mean,
AccT* var) {
const AccT scale = AccT(1) / AccT(rows);
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(cols))
#endif
for (int i = 0; i < cols; ++i) {
AccT x_val, m_val = AccT(0), v_val = AccT(0);
AccT m_val = AccT(0), v_val = AccT(0);
for (int j = 0; j < rows; ++j) {
x_val = convert::To<AccT>(x[j * cols + i]);
m_val += x_val;
v_val += x_val * x_val;
const AccT val = convert::To<AccT>(x[j * cols + i]);
m_val += val;
v_val += val * val;
}
m_val *= scale;
mean[i] = m_val;
mean[i] = m_val = m_val * scale;
var[i] = v_val * scale - m_val * m_val;
}
}
......@@ -40,18 +35,15 @@ void _ColwiseMoments(
AccT* mean,
AccT* var) {
const AccT scale = AccT(1) / AccT(cols);
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(rows))
#endif
for (int i = 0; i < rows; ++i) {
AccT x_val, m_val = AccT(0), v_val = AccT(0);
const int offset = i * cols;
AccT m_val = AccT(0), v_val = AccT(0);
for (int j = 0; j < cols; ++j) {
x_val = convert::To<AccT>(x[i * cols + j]);
m_val += x_val;
v_val += x_val * x_val;
const AccT val = convert::To<AccT>(x[offset + j]);
m_val += val;
v_val += val * val;
}
m_val *= scale;
mean[i] = m_val;
mean[i] = m_val = m_val * scale;
var[i] = v_val * scale - m_val * m_val;
}
}
......@@ -67,25 +59,20 @@ void _GenericMoments(
AccT* mean,
AccT* var) {
const AccT scale = AccT(1) / AccT(cols);
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(rows))
#endif
for (int i = 0; i < rows; ++i) {
AccT x_val, m_val = AccT(0), v_val = AccT(0);
int xi, c, r;
const int offset = i * cols;
AccT m_val = AccT(0), v_val = AccT(0);
for (int j = 0; j < cols; ++j) {
xi = 0;
c = i * cols + j;
int xi = 0, c = offset + j, r;
for (int d = num_dims - 1; d >= 0; --d) {
FIXED_DIVISOR_DIV_MOD(x_dims[d], c, &c, &r);
xi += r * x_strides[d];
}
x_val = convert::To<AccT>(x[xi]);
m_val += x_val;
v_val += x_val * x_val;
const AccT val = convert::To<AccT>(x[xi]);
m_val += val;
v_val += val * val;
}
m_val *= scale;
mean[i] = m_val;
mean[i] = m_val = m_val * scale;
var[i] = v_val * scale - m_val * m_val;
}
}
......
......@@ -11,9 +11,6 @@ namespace kernels {
namespace {
#define LDG(x, i) __ldg(x + i)
#define LDG2(x, i) convert::To<AccT>(__ldg(x + i))
template <typename T, typename AccT>
__global__ void _RowwiseMoments(
const int rows,
......@@ -27,16 +24,15 @@ __global__ void _RowwiseMoments(
CUDA_2D_KERNEL_LOOP1(i, cols) {
AccT m_val = AccT(0), v_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, rows) {
const int xi = j * cols + i;
m_val += LDG2(x, xi);
v_val += math::utils::Square(LDG2(x, xi));
const AccT val = convert::To<AccT>(x[j * cols + i]);
m_val += val;
v_val += val * val;
}
m_val = BlockReduce<AccT>(m_storage).Sum(m_val);
v_val = BlockReduce<AccT>(v_storage).Sum(v_val);
if (threadIdx.x == 0) {
const AccT mu = m_val * scale;
mean[i] = mu;
var[i] = v_val * scale - mu * mu;
mean[i] = m_val = m_val * scale;
var[i] = v_val * scale - m_val * m_val;
}
}
}
......@@ -54,16 +50,15 @@ __global__ void _ColwiseMoments(
CUDA_2D_KERNEL_LOOP1(i, rows) {
AccT m_val = AccT(0), v_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, cols) {
const int xi = i * cols + j;
m_val += LDG2(x, xi);
v_val += math::utils::Square(LDG2(x, xi));
const AccT val = convert::To<AccT>(x[i * cols + j]);
m_val += val;
v_val += val * val;
}
m_val = BlockReduce<AccT>(m_storage).Sum(m_val);
v_val = BlockReduce<AccT>(v_storage).Sum(v_val);
if (threadIdx.x == 0) {
const AccT mu = m_val * scale;
mean[i] = mu;
var[i] = v_val * scale - mu * mu;
mean[i] = m_val = m_val * scale;
var[i] = v_val * scale - m_val * m_val;
}
}
}
......@@ -90,15 +85,15 @@ __global__ void _GenericMoments(
FIXED_DIVISOR_DIV_MOD(X_dims.data[d], c, &c, &r);
xi += r * X_strides.data[d];
}
m_val += LDG2(x, xi);
v_val += math::utils::Square(LDG2(x, xi));
const AccT val = convert::To<AccT>(x[xi]);
m_val += val;
v_val += val * val;
}
m_val = BlockReduce<AccT>(m_storage).Sum(m_val);
v_val = BlockReduce<AccT>(v_storage).Sum(v_val);
if (threadIdx.x == 0) {
const AccT mu = m_val * scale;
mean[i] = mu;
var[i] = v_val * scale - mu * mu;
mean[i] = m_val = m_val * scale;
var[i] = v_val * scale - m_val * m_val;
}
}
}
......@@ -120,20 +115,14 @@ void _Moments(
}
if (math::utils::IsRowwiseReduce(
num_dims, dims, out_dims.data(), &rows, &cols)) {
_RowwiseMoments<<<
CUDA_2D_BLOCKS(cols),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(rows, cols, x, mean, var);
_RowwiseMoments<<<cols, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
rows, cols, x, mean, var);
return;
}
if (math::utils::IsColwiseReduce(
num_dims, dims, out_dims.data(), &rows, &cols)) {
_ColwiseMoments<<<
CUDA_2D_BLOCKS(rows),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(rows, cols, x, mean, var);
_ColwiseMoments<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
rows, cols, x, mean, var);
return;
}
CUDA_TENSOR_DIMS_CHECK(num_dims);
......@@ -155,17 +144,10 @@ void _Moments(
for (int i = 0; i < num_dims; ++i) {
transpose_dims.data[i] = dims[transpose_axes.data[i]];
}
_GenericMoments<<<
CUDA_2D_BLOCKS(rows),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
_GenericMoments<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>(
rows, cols, num_dims, transpose_dims, transpose_strides, x, mean, var);
}
#undef LDG
#undef LDG2
} // namespace
/* ------------------- Launcher Separator ------------------- */
......
......@@ -218,7 +218,7 @@ __global__ void _BatchNormInferenceGrad(
_BatchNormExpectation, \
math::ScalarType<T>::type, \
AccT, \
CUDA_2D_BLOCKS(C), \
C, \
CUDA_THREADS, \
N, \
C, \
......@@ -245,7 +245,7 @@ __global__ void _BatchNormInferenceGrad(
_BatchNormWGrad, \
math::ScalarType<T>::type, \
AccT, \
CUDA_2D_BLOCKS(C), \
C, \
CUDA_THREADS, \
N, \
C, \
......@@ -314,7 +314,7 @@ __global__ void _BatchNormInferenceGrad(
_BatchNormWGrad, \
math::ScalarType<T>::type, \
AccT, \
CUDA_2D_BLOCKS(C), \
C, \
CUDA_THREADS, \
N, \
C, \
......
......@@ -230,7 +230,7 @@ __global__ void _GroupNormGrad(
_GroupNormWGrad, \
math::ScalarType<T>::type, \
AccT, \
CUDA_2D_BLOCKS(G* D), \
G* D, \
CUDA_THREADS, \
N, \
G, \
......@@ -246,7 +246,7 @@ __global__ void _GroupNormGrad(
_GroupNormInternalGrad, \
math::ScalarType<T>::type, \
AccT, \
CUDA_2D_BLOCKS(N* G), \
N* G, \
CUDA_THREADS, \
N, \
G, \
......
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, typename AccT>
void _LayerNorm(
const int N,
const int C,
const AccT epsilon,
const T* x,
const AccT* gamma,
const AccT* beta,
AccT* mu,
AccT* rsig,
T* y) {
const AccT scale = AccT(1) / AccT(C);
for (int i = 0; i < N; ++i) {
const int offset = i * C;
AccT m_val = AccT(0), v_val = AccT(0);
for (int j = 0; j < C; ++j) {
const AccT val = convert::To<AccT>(x[offset + j]);
m_val += val;
v_val += val * val;
}
mu[i] = m_val = m_val * scale;
v_val = std::sqrt(v_val * scale - m_val * m_val + epsilon);
rsig[i] = v_val = AccT(1) / v_val;
for (int j = 0; j < C; ++j) {
AccT val = convert::To<AccT>(x[offset + j]);
val = (val - m_val) * v_val;
y[offset + j] = convert::To<T>(val * gamma[j] + beta[j]);
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \
void LayerNorm<T, AccT, CPUContext>( \
const int N, \
const int C, \
const float epsilon, \
const T* x, \
const AccT* gamma, \
const AccT* beta, \
AccT* mu, \
AccT* rsig, \
T* y, \
CPUContext* ctx) { \
_LayerNorm(N, C, AccT(epsilon), x, gamma, beta, mu, rsig, y); \
}
DEFINE_KERNEL_LAUNCHER(float16, float);
DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, typename AccT>
__global__ void _LayerNorm(
const int N,
const int C,
const AccT epsilon,
const T* x,
const AccT* gamma,
const AccT* beta,
AccT* mu,
AccT* rsig,
T* y) {
__shared__ AccT block_mu, block_rsig;
__shared__ typename BlockReduce<AccT>::TempStorage m_storage;
__shared__ typename BlockReduce<AccT>::TempStorage v_storage;
const AccT scale = AccT(1) / AccT(C);
CUDA_2D_KERNEL_LOOP1(i, N) {
AccT m_val = AccT(0), v_val = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, C) {
const AccT val = convert::To<AccT>(__ldg(x + i * C + j));
m_val += val;
v_val += val * val;
}
m_val = BlockReduce<AccT>(m_storage).Sum(m_val);
v_val = BlockReduce<AccT>(v_storage).Sum(v_val);
if (threadIdx.x == 0) {
mu[i] = block_mu = m_val = m_val * scale;
rsig[i] = block_rsig = rsqrt(v_val * scale - m_val * m_val + epsilon);
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, C) {
const int index = i * C + j;
m_val = convert::To<AccT>(__ldg(x + index));
m_val = (m_val - block_mu) * block_rsig;
y[index] = convert::To<T>(fma(m_val, __ldg(gamma + j), __ldg(beta + j)));
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T, AccT) \
template <> \
void LayerNorm<T, AccT, CUDAContext>( \
const int N, \
const int C, \
const float epsilon, \
const T* x, \
const AccT* gamma, \
const AccT* beta, \
AccT* mu, \
AccT* rsig, \
T* y, \
CUDAContext* ctx) { \
_LayerNorm<<<N, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
C, \
AccT(epsilon), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
gamma, \
beta, \
mu, \
rsig, \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
}
DEFINE_KERNEL_LAUNCHER(float16, float);
DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(double, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
......@@ -160,7 +160,7 @@ __global__ void _L2NormalizeGrad(
CUDAContext* ctx) { \
const auto NxS = N * S; \
_##name<math::ScalarType<T>::type, AccT> \
<<<CUDA_2D_BLOCKS(NxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
<<<NxS, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxS, \
S, \
C, \
......@@ -184,7 +184,7 @@ __global__ void _L2NormalizeGrad(
CUDAContext* ctx) { \
const auto NxS = N * S; \
_##name<math::ScalarType<T>::type, AccT> \
<<<CUDA_2D_BLOCKS(NxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
<<<NxS, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxS, \
S, \
C, \
......
......@@ -415,7 +415,7 @@ DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, double);
_Im2ColNd, \
kTransposed, \
math::ScalarType<T>::type, \
CUDA_2D_BLOCKS(outer_dim), \
outer_dim, \
CUDA_THREADS, \
channels, \
kernel_dim, \
......
......@@ -56,18 +56,18 @@ void _RoiAlign(
const T* x,
const float* rois,
T* y) {
auto x_inner_dim = H * W;
auto y_inner_dim = out_h * out_w;
auto x_cols = C * x_inner_dim;
auto y_cols = C * y_inner_dim;
const auto HxW = H * W;
const auto HoxWo = out_h * out_w;
const auto CxHxW = C * HxW;
const auto CxHoxWo = C * HoxWo;
for (int n = 0; n < num_rois; ++n) {
auto* roi = rois + n * 5;
int batch_ind = (int)roi[0];
auto* offset_y = y + n * y_cols;
auto* offset_y = y + n * CxHoxWo;
if (batch_ind < 0) {
memset(offset_y, 0, sizeof(T) * y_cols);
memset(offset_y, 0, sizeof(T) * CxHoxWo);
continue;
}
......@@ -78,19 +78,21 @@ void _RoiAlign(
const float roi_w = std::max(roi_wend - roi_wstart, 1.f);
const float roi_h = std::max(roi_hend - roi_hstart, 1.f);
const float bin_h = roi_h / (float)out_h;
const float bin_w = roi_w / (float)out_w;
const int grid_h =
sampling_ratio > 0 ? sampling_ratio : (int)std::ceil(roi_h / out_h);
const int grid_w =
sampling_ratio > 0 ? sampling_ratio : (int)std::ceil(roi_w / out_w);
const float bin_h = roi_h / float(out_h);
const float bin_w = roi_w / float(out_w);
const int grid_h = sampling_ratio > 0
? sampling_ratio
: int(std::ceil(roi_h / float(out_h)));
const int grid_w = sampling_ratio > 0
? sampling_ratio
: int(std::ceil(roi_w / float(out_w)));
const T num_grids = T(grid_h * grid_w);
int yi;
T val;
float hstart, wstart, h, w;
const T* offset_x = x + batch_ind * x_cols;
const T* offset_x = x + batch_ind * CxHxW;
for (int c = 0; c < C; ++c) {
yi = 0;
......@@ -109,8 +111,8 @@ void _RoiAlign(
offset_y[yi++] = val / num_grids;
}
} // End h_out && w_out
offset_x += x_inner_dim;
offset_y += y_inner_dim;
offset_x += HxW;
offset_y += HoxWo;
} // End c
} // End n
}
......
......@@ -123,16 +123,16 @@ __global__ void _RoiAlign(
const float roi_w = max(roi_wend - roi_wstart, 1.f);
const float roi_h = max(roi_hend - roi_hstart, 1.f);
const float bin_h = roi_h / (float)out_h;
const float bin_w = roi_w / (float)out_w;
const float bin_h = roi_h / float(out_h);
const float bin_w = roi_w / float(out_w);
const float hstart = roi_hstart + h_out * bin_h;
const float wstart = roi_wstart + w_out * bin_w;
const int grid_h =
sampling_ratio > 0 ? sampling_ratio : ceil(roi_h / out_h);
sampling_ratio > 0 ? sampling_ratio : int(ceil(roi_h / float(out_h)));
const int grid_w =
sampling_ratio > 0 ? sampling_ratio : ceil(roi_w / out_w);
sampling_ratio > 0 ? sampling_ratio : int(ceil(roi_w / float(out_w)));
const T* offset_x = x + (batch_ind * C + c) * H * W;
AccT val = AccT(0);
......@@ -178,16 +178,16 @@ __global__ void _RoiAlignGrad(
const float roi_w = max(roi_wend - roi_wstart, 1.f);
const float roi_h = max(roi_hend - roi_hstart, 1.f);
const float bin_h = roi_h / (float)out_h;
const float bin_w = roi_w / (float)out_w;
const float bin_h = roi_h / float(out_h);
const float bin_w = roi_w / float(out_w);
const float hstart = roi_hstart + h_out * bin_h;
const float wstart = roi_wstart + w_out * bin_w;
const int grid_h =
sampling_ratio > 0 ? sampling_ratio : ceil(roi_h / out_h);
sampling_ratio > 0 ? sampling_ratio : int(ceil(roi_h / float(out_h)));
const int grid_w =
sampling_ratio > 0 ? sampling_ratio : ceil(roi_w / out_w);
sampling_ratio > 0 ? sampling_ratio : int(ceil(roi_w / float(out_w)));
const float dyi = convert::To<float>(dy[yi]) / float(grid_h * grid_w);
float* offset_dx = dx + (batch_ind * C + c) * H * W;
......
......@@ -19,21 +19,21 @@ void _RoiPool(
const float* rois,
int* mask,
T* y) {
auto x_inner_dim = H * W;
auto y_inner_dim = out_h * out_w;
auto x_cols = C * x_inner_dim;
auto y_cols = C * y_inner_dim;
const auto HxW = H * W;
const auto HoxWo = out_h * out_w;
const auto CxHxW = C * HxW;
const auto CxHoxWo = C * HoxWo;
for (int n = 0; n < num_rois; ++n) {
auto* roi = rois + n * 5;
auto* offset_y = y + n * y_cols;
auto* offset_mask = mask + n * y_cols;
auto* offset_y = y + n * CxHoxWo;
auto* offset_mask = mask + n * CxHoxWo;
const int batch_ind = (int)roi[0];
if (batch_ind < 0) {
memset(offset_y, 0, sizeof(T) * y_cols);
memset(offset_mask, -1, sizeof(int) * y_cols);
memset(offset_y, 0, sizeof(T) * CxHoxWo);
memset(offset_mask, -1, sizeof(int) * CxHoxWo);
continue;
}
......@@ -44,14 +44,14 @@ void _RoiPool(
const int roi_w = std::max(roi_wend - roi_wstart + 1, 1);
const int roi_h = std::max(roi_hend - roi_hstart + 1, 1);
const float bin_h = (float)roi_h / (float)out_h;
const float bin_w = (float)roi_w / (float)out_w;
const float bin_h = float(roi_h) / float(out_h);
const float bin_w = float(roi_w) / float(out_w);
T val;
bool empty;
int xi, yi, mask_val;
int hstart, wstart, hend, wend;
const T* offset_x = x + batch_ind * x_cols;
const T* offset_x = x + batch_ind * CxHxW;
for (int c = 0; c < C; ++c) {
yi = 0;
......@@ -82,9 +82,9 @@ void _RoiPool(
offset_mask[yi++] = mask_val;
}
} // End h_out && w_out
offset_x += x_inner_dim;
offset_y += y_inner_dim;
offset_mask += y_inner_dim;
offset_x += HxW;
offset_y += HoxWo;
offset_mask += HoxWo;
} // End c
} // End n
}
......
#include "dragon/operators/activation/log_softmax_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void LogSoftmaxOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
kernels::LogSoftmax(
X.count(0, axis),
X.count(axis + 1),
X.dim(axis),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
template <typename T>
void LogSoftmaxGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
GET_OP_AXIS_ARG(axis, Y.ndim(), -1);
kernels::LogSoftmaxGrad(
Y.count(0, axis),
Y.count(axis + 1),
Y.dim(axis),
dY.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(Y)->template mutable_data<T, Context>(),
ctx());
}
DEPLOY_CPU_OPERATOR(LogSoftmax);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(LogSoftmax);
#endif
DEPLOY_CPU_OPERATOR(LogSoftmaxGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(LogSoftmaxGradient);
#endif
OPERATOR_SCHEMA(LogSoftmax)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(LogSoftmaxGradient)
/* Y, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1)
/* dY => dX */
.AllowInplace({{1, 0}});
REGISTER_GRADIENT(LogSoftmax, InplaceGradientMaker);
} // namespace dragon
......@@ -10,37 +10,41 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_VISION_DEPTH_TO_SPACE_OP_H_
#define DRAGON_OPERATORS_VISION_DEPTH_TO_SPACE_OP_H_
#ifndef DRAGON_OPERATORS_ACTIVATION_LOG_SOFTMAX_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_LOG_SOFTMAX_OP_H_
#include "dragon/operators/array/transpose_op.h"
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class DepthToSpaceOp final : public Operator<Context> {
class LogSoftmaxOp : public Operator<Context> {
public:
DepthToSpaceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
block_size_(OP_SINGLE_ARG(int, "block_size", 2)) {}
SIMPLE_CTOR_DTOR(LogSoftmaxOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
int64_t block_size_;
};
template <class Context>
class DepthToSpaceGradientOp final : public TransposeGradientOp<Context> {
class LogSoftmaxGradientOp : public Operator<Context> {
public:
DepthToSpaceGradientOp(const OperatorDef& def, Workspace* ws)
: TransposeGradientOp<Context>(def, ws) {}
SIMPLE_CTOR_DTOR(LogSoftmaxGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_VISION_DEPTH_TO_SPACE_OP_H_
#endif // DRAGON_OPERATORS_ACTIVATION_LOG_SOFTMAX_OP_H_
......@@ -18,11 +18,6 @@ void SoftmaxOp<Context>::DoRunWithType() {
}
template <class Context>
void SoftmaxOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void SoftmaxGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -37,11 +32,6 @@ void SoftmaxGradientOp<Context>::DoRunWithType() {
ctx());
}
template <class Context>
void SoftmaxGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Softmax);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Softmax);
......
......@@ -23,7 +23,9 @@ class SoftmaxOp : public Operator<Context> {
SIMPLE_CTOR_DTOR(SoftmaxOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -35,10 +37,12 @@ class SoftmaxGradientOp : public Operator<Context> {
SIMPLE_CTOR_DTOR(SoftmaxGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
void RunOnDevice() override;
};
#ifdef USE_CUDNN
......
......@@ -14,69 +14,60 @@ void RepeatOp<Context>::DoRunWithType() {
// Determine the repeat scheme
// 1) Repeat to a flatten vector if axis is not specified
// 2) Repeat along the specified axis
int64_t outer_dim, axis_dim, inner_dim;
int64_t N, C, S;
int64_t reps = repeats();
if (axis == INT_MAX) {
outer_dim = inner_dim = 1;
axis_dim = X.count();
Y->Reshape({axis_dim * repeats()});
N = S = 1;
C = X.count();
Y->Reshape({C * reps});
} else {
axis_dim = X.dim(axis);
outer_dim = X.count(0, axis);
inner_dim = X.count(axis + 1);
C = X.dim(axis);
N = X.count(0, axis);
S = X.count(axis + 1);
auto Y_dims = X.dims();
Y_dims[axis] *= repeats();
Y_dims[axis] *= reps;
Y->Reshape(Y_dims);
}
// Dispatch the repeat kenrel
kernels::Repeat(
outer_dim,
inner_dim,
axis_dim,
repeats(),
N,
S,
C,
reps,
X.template data<T, Context>(),
Y->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void RepeatOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void RepeatGradientOp<Context>::DoRunWithType() {
auto &X = INPUT_SPEC(0), &dY = Input(0), *dX = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), INT_MAX);
// Determine the repeat scheme
int64_t outer_dim, axis_dim, inner_dim;
int64_t N, C, S;
if (axis == INT_MAX) {
outer_dim = inner_dim = 1;
axis_dim = X.count();
N = S = 1;
C = X.count();
} else {
outer_dim = X.count(0, axis);
axis_dim = X.dim(axis);
inner_dim = X.count(axis + 1);
N = X.count(0, axis);
C = X.dim(axis);
S = X.count(axis + 1);
}
// Reduce the gradient along the axis
kernels::RepeatGrad(
outer_dim,
inner_dim,
axis_dim,
N,
S,
C,
repeats(),
dY.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void RepeatGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Repeat);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Repeat);
......
......@@ -20,12 +20,15 @@ namespace dragon {
template <class Context>
class RepeatOp final : public Operator<Context> {
public:
RepeatOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
explicit RepeatOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
INITIALIZE_OP_SINGLE_ARG(int64_t, repeats, 1);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -43,7 +46,9 @@ class RepeatGradientOp final : public Operator<Context> {
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......
......@@ -17,16 +17,22 @@ void TransposeOp<Context>::DoRunWithType() {
<< "\nProviding " << num_axes << " dimensions to permute, "
<< "while Tensor(" << X.name() << ")'s dims are " << X.DimString();
vec64_t new_axes(num_dims);
for (int i = 0; i < num_dims; ++i) {
auto axis = num_axes > 0 ? perm(i) : num_dims - i - 1;
X_strides[i] = X.stride(axis);
Y_dims[i] = X.dim(axis);
new_axes[i] = num_axes > 0 ? perm(i) : num_dims - i - 1;
}
// Store for the gradient calculation
SET_INPUT_SPEC(0);
Buffer("X_strides")->template CopyFrom<int64_t>(X_strides);
Buffer("Y_dims")->template CopyFrom<int64_t>(Y_dims);
if (def().type() == "TransposeGradient") {
auto old_axes(new_axes);
for (int i = 0; i < num_dims; ++i) {
new_axes[old_axes[i]] = i;
}
}
for (int i = 0; i < num_dims; ++i) {
X_strides[i] = X.stride(new_axes[i]);
Y_dims[i] = X.dim(new_axes[i]);
}
kernels::Transpose(
num_dims,
......@@ -37,43 +43,11 @@ void TransposeOp<Context>::DoRunWithType() {
ctx());
}
template <class Context>
void TransposeOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void TransposeGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0);
dX->ReshapeLike(INPUT_SPEC(0));
vec64_t X_strides, Y_dims;
Buffer("X_strides")->template CopyTo<int64_t>(X_strides);
Buffer("Y_dims")->template CopyTo<int64_t>(Y_dims);
kernels::TransposeGrad(
X_strides.size(),
X_strides.data(),
Y_dims.data(),
dY.template data<T, Context>(),
dX->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void TransposeGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Transpose);
REGISTER_CPU_OPERATOR(TransposeGradient, TransposeOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Transpose);
#endif
DEPLOY_CPU_OPERATOR(TransposeGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(TransposeGradient);
REGISTER_CUDA_OPERATOR(TransposeGradient, TransposeOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(Transpose)
......
......@@ -26,7 +26,9 @@ class TransposeOp final : public Operator<Context> {
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -35,18 +37,6 @@ class TransposeOp final : public Operator<Context> {
DECLARE_OP_REPEATED_ARG(int64_t, perm);
};
template <class Context>
class TransposeGradientOp : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(TransposeGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
};
DEFINE_OP_REPEATED_ARG(int64_t, TransposeOp, perm);
} // namespace dragon
......
......@@ -55,6 +55,7 @@ DISPATCH_WITH_TENSOR_TYPES(GreaterEqual, dtypes::Generic, Input(0));
ctx()); \
}
DEFINE_SIMPLE_UNARY_OP_IMPL(Log, T);
DEFINE_SIMPLE_UNARY_OP_IMPL(Sin, T);
DEFINE_SIMPLE_UNARY_OP_IMPL(Cos, T);
DEFINE_SIMPLE_UNARY_OP_IMPL(Square, T);
......@@ -83,7 +84,6 @@ DEFINE_INPLACE_UNARY_OP_IMPL(Sign, T);
DEFINE_INPLACE_UNARY_OP_IMPL(Sqrt, T);
DEFINE_INPLACE_UNARY_OP_IMPL(Rsqrt, T);
DEFINE_INPLACE_UNARY_OP_IMPL(Exp, T);
DEFINE_INPLACE_UNARY_OP_IMPL(Log, T);
DEFINE_INPLACE_UNARY_OP_IMPL(BitwiseNot, T);
#undef DEFINE_INPLACE_UNARY_OP_IMPL
......
......@@ -3,36 +3,35 @@
namespace dragon {
template <class Context>
template <typename LogitT, typename TargetT>
template <typename InputT, typename TargetT>
void AccuracyOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
auto &X = Input(0), &Y = Input(1), *R = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
auto outer_dim = X.count(0, axis);
auto axis_dim = X.dim(axis);
auto inner_dim = X.count(axis + 1);
const auto C = X.dim(axis);
const auto N = X.count(0, axis);
const auto S = X.count(axis + 1);
const auto NxS = N * S;
const auto CxS = C * S;
CHECK_EQ(Y.count(), NxS) << "\nNumel of X and Y must be matched.";
CHECK_EQ(outer_dim * inner_dim, Input(1).count())
<< "\nNumber of preds must match the number of targets.";
auto* input = X.template data<InputT, CPUContext>();
auto* target = Y.template data<TargetT, CPUContext>();
int64_t acc = 0, count = 0;
int64_t cols = X.count() / outer_dim;
auto* logit = X.template data<LogitT, CPUContext>();
auto* target = Input(1).template data<TargetT, CPUContext>();
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < inner_dim; ++j) {
const int label = target[i * inner_dim + j];
for (int i = 0; i < N; ++i) {
for (int j = 0; j < S; ++j) {
const int label = target[i * S + j];
if (label == ignore_index_) continue;
vector<pair<LogitT, int>> vec;
for (int k = 0; k < axis_dim; k++)
vec.push_back(std::make_pair(logit[i * cols + k * inner_dim + j], k));
vector<pair<InputT, int>> vec;
for (int k = 0; k < C; ++k) {
vec.push_back(std::make_pair(input[i * CxS + k * S + j], k));
}
std::partial_sort(
vec.begin(),
vec.begin() + top_k_,
vec.end(),
std::greater<pair<LogitT, int>>());
std::greater<pair<InputT, int>>());
for (int k = 0; k < top_k_; k++) {
if (vec[k].second == label) {
acc++;
......@@ -40,10 +39,10 @@ void AccuracyOp<Context>::DoRunWithType() {
}
}
count++;
} // End inner_dim
} // End outer_dim
}
}
Y->Reshape({})->template mutable_data<float, CPUContext>()[0] =
R->Reshape({})->template mutable_data<float, CPUContext>()[0] =
(float)acc / (float)count;
}
......@@ -79,9 +78,9 @@ DEPLOY_CUDA_OPERATOR(Accuracy);
#endif
OPERATOR_SCHEMA(Accuracy)
/* X, T */
/* X, Y */
.NumInputs(2)
/* Y */
/* R */
.NumOutputs(1);
NO_GRADIENT(Accuracy);
......
......@@ -284,13 +284,13 @@ DEPLOY_CUDA_OPERATOR(BatchNormGradient);
#endif
OPERATOR_SCHEMA(BatchNorm)
/* X, W, B, RunningMean, RunningVar */
/* X, W, B, RM, RV */
.NumInputs(5)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(BatchNormGradient)
/* X, W, RunningMean, RunningVar, dY */
/* X, W, RM, RV, dY */
.NumInputs(5)
/* dX, dW, dB */
.NumOutputs(3);
......
......@@ -9,12 +9,16 @@ template <class Context>
template <typename T>
void GroupNormOp<Context>::DoRunWithType() {
using ParamT = typename math::AccmulatorType<T>::type;
INITIALIZE_TENSOR_VIA_SPEC(Input(1), vec64_t({C_}), ParamT);
INITIALIZE_TENSOR_VIA_SPEC(Input(2), vec64_t({C_}), ParamT);
auto &X = Input(0), *Y = Output(0);
auto &W = Input(1), &B = Input(2);
GetBaseArguments();
INITIALIZE_TENSOR_VIA_SPEC(W, vec64_t({C_}), ParamT);
INITIALIZE_TENSOR_VIA_SPEC(B, vec64_t({C_}), ParamT);
auto* X_mu = Buffer("X_mu")->Reshape({N_, G_});
auto* X_rsig = Buffer("X_rsig")->Reshape({N_, G_});
auto* x = Input(0).template data<T, Context>();
auto* x = X.template data<T, Context>();
auto* mu = X_mu->template mutable_data<ParamT, Context>();
auto* rsig = X_rsig->template mutable_data<ParamT, Context>();
......@@ -29,10 +33,10 @@ void GroupNormOp<Context>::DoRunWithType() {
kernels::Moments(4, dims.data(), 2, axes.data(), x, mu, rsig, ctx());
}
// Inverse stddev from variance
// Inverse stddev from variance.
math::InvStd(N_ * G_, epsilon_, rsig, rsig, ctx());
// Fuse parameters to compute affine transformation
// Fuse parameters to compute affine transformation.
auto* scratch =
ctx()->workspace()->template data<ParamT, Context>({2 * N_ * C_})[0];
kernels::GroupNorm(
......@@ -44,29 +48,24 @@ void GroupNormOp<Context>::DoRunWithType() {
x,
mu,
rsig,
Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<ParamT, Context>(), // beta
W.template data<ParamT, Context>(),
B.template data<ParamT, Context>(),
scratch,
scratch + N_ * C_,
Output(0)->template mutable_data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void GroupNormOp<Context>::RunOnDevice() {
GetBaseArguments();
Output(0)->ReshapeLike(Input(0));
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void GroupNormGradientOp<Context>::DoRunWithType() {
using ParamT = typename math::AccmulatorType<T>::type;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
GetBaseArguments();
// Gradient w.r.t. gamma, beta and input
// Gradient w.r.t. gamma, beta and input.
auto* scratch =
ctx()->workspace()->template data<ParamT, Context>({2 * N_ * G_})[0];
kernels::GroupNormGrad(
......@@ -75,26 +74,19 @@ void GroupNormGradientOp<Context>::DoRunWithType() {
D_,
S_,
data_format(),
Input(0).template data<T, Context>(), // x
X.template data<T, Context>(),
X_mu->template data<ParamT, Context>(),
X_rsig->template data<ParamT, Context>(),
Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<T, Context>(), // dy
W.template data<ParamT, Context>(),
dY.template data<T, Context>(),
scratch,
scratch + N_ * G_,
dW->Reshape({C_})->template mutable_data<ParamT, Context>(),
dB->Reshape({C_})->template mutable_data<ParamT, Context>(),
dX->template mutable_data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void GroupNormGradientOp<Context>::RunOnDevice() {
GetBaseArguments();
Output(0)->ReshapeLike(Input(0));
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(GroupNorm);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(GroupNorm);
......
......@@ -64,7 +64,9 @@ class GroupNormOp : public GroupNormOpBase<Context> {
USE_OPERATOR_FUNCTIONS;
USE_GROUPNORM_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -78,7 +80,9 @@ class GroupNormGradientOp : public GroupNormOpBase<Context> {
USE_OPERATOR_FUNCTIONS;
USE_GROUPNORM_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......
#include "dragon/operators/normalization/layer_norm_op.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void LayerNormOp<Context>::DoRunWithType() {
using ParamT = typename math::AccmulatorType<T>::type;
auto &X = Input(0), *Y = Output(0);
auto &W = Input(1), &B = Input(2);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
const auto N = X.count(0, axis);
const auto C = X.count(axis);
INITIALIZE_TENSOR_VIA_SPEC(W, vec64_t({C}), ParamT);
INITIALIZE_TENSOR_VIA_SPEC(B, vec64_t({C}), ParamT);
auto* X_mu = Buffer("X_mu")->Reshape({N});
auto* X_rsig = Buffer("X_rsig")->Reshape({N});
kernels::LayerNorm(
N,
C,
epsilon_,
X.template data<T, Context>(),
W.template data<ParamT, Context>(),
B.template data<ParamT, Context>(),
X_mu->template mutable_data<ParamT, Context>(),
X_rsig->template mutable_data<ParamT, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
DEPLOY_CPU_OPERATOR(LayerNorm);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(LayerNorm);
......
......@@ -18,22 +18,22 @@
namespace dragon {
template <class Context>
class LayerNormOp final : public GroupNormOp<Context> {
class LayerNormOp final : public Operator<Context> {
public:
LayerNormOp(const OperatorDef& def, Workspace* ws)
: GroupNormOp<Context>(def, ws) {}
: Operator<Context>(def, ws),
epsilon_(OP_SINGLE_ARG(double, "epsilon", 1e-5)) {}
USE_OPERATOR_FUNCTIONS;
void GetBaseArguments() override {
auto& X = Input(0);
GET_OP_AXIS_ARG(axis, X.ndim(), -1);
// Set dimensions
this->N_ = X.count(0, axis);
this->C_ = this->D_ = X.count(axis);
this->G_ = this->S_ = 1;
// Set data format
this->data_format_ = "NHWC";
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
double epsilon_;
};
template <class Context>
......
......@@ -62,9 +62,6 @@ template <class Context>
template <typename T>
void ConvOpBase<Context>::Col2Im(const T* col, T* im) {
if (num_axes_ == 1 || num_axes_ == 2) {
// std::cout << conv_in_channels_ << std::endl;
// std::cout << in_shape_[0] << " " << in_shape_[1] << std::endl;
// std::cout << out_shape_[0] << " " << out_shape_[1] << std::endl;
kernels::Col2Im2d(
conv_in_channels_,
in_shape_[0],
......
#include "dragon/operators/vision/depth_to_space_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void DepthToSpaceOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
SET_INPUT_SPEC(0);
int start_axis, end_axis;
int num_dims = X.ndim(), num_axes = X.ndim() - 2;
CHECK_GT(num_dims, 2) << "\nExcepted the spatial input"
<< " with number of dimensions >= 3.";
// Compute the reshape and transpose arguments
vec64_t perm(size_t(num_axes * 2 + 2), 0);
vec64_t in_dims, out_shape = X.dims();
if (data_format() == "NCHW") {
start_axis = 2, end_axis = num_dims;
out_shape[1] /= std::pow(block_size_, num_axes);
in_dims = out_shape;
perm[1] = num_axes + 1;
for (int i = 0; i < num_axes; i++) {
perm[i * 2 + 2] = num_axes + i + 2;
perm[i * 2 + 3] = i + 1;
in_dims.insert(in_dims.begin() + 1, block_size_);
out_shape[start_axis + i] *= block_size_;
}
} else if (data_format() == "NHWC") {
start_axis = 1, end_axis = num_dims - 1;
out_shape[end_axis] /= std::pow(block_size_, num_axes);
in_dims = out_shape;
for (int i = 0; i < num_axes; i++) {
perm[i * 2 + 1] = i + 1;
perm[i * 2 + 2] = num_axes + i + 1;
in_dims.insert(in_dims.begin() + num_axes + 1, block_size_);
out_shape[start_axis + i] *= block_size_;
}
perm.back() = perm.size() - 1;
} else {
LOG(FATAL) << "Unknown DataFormat: " << data_format();
}
// Now, handle it as the generic transpose operation
Tensor X_reshape(in_dims);
vec64_t x_strides(in_dims.size()), y_dims(in_dims.size());
CHECK_EQ(X_reshape.count(), X.count())
<< "\nCould not rearrange " << X.DimString() << " to "
<< X_reshape.DimString() << " with block size " << block_size_ << ".";
for (int i = 0; i < in_dims.size(); i++) {
x_strides[i] = X_reshape.stride(perm[i]);
y_dims[i] = X_reshape.dim(perm[i]);
}
// Store for the gradient calculation
Buffer("X_strides")->template CopyFrom<int64_t>(x_strides);
Buffer("Y_dims")->template CopyFrom<int64_t>(y_dims);
kernels::Transpose(
x_strides.size(),
x_strides.data(),
y_dims.data(),
X.template data<T, Context>(),
Y->Reshape(out_shape)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void DepthToSpaceOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(DepthToSpace);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(DepthToSpace);
#endif
DEPLOY_CPU_OPERATOR(DepthToSpaceGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(DepthToSpaceGradient);
#endif
OPERATOR_SCHEMA(DepthToSpace)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(DepthToSpaceGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(DepthToSpace, SimpleGradientMaker);
} // namespace dragon
#include "dragon/operators/vision/space_to_depth_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -8,7 +7,6 @@ template <class Context>
template <typename T>
void SpaceToDepthOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
SET_INPUT_SPEC(0);
int start_axis, end_axis, perm_count = 0;
int num_dims = X.ndim(), num_axes = X.ndim() - 2;
......@@ -16,9 +14,9 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
CHECK_GT(num_dims, 2) << "\nExcepted the spatial input"
<< " with number of dimensions >= 3.";
// Compute the reshape and transpose arguments
// Compute the reshape and transpose arguments.
vec64_t perm(size_t(num_axes * 2 + 2));
vec64_t in_dims, in_shape = Input(0).dims();
vec64_t in_dims, in_shape = X.dims();
vec64_t out_shape = in_shape;
if (data_format() == "NCHW") {
......@@ -55,59 +53,111 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
}
}
// Now, handle it as the generic transpose operation
// Now, handle it as the generic transpose operation.
Tensor X_reshape(in_dims);
vec64_t x_strides(in_dims.size()), y_dims(in_dims.size());
CHECK_EQ(X_reshape.count(), X.count())
<< "\nCould not rearrange " << X.DimString() << " to "
<< X_reshape.DimString() << " with block size " << block_size_ << ".";
for (int i = 0; i < in_dims.size(); i++) {
x_strides[i] = X_reshape.stride(perm[i]);
y_dims[i] = X_reshape.dim(perm[i]);
vec64_t X_strides(in_dims.size());
vec64_t Y_dims(in_dims.size());
for (int i = 0; i < X_reshape.ndim(); i++) {
X_strides[i] = X_reshape.stride(perm[i]);
Y_dims[i] = X_reshape.dim(perm[i]);
}
// Store for the gradient calculation
Buffer("X_strides")->template CopyFrom<int64_t>(x_strides);
Buffer("Y_dims")->template CopyFrom<int64_t>(y_dims);
kernels::Transpose(
x_strides.size(),
x_strides.data(),
y_dims.data(),
X_strides.size(),
X_strides.data(),
Y_dims.data(),
X.template data<T, Context>(),
Y->Reshape(out_shape)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void SpaceToDepthOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
template <typename T>
void DepthToSpaceOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
int start_axis, end_axis;
int num_dims = X.ndim(), num_axes = X.ndim() - 2;
CHECK_GT(num_dims, 2) << "\nExcepted the spatial input"
<< " with number of dimensions >= 3.";
// Compute the reshape and transpose arguments.
vec64_t perm(size_t(num_axes * 2 + 2), 0);
vec64_t in_dims, out_shape = X.dims();
if (data_format() == "NCHW") {
start_axis = 2, end_axis = num_dims;
out_shape[1] /= std::pow(block_size_, num_axes);
in_dims = out_shape;
perm[1] = num_axes + 1;
for (int i = 0; i < num_axes; i++) {
perm[i * 2 + 2] = num_axes + i + 2;
perm[i * 2 + 3] = i + 1;
in_dims.insert(in_dims.begin() + 1, block_size_);
out_shape[start_axis + i] *= block_size_;
}
} else if (data_format() == "NHWC") {
start_axis = 1, end_axis = num_dims - 1;
out_shape[end_axis] /= std::pow(block_size_, num_axes);
in_dims = out_shape;
for (int i = 0; i < num_axes; i++) {
perm[i * 2 + 1] = i + 1;
perm[i * 2 + 2] = num_axes + i + 1;
in_dims.insert(in_dims.begin() + num_axes + 1, block_size_);
out_shape[start_axis + i] *= block_size_;
}
perm.back() = perm.size() - 1;
} else {
LOG(FATAL) << "Unknown DataFormat: " << data_format();
}
// Now, handle it as the generic transpose operation.
Tensor X_reshape(in_dims);
CHECK_EQ(X_reshape.count(), X.count())
<< "\nCould not rearrange " << X.DimString() << " to "
<< X_reshape.DimString() << " with block size " << block_size_ << ".";
vec64_t X_strides(in_dims.size());
vec64_t Y_dims(in_dims.size());
for (int i = 0; i < in_dims.size(); i++) {
X_strides[i] = X_reshape.stride(perm[i]);
Y_dims[i] = X_reshape.dim(perm[i]);
}
kernels::Transpose(
X_strides.size(),
X_strides.data(),
Y_dims.data(),
X.template data<T, Context>(),
Y->Reshape(out_shape)->template mutable_data<T, Context>(),
ctx());
}
DEPLOY_CPU_OPERATOR(SpaceToDepth);
REGISTER_CPU_OPERATOR(SpaceToDepthGradient, DepthToSpaceOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SpaceToDepth);
REGISTER_CUDA_OPERATOR(SpaceToDepthGradient, DepthToSpaceOp<CUDAContext>);
#endif
DEPLOY_CPU_OPERATOR(SpaceToDepthGradient);
DEPLOY_CPU_OPERATOR(DepthToSpace);
REGISTER_CPU_OPERATOR(DepthToSpaceGradient, SpaceToDepthOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(SpaceToDepthGradient);
DEPLOY_CUDA_OPERATOR(DepthToSpace);
REGISTER_CUDA_OPERATOR(DepthToSpaceGradient, SpaceToDepthOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(SpaceToDepth)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(SpaceToDepthGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1);
OPERATOR_SCHEMA(SpaceToDepth).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(SpaceToDepthGradient).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(DepthToSpace).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(DepthToSpaceGradient).NumInputs(1).NumOutputs(1);
REGISTER_GRADIENT(SpaceToDepth, SimpleGradientMaker);
REGISTER_GRADIENT(DepthToSpace, SimpleGradientMaker);
} // namespace dragon
......@@ -13,7 +13,7 @@
#ifndef DRAGON_OPERATORS_VISION_SPACE_TO_DEPTH_OP_H_
#define DRAGON_OPERATORS_VISION_SPACE_TO_DEPTH_OP_H_
#include "dragon/operators/array/transpose_op.h"
#include "dragon/core/operator.h"
namespace dragon {
......@@ -25,21 +25,34 @@ class SpaceToDepthOp final : public Operator<Context> {
block_size_(OP_SINGLE_ARG(int, "block_size", 2)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
int64_t block_size_;
Tensor X_, *X_strides_, *Y_dims_;
};
template <class Context>
class SpaceToDepthGradientOp final : public TransposeGradientOp<Context> {
class DepthToSpaceOp final : public Operator<Context> {
public:
SpaceToDepthGradientOp(const OperatorDef& def, Workspace* ws)
: TransposeGradientOp<Context>(def, ws) {}
DepthToSpaceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
block_size_(OP_SINGLE_ARG(int, "block_size", 2)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
int64_t block_size_;
};
} // namespace dragon
......
......@@ -547,7 +547,7 @@ def smooth_l1_loss_args(**kwargs):
}
@register('Softmax')
@register(['Softmax', 'LogSoftmax'])
def softmax_args(**kwargs):
return {'axis': kwargs.get('axis', -1)}
......
......@@ -283,6 +283,17 @@ class Tensor(types.TensorBase):
"""
def item(self):
"""Return the value as a python number.
Returns
-------
number
The value.
"""
return float(self) if 'float' in self.dtype else int(self)
def normal(self, mean=0, std=1):
r"""Fill self from a normal distribution.
......@@ -343,6 +354,17 @@ class Tensor(types.TensorBase):
"""
def tolist(self):
"""Return the value as a python list.
Returns
-------
list
The value.
"""
return self.numpy().tolist()
def truncated_normal(self, mean=0, std=1):
r"""Fill self from a truncated normal distribution.
......@@ -452,7 +474,7 @@ class Tensor(types.TensorBase):
"""
def __float__(self):
"""Return a float python scalar.
"""Return the value as a python number.
Returns
-------
......@@ -591,7 +613,7 @@ class Tensor(types.TensorBase):
"""
def __int__(self):
"""Return an integer python scalar.
"""Return the value as a python number.
Returns
-------
......
......@@ -129,7 +129,7 @@ class DataReader(multiprocessing.Process):
self._init_dataset()
# Persist a loop to read examples.
while True:
# Pop the depleted part if necessary
# Pop the depleted part if necessary.
if self._parts[0].start == self._parts[0].end:
self._parts.pop(0)
offset = 0
......@@ -145,10 +145,10 @@ class DataReader(multiprocessing.Process):
# Load and push back a new example into the buffer.
k = self._parts[-1].end % len(self._example_buffer)
self._example_buffer[k] = self.next_example()
# Increase the part boundaries
# Increase the part boundaries.
self._parts[-1].end += 1
self._parts[0].start += 1
# Reset the cursor if necessary
# Reset the cursor if necessary.
if self._cursor >= self._last:
self.reset()
......
......@@ -17,8 +17,6 @@ from __future__ import print_function
from dragon.core.autograph import context
from dragon.core.autograph.op_impl import OpLib
from dragon.core.autograph.op_impl import OpSchema
from dragon.core.ops import math_ops
from dragon.core.ops import array_ops
@OpSchema.num_inputs(1)
......@@ -353,7 +351,7 @@ def leaky_relu(inputs, alpha=0.2, inplace=False, **kwargs):
@OpSchema.num_inputs(1)
def log_softmax(inputs, axis=-1, **kwargs):
def log_softmax(inputs, axis=-1, inplace=False, **kwargs):
r"""Compute the composite of logarithm and softmax.
The **LogSoftmax** function is defined as:
......@@ -374,6 +372,8 @@ def log_softmax(inputs, axis=-1, **kwargs):
The input tensor.
axis : int, optional, default=-1
The axis to reduce.
inplace : bool, optional, default=False
Call in-place or return a new tensor.
Returns
-------
......@@ -381,11 +381,11 @@ def log_softmax(inputs, axis=-1, **kwargs):
The output tensor.
"""
return math_ops.sub(
[inputs, math_ops.log(
array_ops.sum(math_ops.exp(inputs, **kwargs),
axis=[axis], keepdims=True, **kwargs),
**kwargs)], **kwargs)
if context.executing_eagerly():
return OpLib.execute(
'LogSoftmax', inputs,
outputs=inputs if inplace else [None], axis=axis)
return OpLib.add('LogSoftmax', inputs, axis=axis, **kwargs)
@OpSchema.num_inputs(2)
......
......@@ -69,7 +69,7 @@ def selu_exporter(op_def, context):
return node, const_tensors
@export_util.register('Softmax')
@export_util.register(['Softmax', 'LogSoftmax'])
def softmax_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
ndim = len(context.blob_shapes[op_def.input[0]])
......@@ -82,7 +82,7 @@ def softmax_exporter(op_def, context):
return node, const_tensors
@export_util.register('Softmax-13')
@export_util.register(['Softmax-13', 'LogSoftmax-13'])
def softmax_exporter_v13(op_def, context):
node, const_tensors = export_util.translate(**locals())
ndim = len(context.blob_shapes[op_def.input[0]])
......
......@@ -228,12 +228,11 @@ DEFINE_BROADCAST_1ST_FUNC(Div, double, /);
const int rows, const int cols, const T* a, const T* b, T* y) { \
if (a == y) { \
EigenArrayMap<T>(y, cols, rows).rowwise() Expr## = \
ConstEigenVectorArrayMap<T>(b, rows).transpose(); \
ConstEigenVectorArrayMap2<T>(b, rows); \
} else { \
EigenArrayMap<T>(y, cols, rows) = \
ConstEigenArrayMap<T>(a, cols, rows) \
.rowwise() Expr ConstEigenVectorArrayMap<T>(b, rows) \
.transpose(); \
.rowwise() Expr ConstEigenVectorArrayMap2<T>(b, rows); \
} \
}
......
......@@ -109,20 +109,14 @@ __global__ void _GenericReduce(
} \
if (math::utils::IsRowwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_RowwiseReduce<<< \
CUDA_2D_BLOCKS(cols), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(rows, cols, reducer, init, scale, x, y); \
_RowwiseReduce<<<cols, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, cols, reducer, init, scale, x, y); \
return; \
} \
if (math::utils::IsColwiseReduce( \
num_dims, dims, out_dims.data(), &rows, &cols)) { \
_ColwiseReduce<<< \
CUDA_2D_BLOCKS(rows), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(rows, cols, reducer, init, scale, x, y); \
_ColwiseReduce<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, cols, reducer, init, scale, x, y); \
return; \
} \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
......@@ -144,11 +138,7 @@ __global__ void _GenericReduce(
for (int i = 0; i < num_dims; ++i) { \
transpose_dims.data[i] = dims[transpose_axes.data[i]]; \
} \
_GenericReduce<<< \
CUDA_2D_BLOCKS(rows), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
_GenericReduce<<<rows, CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, \
cols, \
num_dims, \
......
......@@ -229,6 +229,25 @@ void SoftmaxGrad(
Context* ctx);
template <typename T, class Context>
void LogSoftmax(
const int N,
const int S,
const int C,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void LogSoftmaxGrad(
const int N,
const int S,
const int C,
const T* dy,
const T* y,
T* dx,
Context* ctx);
template <typename T, class Context>
void Tanh(const int N, const T* x, T* y, Context* ctx);
template <typename T, class Context>
......@@ -586,15 +605,6 @@ void Transpose(
Context* ctx);
template <typename T, class Context>
void TransposeGrad(
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const T* dy,
T* dx,
Context* ctx);
template <typename T, class Context>
void TopK(
const int N,
const int S,
......@@ -978,6 +988,19 @@ void L2NormalizeGrad(
T* dx,
Context* ctx);
template <typename T, typename AccT, class Context>
void LayerNorm(
const int N,
const int C,
const float epsilon,
const T* x,
const AccT* gamma,
const AccT* beta,
AccT* mu,
AccT* rsig,
T* y,
Context* ctx);
/*
* RecurrentOp Kernels
*/
......
......@@ -10,8 +10,8 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_STRING_H_
#define DRAGON_UTILS_STRING_H_
#ifndef DRAGON_UTILS_STRING_UTILS_H_
#define DRAGON_UTILS_STRING_UTILS_H_
#include <algorithm>
#include <cstdlib>
......@@ -100,4 +100,4 @@ inline std::string replace_all(
} // namespace dragon
#endif // DRAGON_UTILS_STRING_H_
#endif // DRAGON_UTILS_STRING_UTILS_H_
......@@ -115,6 +115,8 @@ class TestTensor(unittest.TestCase):
self.assertEqual(a.__repr__(), b.__repr__())
self.assertNotEqual(a.__repr__(), dragon.Tensor((), symbolic=True).__repr__())
self.assertEqual(float(int(a)), float(b))
self.assertEqual(dragon.constant([2]).item(), 2)
self.assertEqual(dragon.constant([2, 3]).tolist(), [2, 3])
try:
_ = dragon.Tensor(None)
except ValueError:
......
......@@ -76,7 +76,7 @@ class OpTestCase(unittest.TestCase):
second = inputs[num_first:len(inputs)] if num_second > 1 else inputs[num_first]
if isinstance(first, np.ndarray) and isinstance(second, np.ndarray):
super(OpTestCase, self).assertEqual(first.shape, second.shape)
if first.dtype == np.bool and second.dtype == np.bool:
if first.dtype == bool and second.dtype == bool:
diff = first ^ second
num_unique = len(np.unique(diff))
self.assertLessEqual(num_unique, 1, msg)
......
......@@ -61,7 +61,7 @@ class OpTestCase(unittest.TestCase):
second = inputs[num_first:len(inputs)] if num_second > 1 else inputs[num_first]
if isinstance(first, np.ndarray) and isinstance(second, np.ndarray):
super(OpTestCase, self).assertEqual(first.shape, second.shape)
if first.dtype == np.bool and second.dtype == np.bool:
if first.dtype == bool and second.dtype == bool:
diff = first ^ second
num_unique = len(np.unique(diff))
self.assertLessEqual(num_unique, 1, msg)
......
......@@ -59,7 +59,7 @@ class OpTestCase(unittest.TestCase):
second = inputs[num_first:len(inputs)] if num_second > 1 else inputs[num_first]
if isinstance(first, np.ndarray) and isinstance(second, np.ndarray):
super(OpTestCase, self).assertEqual(first.shape, second.shape)
if first.dtype == np.bool and second.dtype == np.bool:
if first.dtype == bool and second.dtype == bool:
diff = first ^ second
num_unique = len(np.unique(diff))
self.assertLessEqual(num_unique, 1, msg)
......@@ -243,6 +243,8 @@ class TestTensorOps(OpTestCase):
data = np.array([0., 1., 2.], 'float32')
x = new_tensor(data)
self.assertEqual(x.exp(), np.exp(data))
x.exp_()
self.assertEqual(x, np.exp(data))
def test_expand(self):
entries = [(2, 2, 3, 1),
......@@ -403,6 +405,8 @@ class TestTensorOps(OpTestCase):
data = np.array([1., 2., 3.], 'float32')
x = new_tensor(data)
self.assertEqual(x.log(), np.log(data))
x.log_()
self.assertEqual(x, np.log(data))
def test_logical_and(self):
for a_shape, b_shape in self.binary_test_shapes:
......
......@@ -52,6 +52,8 @@ class TestTensor(unittest.TestCase):
self.assertEqual(int(a.detach()), 0)
self.assertEqual(torch.Tensor([0]).dim(), 1)
self.assertEqual(float(torch.Tensor(1).one_()), 1.)
self.assertEqual(torch.tensor(2.333).item(), 2.333)
self.assertEqual(torch.tensor([2, 3]).tolist(), [2, 3])
self.assertEqual(torch.empty(2, 3).ndimension(), 2)
self.assertEqual(torch.empty(3).new_empty(2, 3).ndimension(), 2)
self.assertEqual(repr(torch.tensor(1)), '1')
......
......@@ -1303,7 +1303,7 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
size=size, alpha=float(alpha), beta=float(beta), bias=float(k))
def log_softmax(input, dim):
def log_softmax(input, dim, inplace=False):
r"""Apply the composite of logarithm and softmax to input.
The **LogSoftmax** function is defined as:
......@@ -1316,6 +1316,8 @@ def log_softmax(input, dim):
The input.
dim : int
The dimension to reduce.
inplace : bool, optional, default=False
Whether to do the operation in-place.
Returns
-------
......@@ -1327,7 +1329,9 @@ def log_softmax(input, dim):
`torch.nn.LogSoftmax(...)`_
"""
return input - input.logsumexp(dim, keepdim=True)
return FunctionLib.apply(
'LogSoftmax', input.device, [input],
outputs=[input if inplace else None], axis=dim)
def lstm_cell(input, cx):
......
......@@ -142,18 +142,17 @@ class GumbelSoftmax(Module):
self.tau = tau
self.dim = dim
self.inplace = inplace
if dim is None:
raise ValueError('Excepted a valid dim, got None.')
def forward(self, logits=None, probs=None):
if probs is not None:
input = probs.log()
else:
input = logits - logits.logsumexp(dim=self.dim, keepdim=True)
u_dist = init_ops.rand(input.shape, dtype=input.dtype, device=input.device)
gumbels = -((-(u_dist.log())).log())
scores = (input + gumbels) / self.tau
return F.softmax(scores, self.dim, self.inplace)
def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
return 'dim={}{}'.format(self.dim, inplace_str)
def forward(self, input):
u_dist = init_ops.rand(input.shape, dtype=input.dtype,
device=input.device)
gumbel = -((-(u_dist.log())).log())
gumbel = (input + gumbel) / self.tau
return F.softmax(gumbel, self.dim, self.inplace)
class Hardsigmoid(Module):
......@@ -307,23 +306,27 @@ class LogSoftmax(Module):
"""
def __init__(self, dim):
def __init__(self, dim, inplace=False):
"""Create a ``LogSoftmax`` module.
Parameters
----------
dim : int
The dimension to reduce.
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(LogSoftmax, self).__init__()
self.dim = dim
self.inplace = inplace
def extra_repr(self):
return 'dim={dim}'.format(dim=self.dim)
inplace_str = ', inplace' if self.inplace else ''
return 'dim={}{}'.format(self.dim, inplace_str)
def forward(self, input):
return F.log_softmax(input, self.dim)
return F.log_softmax(input, self.dim, self.inplace)
class MultiheadAttention(Module):
......
......@@ -788,6 +788,24 @@ def exp(self):
return math_ops.exp(self)
def exp_(self):
r"""Set to the exponential of elements.
.. math:: \text{self} = \exp(\text{self})
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.exp(...)`_
"""
return math_ops.exp(self, self)
def expand(self, *sizes):
"""Return a tensor with elements broadcast.
......@@ -1234,6 +1252,24 @@ def log(self):
return math_ops.log(self)
def log_(self):
r"""Set to the natural logarithm of elements.
.. math:: \text{self} = \log(\text{self})
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.log(...)`_
"""
return math_ops.log(self, self)
def logical_and(self, other):
r"""Compute the element-wise AND logical operation.
......@@ -2916,6 +2952,7 @@ Tensor.double = double
Tensor.double_ = double_
Tensor.eq = eq
Tensor.exp = exp
Tensor.exp_ = exp_
Tensor.expand = expand
Tensor.fill_ = fill_
Tensor.flatten = flatten
......@@ -2941,6 +2978,7 @@ Tensor.le = le
Tensor.long = long
Tensor.long_ = long_
Tensor.log = log
Tensor.log_ = log_
Tensor.logical_and = logical_and
Tensor.logical_not = logical_not
Tensor.logical_or = logical_or
......
......@@ -972,6 +972,22 @@ class Tensor(object):
"""
def exp_(self):
r"""Set to the exponential of elements.
.. math:: \text{self} = \exp(\text{self})
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.exp(...)`_
"""
def expand(self, *sizes):
"""Return a tensor with elements broadcast.
......@@ -1326,6 +1342,17 @@ class Tensor(object):
"""
return 'float' in self.dtype
def item(self):
"""Return the value as a python number.
Returns
-------
number
The value.
"""
return float(self) if self.is_floating_point() else int(self)
def le(self, other):
r"""Compute the element-wise less-equal comparison.
......@@ -1363,6 +1390,22 @@ class Tensor(object):
"""
def log_(self):
r"""Set to the natural logarithm of elements.
.. math:: \text{self} = \log(\text{self})
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.log(...)`_
"""
def logical_and(self, other):
r"""Compute the element-wise AND logical operation.
......@@ -2676,6 +2719,17 @@ class Tensor(object):
return self.type(dtype)
return self
def tolist(self):
"""Return the value as a python list.
Returns
-------
list
The value.
"""
return self.numpy().tolist()
def topk(self, k, dim=-1, largest=True, sorted=True):
"""Return the top-K largest or smallest elements.
......@@ -3089,7 +3143,7 @@ class Tensor(object):
return self.eq(other)
def __float__(self):
"""Return a float python scalar.
"""Return the value as a python number.
Returns
-------
......@@ -3194,7 +3248,7 @@ class Tensor(object):
return self.mul_(other)
def __int__(self):
"""Return an integer python scalar.
"""Return the value as a python number.
Returns
-------
......@@ -3202,7 +3256,7 @@ class Tensor(object):
The integer value.
"""
return int(self.__float__())
return int(self.numpy())
def __invert__(self):
"""Compute the element-wise NOT bitwise operation.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!