Commit 43a82e77 by Ting PAN

Add scatter-gather elements operator

Summary:
This commit adds scatter and gather operator
to remap elements along the given dimension of indices.
1 parent f431756f
Showing with 2464 additions and 322 deletions
......@@ -57,7 +57,7 @@ class ImageDecoder(object):
Returns
-------
nvidia.dali.ops.ImageDecoder
nvidia.dali.ops.decoders.Image
The operator.
"""
......@@ -118,7 +118,7 @@ class ImageDecoderRandomCrop(object):
Returns
-------
nvidia.dali.ops.ImageDecoderRandomCrop
nvidia.dali.ops.decoders.ImageRandomCrop
The operator.
"""
......
......@@ -368,7 +368,7 @@ class RandomResizedCrop(object):
num_attempts=10,
**kwargs
):
"""Create a ``ImageDecoderRandomCrop`` operator.
"""Create a ``RandomResizedCrop`` operator.
Parameters
----------
......@@ -389,7 +389,7 @@ class RandomResizedCrop(object):
Returns
-------
nvidia.dali.ops.ImageDecoderRandomCrop
nvidia.dali.ops.RandomResizedCrop
The operator.
"""
......@@ -463,6 +463,11 @@ class Resize(object):
min_filter : str, optional, default='TRIANGULAR'
The interpolation for down sampling.
Returns
------
nvidia.dali.ops.Resize
The operator.
"""
if isinstance(interp_type, six.string_types):
interp_type = getattr(types, 'INTERP_' + interp_type.upper())
......
......@@ -43,7 +43,7 @@ class CoinFlip(object):
Returns
-------
nvidia.dali.ops.CoinFlip
nvidia.dali.ops.random.CoinFlip
The operator.
"""
......@@ -72,7 +72,7 @@ class Uniform(object):
Returns
-------
nvidia.dali.ops.Uniform
nvidia.dali.ops.random.Uniform
The operator.
"""
......
......@@ -213,7 +213,7 @@ class TFRecordReader(object):
Returns
-------
nvidia.dali.ops.TFRecordReader
nvidia.dali.ops.readers.TFRecord
The reader instance.
"""
......
......@@ -34,7 +34,7 @@ extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinxcontrib.katex',
# 'sphinx_seeta_theme.ext.viewcode',
'sphinx_seeta_theme.ext.viewcode',
]
napoleon_use_rtype = False
......
......@@ -80,7 +80,10 @@ dragon
: Compile a function and return an executable.
`gather(...) <dragon/gather.html>`_
: Gather the elements along the given axis using index.
: Gather elements along the given axis using index.
`gather_elements(...) <dragon/gather_elements.html>`_
: Gather elements along the given axis of index.
`get_num_threads(...) <dragon/get_num_threads.html>`_
: Return the number of threads for cpu parallelism.
......@@ -133,6 +136,12 @@ dragon
`reshape(...) <dragon/reshape.html>`_
: Change the dimensions of input.
`scatter_add(...) <dragon/scatter_add.html>`_
: Add elements along the given axis of index.
`scatter_elements(...) <dragon/scatter_elements.html>`_
: Update elements along the given axis of index.
`set_num_threads(...) <dragon/set_num_threads.html>`_
: Set the number of threads for cpu parallelism.
......@@ -207,6 +216,7 @@ dragon
dragon/flatten
dragon/function
dragon/gather
dragon/gather_elements
dragon/get_num_threads
dragon/get_workspace
dragon/graph_mode
......@@ -224,6 +234,8 @@ dragon
dragon/repeat
dragon/reset_workspace
dragon/reshape
dragon/scatter_add
dragon/scatter_elements
dragon/set_num_threads
dragon/shape
dragon/slice
......
......@@ -96,6 +96,10 @@ __and__
#######
.. automethod:: dragon.Tensor.__and__
__eq__
######
.. automethod:: dragon.Tensor.__eq__
__float__
#########
.. automethod:: dragon.Tensor.__float__
......@@ -160,6 +164,10 @@ __mul__
#######
.. automethod:: dragon.Tensor.__mul__
__ne__
######
.. automethod:: dragon.Tensor.__ne__
__neg__
#######
.. automethod:: dragon.Tensor.__neg__
......@@ -222,12 +230,14 @@ __xor__
.. _dragon.identity(...): identity.html
.. _dragon.math.add(...): math/add.html
.. _dragon.math.div(...): math/div.html
.. _dragon.math.equal(...): math/equal.html
.. _dragon.math.greater(...): math/greater.html
.. _dragon.math.greater_equal(...): math/greater_equal.html
.. _dragon.math.less(...): math/less.html
.. _dragon.math.less_equal(...): math/less_equal.html
.. _dragon.math.mul(...): math/mul.html
.. _dragon.math.negative(...): math/negative.html
.. _dragon.math.not_equal(...): math/not_equal.html
.. _dragon.math.sub(...): math/sub.html
.. _dragon.random.glorot_normal(...): random/glorot_normal.html
.. _dragon.random.glorot_uniform(...): random/glorot_uniform.html
......
gather_elements
===============
.. autofunction:: dragon.gather_elements
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
......@@ -53,7 +53,7 @@ Name Supported Reference
`Compress`_
`Concat`_ |v| :func:`dragon.concat`
`ConcatFromSequence`_
`Constant`_
`Constant`_ |v| :func:`dragon.fill`
`ConstantOfShape`_
`Conv`_ |v| :func:`dragon.nn.conv`
`ConvInteger`_
......@@ -77,7 +77,7 @@ Name Supported Reference
`Floor`_ |v| :func:`dragon.math.floor`
`GRU`_ |v| :func:`dragon.nn.GRU`
`Gather`_ |v| :func:`dragon.gather`
`GatherElements`_
`GatherElements`_ |v| :func:`dragon.gather_elements`
`GatherND`_
`Gemm`_ |v| :func:`dragon.math.gemm`
`GlobalAveragePool`_ |v| :func:`dragon.nn.pool`
......@@ -146,8 +146,8 @@ Name Supported Reference
`RoiAlign`_ |v| :func:`dragon.vision.roi_align`
`Round`_ |v| :func:`dragon.math.round`
`Scan`_
`Scatter`_
`ScatterElements`_
`Scatter`_ |v| :func:`dragon.scatter_elements`
`ScatterElements`_ |v| :func:`dragon.scatter_elements`
`ScatterND`_
`Selu`_ |v| :func:`dragon.nn.selu`
`SequenceAt`_
......
scatter_add
===========
.. autofunction:: dragon.scatter_add
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
scatter_elements
================
.. autofunction:: dragon.scatter_elements
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
......@@ -119,7 +119,7 @@ vm.torch
: Compute the largest integer not greater than input.
`from_numpy(...) <torch/from_numpy.html>`_
: Create a tensor from the given numpy array.
: Create a tensor converting from the given numpy array.
`full(...) <torch/full.html>`_
: Return a tensor filled with a scalar.
......@@ -127,6 +127,9 @@ vm.torch
`full_like(...) <torch/full_like.html>`_
: Return a tensor filled with a scalar with size as input.
`gather(...) <torch/gather.html>`_
: Gather elements along the given dimension of index.
`ge(...) <torch/ge.html>`_
: Compute the element-wise greater-equal comparison.
......@@ -134,7 +137,7 @@ vm.torch
: Compute the element-wise greater comparison.
`index_select(...) <torch/index_select.html>`_
: Select the elements along the given dim using index.
: Select elements along the given dimension using index.
`isinf(...) <torch/isinf.html>`_
: Check if the elements of input are infinite.
......@@ -247,6 +250,12 @@ vm.torch
`rsqrt(...) <torch/rsqrt.html>`_
: Compute the reciprocal square root of input.
`scatter(...) <torch/scatter.html>`_
: Update elements along the given dimension of index.
`scatter_add(...) <torch/scatter_add.html>`_
: Add elements along the given dimension of index.
`sign(...) <torch/sign.html>`_
: Compute the sign indication of input.
......@@ -275,7 +284,7 @@ vm.torch
: Compute the sum value of elements along the given dimension.
`tensor(...) <torch/tensor.html>`_
: Create a tensor initializing the content from data.
: Create a tensor initializing from the given data.
`tile(...) <torch/tile.html>`_
: Repeat elements along each dimension of input.
......@@ -347,6 +356,7 @@ vm.torch
torch/from_numpy
torch/full
torch/full_like
torch/gather
torch/ge
torch/gt
torch/index_select
......@@ -388,6 +398,8 @@ vm.torch
torch/reshape
torch/round
torch/rsqrt
torch/scatter
torch/scatter_add
torch/set_grad_enabled
torch/sign
torch/sin
......
......@@ -249,6 +249,10 @@ floor\_
#######
.. automethod:: dragon.vm.torch.Tensor.floor_
gather
######
.. automethod:: dragon.vm.torch.Tensor.gather
ge
###
.. automethod:: dragon.vm.torch.Tensor.ge
......@@ -397,10 +401,6 @@ neg\_
#####
.. automethod:: dragon.vm.torch.Tensor.neg_
new_ones
########
.. automethod:: dragon.vm.torch.Tensor.new_ones
new_empty
#########
.. automethod:: dragon.vm.torch.Tensor.new_empty
......@@ -409,6 +409,14 @@ new_full
########
.. automethod:: dragon.vm.torch.Tensor.new_full
new_ones
########
.. automethod:: dragon.vm.torch.Tensor.new_ones
new_tensor
##########
.. automethod:: dragon.vm.torch.Tensor.new_tensor
new_zeros
#########
.. automethod:: dragon.vm.torch.Tensor.new_zeros
......@@ -481,6 +489,22 @@ rsqrt\_
#######
.. automethod:: dragon.vm.torch.Tensor.rsqrt_
scatter
#######
.. automethod:: dragon.vm.torch.Tensor.scatter
scatter\_
#########
.. automethod:: dragon.vm.torch.Tensor.scatter_
scatter_add
###########
.. automethod:: dragon.vm.torch.Tensor.scatter_add
scatter_add\_
#############
.. automethod:: dragon.vm.torch.Tensor.scatter_add_
sign
####
.. automethod:: dragon.vm.torch.Tensor.sign
......@@ -624,6 +648,7 @@ zero\_
.. _torch.flatten(...): flatten.html
.. _torch.floor(...): floor.html
.. _torch.full(...): full.html
.. _torch.gather(...): gather.html
.. _torch.ge(...): ge.html
.. _torch.gt(...): gt.html
.. _torch.isinf(...): isinf.html
......@@ -652,6 +677,8 @@ zero\_
.. _torch.reshape(...): reshape.html
.. _torch.round(...): round.html
.. _torch.rsqrt(...): rsqrt.html
.. _torch.scatter(...): scatter.html
.. _torch.scatter_add(...): scatter_add.html
.. _torch.sign(...): sign.html
.. _torch.sin(...): sin.html
.. _torch.sort(...): sort.html
......@@ -660,6 +687,7 @@ zero\_
.. _torch.squeeze(...): squeeze.html
.. _torch.sub(...): sub.html
.. _torch.sum(...): sum.html
.. _torch.tensor(...): tensor.html
.. _torch.topk(...): topk.html
.. _torch.transpose(...): transpose.html
.. _torch.tril(...): tril.html
......
gather
======
.. autofunction:: dragon.vm.torch.gather
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
scatter
=======
.. autofunction:: dragon.vm.torch.scatter
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
scatter_add
===========
.. autofunction:: dragon.vm.torch.scatter_add
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
......@@ -52,12 +52,12 @@ void _ChannelNormalize(
_ChannelNormalize(axis, num_dims, x_strides, y_dims, x, mean, std, y); \
}
DEFINE_KERNEL_LAUNCHER(int8_t, float16);
DEFINE_KERNEL_LAUNCHER(int8_t, float);
DEFINE_KERNEL_LAUNCHER(int8_t, double);
DEFINE_KERNEL_LAUNCHER(uint8_t, float16);
DEFINE_KERNEL_LAUNCHER(uint8_t, float);
DEFINE_KERNEL_LAUNCHER(uint8_t, double);
DEFINE_KERNEL_LAUNCHER(int8_t, float16);
DEFINE_KERNEL_LAUNCHER(int8_t, float);
DEFINE_KERNEL_LAUNCHER(int8_t, double);
DEFINE_KERNEL_LAUNCHER(int, float16);
DEFINE_KERNEL_LAUNCHER(int, float);
DEFINE_KERNEL_LAUNCHER(int, double);
......
......@@ -66,12 +66,12 @@ __global__ void _ChannelNormalize(
N, axis, num_dims, X_strides, Y_dims, x, mean, std, y); \
}
DEFINE_KERNEL_LAUNCHER(int8_t, float16);
DEFINE_KERNEL_LAUNCHER(int8_t, float);
DEFINE_KERNEL_LAUNCHER(int8_t, double);
DEFINE_KERNEL_LAUNCHER(uint8_t, float16);
DEFINE_KERNEL_LAUNCHER(uint8_t, float);
DEFINE_KERNEL_LAUNCHER(uint8_t, double);
DEFINE_KERNEL_LAUNCHER(int8_t, float16);
DEFINE_KERNEL_LAUNCHER(int8_t, float);
DEFINE_KERNEL_LAUNCHER(int8_t, double);
DEFINE_KERNEL_LAUNCHER(int, float16);
DEFINE_KERNEL_LAUNCHER(int, float);
DEFINE_KERNEL_LAUNCHER(int, double);
......
......@@ -51,6 +51,28 @@ void _GatherGrad(
}
}
template <typename T>
void _GatherElements(
const int axis,
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const int64_t* index,
const T* x,
T* y) {
const auto N =
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t dim_index(num_dims, 0);
for (int yi = 0; yi < N; ++yi) {
int64_t xi = 0;
for (int d = num_dims - 1; d >= 0; --d) {
xi += (d == axis ? index[yi] : dim_index[d]) * x_strides[d];
}
y[yi] = x[xi];
math::utils::IncreaseIndexInDims(num_dims, y_dims, dim_index.data());
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
......@@ -82,6 +104,30 @@ DEFINE_KERNEL_LAUNCHER(GatherGrad, float, float); // GatherGrad
DEFINE_KERNEL_LAUNCHER(GatherGrad, double, float); // GatherGrad
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void GatherElements<T, CPUContext>( \
const int axis, \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const int64_t* index, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_GatherElements(axis, num_dims, x_strides, y_dims, index, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
......@@ -47,6 +47,27 @@ __global__ void _GatherGrad(
}
}
template <typename T, int D>
__global__ void _GatherElements(
const int N,
const int axis,
const int num_dims,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims,
const int64_t* index,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
xi += (d == axis ? index[yi] : r) * X_strides.data[d];
}
y[yi] = x[xi];
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
......@@ -86,6 +107,40 @@ DEFINE_KERNEL_LAUNCHER(GatherGrad, float, float); // GatherGrad
DEFINE_KERNEL_LAUNCHER(GatherGrad, double, float); // GatherGrad
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>( \
const int axis, \
const int num_dims, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const int64_t* index, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> Y_dims; \
const auto N = std::accumulate( \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \
} \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, axis, num_dims, X_strides, Y_dims, index, x, y); \
}
DEFINE_KERNEL_LAUNCHER(GatherElements, bool);
DEFINE_KERNEL_LAUNCHER(GatherElements, uint8_t);
DEFINE_KERNEL_LAUNCHER(GatherElements, int8_t);
DEFINE_KERNEL_LAUNCHER(GatherElements, int);
DEFINE_KERNEL_LAUNCHER(GatherElements, int64_t);
DEFINE_KERNEL_LAUNCHER(GatherElements, float16);
DEFINE_KERNEL_LAUNCHER(GatherElements, float);
DEFINE_KERNEL_LAUNCHER(GatherElements, double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
......
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _ScatterElements(
const int axis,
const int num_dims,
const T value,
const int64_t* dims,
const int64_t* y_strides,
const int64_t* index,
T* y) {
const auto N =
std::accumulate(dims, dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t dim_index(num_dims, 0);
for (int i = 0; i < N; ++i) {
int64_t yi = 0;
for (int d = num_dims - 1; d >= 0; --d) {
yi += (d == axis ? index[i] : dim_index[d]) * y_strides[d];
}
y[yi] = value;
math::utils::IncreaseIndexInDims(num_dims, dims, dim_index.data());
}
}
template <typename T>
void _ScatterElements(
const int axis,
const int num_dims,
const int64_t* dims,
const int64_t* x_strides,
const int64_t* y_strides,
const int64_t* index,
const T* x,
T* y) {
const auto N =
std::accumulate(dims, dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t dim_index(num_dims, 0);
for (int i = 0; i < N; ++i) {
int64_t xi = 0, yi = 0;
for (int d = num_dims - 1; d >= 0; --d) {
xi += dim_index[d] * x_strides[d];
yi += (d == axis ? index[i] : dim_index[d]) * y_strides[d];
}
y[yi] = x[xi];
math::utils::IncreaseIndexInDims(num_dims, dims, dim_index.data());
}
}
template <typename T, typename AccT>
void _ScatterAdd(
const int axis,
const int num_dims,
const int64_t* dims,
const int64_t* x_strides,
const int64_t* y_strides,
const int64_t* index,
const T* x,
AccT* y) {
const auto N =
std::accumulate(dims, dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t dim_index(num_dims, 0);
for (int i = 0; i < N; ++i) {
int64_t xi = 0, yi = 0;
for (int d = num_dims - 1; d >= 0; --d) {
xi += dim_index[d] * x_strides[d];
yi += (d == axis ? index[i] : dim_index[d]) * y_strides[d];
}
y[yi] += convert::To<AccT>(x[xi]);
math::utils::IncreaseIndexInDims(num_dims, dims, dim_index.data());
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int axis, \
const int num_dims, \
const T value, \
const int64_t* dims, \
const int64_t* y_strides, \
const int64_t* index, \
T* y, \
CPUContext* ctx) { \
_##name(axis, num_dims, value, dims, y_strides, index, y); \
}
DEFINE_KERNEL_LAUNCHER(ScatterElements, bool);
DEFINE_KERNEL_LAUNCHER(ScatterElements, uint8_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int8_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int64_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, float16);
DEFINE_KERNEL_LAUNCHER(ScatterElements, float);
DEFINE_KERNEL_LAUNCHER(ScatterElements, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int axis, \
const int num_dims, \
const int64_t* dims, \
const int64_t* x_strides, \
const int64_t* y_strides, \
const int64_t* index, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_##name(axis, num_dims, dims, x_strides, y_strides, index, x, y); \
}
DEFINE_KERNEL_LAUNCHER(ScatterElements, bool);
DEFINE_KERNEL_LAUNCHER(ScatterElements, uint8_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int8_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int64_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, float16);
DEFINE_KERNEL_LAUNCHER(ScatterElements, float);
DEFINE_KERNEL_LAUNCHER(ScatterElements, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T, AccT) \
template <> \
void name<T, AccT, CPUContext>( \
const int axis, \
const int num_dims, \
const int64_t* dims, \
const int64_t* x_strides, \
const int64_t* y_strides, \
const int64_t* index, \
const T* x, \
AccT* y, \
CPUContext* ctx) { \
_##name(axis, num_dims, dims, x_strides, y_strides, index, x, y); \
}
DEFINE_KERNEL_LAUNCHER(ScatterAdd, uint8_t, uint8_t);
DEFINE_KERNEL_LAUNCHER(ScatterAdd, int8_t, int8_t);
DEFINE_KERNEL_LAUNCHER(ScatterAdd, int, int)
DEFINE_KERNEL_LAUNCHER(ScatterAdd, int64_t, int64_t)
DEFINE_KERNEL_LAUNCHER(ScatterAdd, float16, float);
DEFINE_KERNEL_LAUNCHER(ScatterAdd, float, float)
DEFINE_KERNEL_LAUNCHER(ScatterAdd, double, float);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, int D>
__global__ void _ScatterElements(
const int N,
const int axis,
const int num_dims,
const T value,
const SimpleArray<int, D> X_dims,
const SimpleArray<int, D> Y_strides,
const int64_t* index,
T* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
int yi = 0, tmp = i;
for (int d = num_dims - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(X_dims.data[d], tmp, &tmp, &r);
yi += (d == axis ? index[i] : r) * Y_strides.data[d];
}
y[yi] = value;
}
}
template <typename T, int D>
__global__ void _ScatterElements(
const int N,
const int axis,
const int num_dims,
const SimpleArray<int, D> X_dims,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_strides,
const int64_t* index,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
int xi = 0, yi = 0, tmp = i;
for (int d = num_dims - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(X_dims.data[d], tmp, &tmp, &r);
xi += r * X_strides.data[d];
yi += (d == axis ? index[i] : r) * Y_strides.data[d];
}
y[yi] = x[xi];
}
}
template <typename T, typename AccT, int D>
__global__ void _ScatterAdd(
const int N,
const int axis,
const int num_dims,
const SimpleArray<int, D> X_dims,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_strides,
const int64_t* index,
const T* x,
AccT* y) {
CUDA_1D_KERNEL_LOOP(i, N) {
int xi = 0, yi = 0, tmp = i;
for (int d = num_dims - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(X_dims.data[d], tmp, &tmp, &r);
xi += r * X_strides.data[d];
yi += (d == axis ? index[i] : r) * Y_strides.data[d];
}
math::utils::AtomicAdd(y + yi, convert::To<AccT>(x[xi]));
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>( \
const int axis, \
const int num_dims, \
const T value, \
const int64_t* dims, \
const int64_t* y_strides, \
const int64_t* index, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_dims; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> Y_strides; \
const auto N = \
std::accumulate(dims, dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_dims.data[i] = dims[i]; \
Y_strides.data[i] = y_strides[i]; \
} \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, axis, num_dims, value, X_dims, Y_strides, index, y); \
}
DEFINE_KERNEL_LAUNCHER(ScatterElements, bool);
DEFINE_KERNEL_LAUNCHER(ScatterElements, uint8_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int8_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int64_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, float16);
DEFINE_KERNEL_LAUNCHER(ScatterElements, float);
DEFINE_KERNEL_LAUNCHER(ScatterElements, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CUDAContext>( \
const int axis, \
const int num_dims, \
const int64_t* dims, \
const int64_t* x_strides, \
const int64_t* y_strides, \
const int64_t* index, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_dims; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> Y_strides; \
const auto N = \
std::accumulate(dims, dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_dims.data[i] = dims[i]; \
X_strides.data[i] = x_strides[i]; \
Y_strides.data[i] = y_strides[i]; \
} \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, axis, num_dims, X_dims, X_strides, Y_strides, index, x, y); \
}
DEFINE_KERNEL_LAUNCHER(ScatterElements, bool);
DEFINE_KERNEL_LAUNCHER(ScatterElements, uint8_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int8_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int);
DEFINE_KERNEL_LAUNCHER(ScatterElements, int64_t);
DEFINE_KERNEL_LAUNCHER(ScatterElements, float16);
DEFINE_KERNEL_LAUNCHER(ScatterElements, float);
DEFINE_KERNEL_LAUNCHER(ScatterElements, double);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T, AccT) \
template <> \
void name<T, AccT, CUDAContext>( \
const int axis, \
const int num_dims, \
const int64_t* dims, \
const int64_t* x_strides, \
const int64_t* y_strides, \
const int64_t* index, \
const T* x, \
AccT* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_dims; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> Y_strides; \
const auto N = \
std::accumulate(dims, dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_dims.data[i] = dims[i]; \
X_strides.data[i] = x_strides[i]; \
Y_strides.data[i] = y_strides[i]; \
} \
_##name<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, \
axis, \
num_dims, \
X_dims, \
X_strides, \
Y_strides, \
index, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
y); \
}
DEFINE_KERNEL_LAUNCHER(ScatterAdd, uint8_t, uint8_t);
DEFINE_KERNEL_LAUNCHER(ScatterAdd, int8_t, int8_t);
DEFINE_KERNEL_LAUNCHER(ScatterAdd, int, int)
DEFINE_KERNEL_LAUNCHER(ScatterAdd, int64_t, int64_t)
DEFINE_KERNEL_LAUNCHER(ScatterAdd, float16, float);
DEFINE_KERNEL_LAUNCHER(ScatterAdd, float, float)
DEFINE_KERNEL_LAUNCHER(ScatterAdd, double, float);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
......@@ -9,21 +9,19 @@ namespace {
template <typename T>
void _L1Normalize(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const T normalizer,
const T epsilon,
const T* x,
T* y) {
const auto dim = reduce_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < inner_dim; ++j) {
auto offset = i * dim + j;
auto X = ConstEigenStridedVectorMap<T>(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
EigenStridedVectorMap<T>(
y + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
const auto CxS = C * S;
for (int i = 0; i < N; ++i) {
for (int j = 0; j < S; ++j) {
auto offset = i * CxS + j;
ConstEigenStridedVectorMap<T> X(x + offset, 1, C, EigenInnerStride(S));
EigenStridedVectorMap<T>(y + offset, 1, C, EigenInnerStride(S)) =
X / std::max(X.template lpNorm<1>() / normalizer, epsilon);
}
}
......@@ -31,21 +29,19 @@ void _L1Normalize(
template <typename T>
void _L2Normalize(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const T normalizer,
const T epsilon,
const T* x,
T* y) {
const auto dim = reduce_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < inner_dim; ++j) {
auto offset = i * dim + j;
auto X = ConstEigenStridedVectorMap<T>(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
EigenStridedVectorMap<T>(
y + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
const auto CxS = C * S;
for (int i = 0; i < N; ++i) {
for (int j = 0; j < S; ++j) {
auto offset = i * CxS + j;
ConstEigenStridedVectorMap<T> X(x + offset, 1, C, EigenInnerStride(S));
EigenStridedVectorMap<T>(y + offset, 1, C, EigenInnerStride(S)) =
X / std::max(std::sqrt(X.squaredNorm() / normalizer), epsilon);
}
}
......@@ -53,26 +49,23 @@ void _L2Normalize(
template <typename T>
void _L1NormalizeGrad(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const T normalizer,
const T epsilon,
const T* dy,
const T* x,
T* dx) {
const auto dim = reduce_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < inner_dim; ++j) {
auto offset = i * dim + j;
auto dY = ConstEigenStridedVectorMap<T>(
dy + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
auto X = ConstEigenStridedVectorMap<T>(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
const auto CxS = C * S;
for (int i = 0; i < N; ++i) {
for (int j = 0; j < S; ++j) {
auto offset = i * CxS + j;
ConstEigenStridedVectorMap<T> dY(dy + offset, 1, C, EigenInnerStride(S));
ConstEigenStridedVectorMap<T> X(x + offset, 1, C, EigenInnerStride(S));
auto norm = std::max(X.template lpNorm<1>() / normalizer, epsilon);
auto norm2 = std::pow(norm, T(2));
EigenStridedVectorMap<T>(
dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
EigenStridedVectorMap<T>(dx + offset, 1, C, EigenInnerStride(S)) =
(dY / norm) -
(X.array().sign().matrix() / norm2) * dY.dot(X) / normalizer;
}
......@@ -81,26 +74,23 @@ void _L1NormalizeGrad(
template <typename T>
void _L2NormalizeGrad(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const T normalizer,
const T epsilon,
const T* dy,
const T* x,
T* dx) {
const auto dim = reduce_dim * inner_dim;
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < inner_dim; ++j) {
auto offset = i * dim + j;
auto dY = ConstEigenStridedVectorMap<T>(
dy + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
auto X = ConstEigenStridedVectorMap<T>(
x + offset, 1, reduce_dim, EigenInnerStride(inner_dim));
const auto CxS = C * S;
for (int i = 0; i < N; ++i) {
for (int j = 0; j < S; ++j) {
auto offset = i * CxS + j;
ConstEigenStridedVectorMap<T> dY(dy + offset, 1, C, EigenInnerStride(S));
ConstEigenStridedVectorMap<T> X(x + offset, 1, C, EigenInnerStride(S));
auto norm = std::max(std::sqrt(X.squaredNorm() / normalizer), epsilon);
auto norm3 = std::pow(norm, T(3));
EigenStridedVectorMap<T>(
dx + offset, 1, reduce_dim, EigenInnerStride(inner_dim)) =
EigenStridedVectorMap<T>(dx + offset, 1, C, EigenInnerStride(S)) =
(dY / norm) - ((X / norm3) * dY.dot(X) / normalizer);
}
}
......@@ -112,9 +102,9 @@ void _L2NormalizeGrad(
template <>
void L1Normalize<float16, CPUContext>(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const float normalizer,
const float epsilon,
const float16* x,
......@@ -125,9 +115,9 @@ void L1Normalize<float16, CPUContext>(
template <>
void L2Normalize<float16, CPUContext>(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const float normalizer,
const float epsilon,
const float16* x,
......@@ -138,9 +128,9 @@ void L2Normalize<float16, CPUContext>(
template <>
void L1NormalizeGrad<float16, CPUContext>(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const float normalizer,
const float epsilon,
const float16* dy,
......@@ -152,9 +142,9 @@ void L1NormalizeGrad<float16, CPUContext>(
template <>
void L2NormalizeGrad<float16, CPUContext>(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const float normalizer,
const float epsilon,
const float16* dy,
......@@ -167,30 +157,30 @@ void L2NormalizeGrad<float16, CPUContext>(
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int reduce_dim, \
const int N, \
const int S, \
const int C, \
const float normalizer, \
const float eps, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_##name<T>(outer_dim, inner_dim, reduce_dim, normalizer, eps, x, y); \
_##name<T>(N, S, C, normalizer, eps, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int reduce_dim, \
const int N, \
const int S, \
const int C, \
const float normalizer, \
const float eps, \
const T* dy, \
const T* x, \
T* dx, \
CPUContext* ctx) { \
_##name<T>(outer_dim, inner_dim, reduce_dim, normalizer, eps, dy, x, dx); \
_##name<T>(N, S, C, normalizer, eps, dy, x, dx); \
}
DEFINE_KERNEL_LAUNCHER(L1Normalize, float);
......
......@@ -13,67 +13,67 @@ namespace {
template <typename T, typename AccT>
__global__ void _L1Normalize(
const int kBlocks,
const int inner_dim,
const int reduce_dim,
const int NxS,
const int S,
const int C,
const AccT normalizer,
const AccT epsilon,
const T* x,
T* y) {
__shared__ AccT norm;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, kBlocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
CUDA_2D_KERNEL_LOOP1(i, NxS) {
auto offset = i / S * C * S + i % S;
AccT sum = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
sum += abs(convert::To<AccT>(x[offset + j * inner_dim]));
CUDA_2D_KERNEL_LOOP2(j, C) {
sum += abs(convert::To<AccT>(x[offset + j * S]));
}
sum = BlockReduce<AccT>(storage).Sum(sum);
if (threadIdx.x == 0) {
norm = max(sum / normalizer, epsilon);
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
y[idx] = convert::To<T>(convert::To<AccT>(x[idx]) / norm);
CUDA_2D_KERNEL_LOOP2(j, C) {
auto index = offset + j * S;
y[index] = convert::To<T>(convert::To<AccT>(x[index]) / norm);
}
}
}
template <typename T, typename AccT>
__global__ void _L2Normalize(
const int kBlocks,
const int inner_dim,
const int reduce_dim,
const int NxS,
const int S,
const int C,
const AccT normalizer,
const AccT epsilon,
const T* x,
T* y) {
__shared__ AccT norm;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, kBlocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
CUDA_2D_KERNEL_LOOP1(i, NxS) {
auto offset = i / S * C * S + i % S;
AccT sum = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
sum += math::utils::Square(convert::To<AccT>(x[offset + j * inner_dim]));
CUDA_2D_KERNEL_LOOP2(j, C) {
sum += math::utils::Square(convert::To<AccT>(x[offset + j * S]));
}
sum = BlockReduce<AccT>(storage).Sum(sum);
if (threadIdx.x == 0) {
norm = max(sqrt(sum / normalizer), epsilon);
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
y[idx] = convert::To<T>(convert::To<AccT>(x[idx]) / norm);
CUDA_2D_KERNEL_LOOP2(j, C) {
auto index = offset + j * S;
y[index] = convert::To<T>(convert::To<AccT>(x[index]) / norm);
}
}
}
template <typename T, typename AccT>
__global__ void _L1NormalizeGrad(
const int kBlocks,
const int inner_dim,
const int reduce_dim,
const int NxS,
const int S,
const int C,
const AccT normalizer,
const AccT epsilon,
const T* dy,
......@@ -81,13 +81,13 @@ __global__ void _L1NormalizeGrad(
T* dx) {
__shared__ AccT norm, norm2, sum;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, kBlocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
CUDA_2D_KERNEL_LOOP1(i, NxS) {
auto offset = i / S * C * S + i % S;
AccT val1 = AccT(0), val2 = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
val1 += abs(convert::To<AccT>(x[idx]));
val2 += convert::To<AccT>(dy[idx]) * convert::To<AccT>(x[idx]);
CUDA_2D_KERNEL_LOOP2(j, C) {
auto index = offset + j * S;
val1 += abs(convert::To<AccT>(x[index]));
val2 += convert::To<AccT>(dy[index]) * convert::To<AccT>(x[index]);
}
val1 = BlockReduce<AccT>(storage).Sum(val1);
val2 = BlockReduce<AccT>(storage).Sum(val2);
......@@ -97,20 +97,20 @@ __global__ void _L1NormalizeGrad(
sum = val2 / normalizer;
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
dx[idx] = convert::To<T>(
(convert::To<AccT>(dy[idx]) / norm) -
((math::utils::Sign(convert::To<AccT>(x[idx])) / norm2) * sum));
CUDA_2D_KERNEL_LOOP2(j, C) {
auto index = offset + j * S;
dx[index] = convert::To<T>(
(convert::To<AccT>(dy[index]) / norm) -
((math::utils::Sign(convert::To<AccT>(x[index])) / norm2) * sum));
}
}
}
template <typename T, typename AccT>
__global__ void _L2NormalizeGrad(
const int kBlocks,
const int inner_dim,
const int reduce_dim,
const int NxS,
const int S,
const int C,
const AccT normalizer,
const AccT epsilon,
const T* dy,
......@@ -118,13 +118,13 @@ __global__ void _L2NormalizeGrad(
T* dx) {
__shared__ AccT norm, norm3, sum;
__shared__ typename BlockReduce<AccT>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, kBlocks) {
auto offset = (i / inner_dim) * reduce_dim * inner_dim + (i % inner_dim);
CUDA_2D_KERNEL_LOOP1(i, NxS) {
auto offset = i / S * C * S + i % S;
AccT val1 = AccT(0), val2 = AccT(0);
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
val1 += math::utils::Square(convert::To<AccT>(x[idx]));
val2 += convert::To<AccT>(dy[idx]) * convert::To<AccT>(x[idx]);
CUDA_2D_KERNEL_LOOP2(j, C) {
auto index = offset + j * S;
val1 += math::utils::Square(convert::To<AccT>(x[index]));
val2 += convert::To<AccT>(dy[index]) * convert::To<AccT>(x[index]);
}
val1 = BlockReduce<AccT>(storage).Sum(val1);
val2 = BlockReduce<AccT>(storage).Sum(val2);
......@@ -134,11 +134,11 @@ __global__ void _L2NormalizeGrad(
sum = val2 / normalizer;
}
__syncthreads();
CUDA_2D_KERNEL_LOOP2(j, reduce_dim) {
auto idx = offset + j * inner_dim;
dx[idx] = convert::To<T>(
(convert::To<AccT>(dy[idx]) / norm) -
((convert::To<AccT>(x[idx]) / norm3) * sum));
CUDA_2D_KERNEL_LOOP2(j, C) {
auto index = offset + j * S;
dx[index] = convert::To<T>(
(convert::To<AccT>(dy[index]) / norm) -
((convert::To<AccT>(x[index]) / norm3) * sum));
}
}
}
......@@ -150,20 +150,20 @@ __global__ void _L2NormalizeGrad(
#define DEFINE_KERNEL_LAUNCHER(name, T, AccT) \
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int reduce_dim, \
const int N, \
const int S, \
const int C, \
const float normalizer, \
const float epsilon, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const auto kBlocks = outer_dim * inner_dim; \
const auto NxS = N * S; \
_##name<math::ScalarType<T>::type, AccT> \
<<<CUDA_2D_BLOCKS(kBlocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
kBlocks, \
inner_dim, \
reduce_dim, \
<<<CUDA_2D_BLOCKS(NxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxS, \
S, \
C, \
AccT(normalizer), \
AccT(epsilon), \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
......@@ -173,21 +173,21 @@ __global__ void _L2NormalizeGrad(
#define DEFINE_GRAD_KERNEL_LAUNCHER(name, T, AccT) \
template <> \
void name<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int reduce_dim, \
const int N, \
const int S, \
const int C, \
const float normalizer, \
const float epsilon, \
const T* dy, \
const T* x, \
T* dx, \
CUDAContext* ctx) { \
const auto kBlocks = outer_dim * inner_dim; \
const auto NxS = N * S; \
_##name<math::ScalarType<T>::type, AccT> \
<<<CUDA_2D_BLOCKS(kBlocks), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
kBlocks, \
inner_dim, \
reduce_dim, \
<<<CUDA_2D_BLOCKS(NxS), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
NxS, \
S, \
C, \
AccT(normalizer), \
AccT(epsilon), \
reinterpret_cast<const math::ScalarType<T>::type*>(dy), \
......
#include "dragon/core/workspace.h"
#include "dragon/operators/array/gather_ops.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void GatherElementsOp<Context>::DoRunWithType() {
SET_INPUT_SPEC(0);
auto &X = Input(0), &X_index = Input(1), *Y = Output(0);
GET_OP_AXIS_ARG(axis, X.ndim(), 0);
CHECK_GT(X_index.count(), 0) << "\nLength of index must > 0.";
CHECK_EQ(X.ndim(), X_index.ndim())
<< "\nMismatched number of dimensions between input and index.";
for (int i = 0; i < X.ndim(); ++i) {
if (i != axis) CHECK_EQ(X_index.dim(i), X.dim(i));
}
kernels::GatherElements(
axis,
X.ndim(),
X.strides().data(),
X_index.dims().data(),
X_index.template data<int64_t, Context>(),
X.template data<T, Context>(),
Y->ReshapeLike(X_index)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
template <typename T>
void GatherElementsGradientOp<Context>::DoRunWithType() {
auto &X_index = Input(0), &dY = Input(1);
auto &X_ref = INPUT_SPEC(0), *dX = Output(0);
GET_OP_AXIS_ARG(axis, X_ref.ndim(), 0);
auto* dx = dX->ReshapeLike(X_ref)->template mutable_data<T, Context>();
auto* dx_acc = (TypeMeta::Id<T>() == TypeMeta::Id<float>())
? (float*)nullptr
: ctx()->workspace()->template data<float, Context>({dX->count()})[0];
// Empty gradient
math::Set(
dX->count(),
0.f,
dx_acc != nullptr ? dx_acc : reinterpret_cast<float*>(dx),
ctx());
// Scatter and accumulate to dX
kernels::ScatterAdd(
axis,
X_ref.ndim(),
X_index.dims().data(),
X_index.strides().data(),
X_ref.strides().data(),
X_index.template data<int64_t, Context>(),
dY.template data<T, Context>(),
dx_acc != nullptr ? dx_acc : reinterpret_cast<float*>(dx),
ctx());
// Convert to dX if necessary
if (dx_acc != nullptr) {
math::Cast(dX->count(), dx_acc, dx, ctx());
}
}
DEPLOY_CPU_OPERATOR(GatherElements);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(GatherElements);
#endif
DEPLOY_CPU_OPERATOR(GatherElementsGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(GatherElementsGradient);
#endif
OPERATOR_SCHEMA(GatherElements)
/* X, X_index */
.NumInputs(2)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(GatherElementsGradient)
/* X_index, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1);
namespace {
class GradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
void CreateGradientDefs() override {
AddGradientDef(
def().type() + "Gradient",
"",
vector<string>({I(1), GO(0)}),
vector<string>({GI(0)}));
}
};
} // namespace
REGISTER_GRADIENT(GatherElements, GradientMaker);
} // namespace dragon
......@@ -14,8 +14,6 @@ void GatherOp<Context>::DoRunWithType() {
SET_INPUT_SPEC(0);
CHECK_GT(X_index.count(), 0) << "\nLength of index must > 0.";
CHECK(X_index.template IsType<int64_t>()) << "\nExcepted int64 index.";
vec64_t X_dims(X.dims());
vec64_t Y_dims(X_dims.begin(), X_dims.begin() + axis);
Y_dims.insert(Y_dims.end(), X_index.dims().begin(), X_index.dims().end());
......
......@@ -41,6 +41,34 @@ class GatherGradientOp final : public Operator<Context> {
void DoRunWithType();
};
template <class Context>
class GatherElementsOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(GatherElementsOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
};
template <class Context>
class GatherElementsGradientOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(GatherElementsGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(1));
}
template <typename T>
void DoRunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_GATHER_OPS_H_
#include "dragon/core/workspace.h"
#include "dragon/operators/array/scatter_ops.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void ScatterAddOp<Context>::DoRunWithType() {
SET_INPUT_SPEC(2);
auto &X = Input(0), *Y = Output(0);
auto &X_index = Input(1), &X_value = Input(2);
GET_OP_AXIS_ARG(axis, X.ndim(), 0);
CHECK_GT(X_index.count(), 0) << "\nLength of index must > 0.";
CHECK_EQ(X.ndim(), X_index.ndim())
<< "\nMismatched number of dimensions between input and index.";
CHECK_EQ(X_index.ndim(), X_value.ndim())
<< "\nMismatched number of dimensions between index and value.";
for (int i = 0; i < X.ndim(); ++i) {
CHECK_LE(X_index.dim(i), X_value.dim(i));
if (i != axis) CHECK_LE(X_index.dim(i), X_value.dim(i));
}
// Copy the input data.
Y->ReshapeLike(X)->CopyFrom(X, ctx());
// Add the new data.
kernels::ScatterAdd(
axis,
X.ndim(),
X_index.dims().data(),
X_value.strides().data(),
X.strides().data(),
X_index.template data<int64_t, Context>(),
X_value.template data<T, Context>(),
Y->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
template <typename T>
void ScatterAddOp<Context>::DoRunWithTypeAndCast() {
SET_INPUT_SPEC(2);
auto &X = Input(0), *Y = Output(0);
auto &X_index = Input(1), &X_value = Input(2);
GET_OP_AXIS_ARG(axis, X.ndim(), 0);
CHECK_GT(X_index.count(), 0) << "\nLength of index must > 0.";
CHECK_EQ(X.ndim(), X_index.ndim())
<< "\nMismatched number of dimensions between input and index.";
CHECK_EQ(X_index.ndim(), X_value.ndim())
<< "\nMismatched number of dimensions between index and value.";
for (int i = 0; i < X.ndim(); ++i) {
CHECK_LE(X_index.dim(i), X_value.dim(i));
if (i != axis) CHECK_LE(X_index.dim(i), X_value.dim(i));
}
// Copy the input data.
auto* y = ctx()->workspace()->template data<float, Context>({X.count()})[0];
math::Cast(X.count(), X.template data<T, Context>(), y, ctx());
// Add the new data.
kernels::ScatterAdd(
axis,
X.ndim(),
X_index.dims().data(),
X_value.strides().data(),
X.strides().data(),
X_index.template data<int64_t, Context>(),
X_value.template data<T, Context>(),
y,
ctx());
// Convert to Y.
math::Cast(
X.count(),
y,
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void ScatterAddOp<Context>::RunOnDevice() {
auto& X = Input(0);
if (X.template IsType<float16>()) {
DoRunWithTypeAndCast<float16>();
} else if (X.template IsType<double>()) {
DoRunWithTypeAndCast<double>();
} else {
using Types = dtypes::TypesBase<uint8_t, int8_t, int, int64_t, float>;
DispatchHelper<Types>::Call(this, X);
}
}
template <class Context>
template <typename T>
void ScatterAddGradientOp<Context>::DoRunWithType() {
auto &X_index = Input(0), &dY = Input(1);
auto *dX = Output(0), *dX_value = Output(1);
GET_OP_AXIS_ARG(axis, dY.ndim(), 0);
if (dX_value->has_name()) {
auto& X_value_ref = INPUT_SPEC(2);
for (int i = 0; i < X_index.ndim(); ++i) {
CHECK_EQ(X_index.dim(i), X_value_ref.dim(i));
if (i != axis) CHECK_EQ(X_index.dim(i), dY.dim(i));
}
kernels::GatherElements(
axis,
dY.ndim(),
dY.strides().data(),
X_index.dims().data(),
X_index.template data<int64_t, Context>(),
dY.template data<T, Context>(),
dX_value->ReshapeLike(X_index)->template mutable_data<T, Context>(),
ctx());
}
if (dX->has_name()) {
dX->ReshapeLike(dY)->CopyFrom(dY, ctx());
}
}
DEPLOY_CPU_OPERATOR(ScatterAdd);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ScatterAdd);
#endif
DEPLOY_CPU_OPERATOR(ScatterAddGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ScatterAddGradient);
#endif
OPERATOR_SCHEMA(ScatterAdd)
/* X, X_index, X_value */
.NumInputs(3)
/* Y */
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ScatterAddGradient)
/* X_index, dY */
.NumInputs(2)
/* dX, dX_value */
.NumOutputs(2)
/* dY => dX */
.AllowInplace({{1, 0}});
namespace {
class GradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
void CreateGradientDefs() override {
AddGradientDef(
def().type() + "Gradient",
"",
vector<string>({I(1), GO(0)}),
vector<string>({GI(0), GI(2)}));
}
};
} // namespace
REGISTER_GRADIENT(ScatterAdd, GradientMaker);
} // namespace dragon
#include "dragon/core/workspace.h"
#include "dragon/operators/array/scatter_ops.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void ScatterElementsOp<Context>::DoRunWithType() {
SET_INPUT_SPEC(2);
auto &X = Input(0), *Y = Output(0);
auto &X_index = Input(1), &X_value = Input(2);
GET_OP_AXIS_ARG(axis, X.ndim(), 0);
CHECK_GT(X_index.count(), 0) << "\nLength of index must > 0.";
CHECK_EQ(X.ndim(), X_index.ndim())
<< "\nMismatched number of dimensions between input and index.";
CHECK_EQ(X_index.ndim(), X_value.ndim())
<< "\nMismatched number of dimensions between index and value.";
for (int i = 0; i < X.ndim(); ++i) {
CHECK_LE(X_index.dim(i), X_value.dim(i));
if (i != axis) CHECK_LE(X_index.dim(i), X_value.dim(i));
}
// Copy the input data
Y->ReshapeLike(X)->CopyFrom(X, ctx());
// Update with the new data
kernels::ScatterElements(
axis,
X.ndim(),
X_index.dims().data(),
X_value.strides().data(),
X.strides().data(),
X_index.template data<int64_t, Context>(),
X_value.template data<T, Context>(),
Y->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
template <typename T>
void ScatterElementsGradientOp<Context>::DoRunWithType() {
auto &X_index = Input(0), &dY = Input(1);
auto *dX = Output(0), *dX_value = Output(1);
GET_OP_AXIS_ARG(axis, dY.ndim(), 0);
if (dX_value->has_name()) {
auto& X_value_ref = INPUT_SPEC(2);
for (int i = 0; i < X_index.ndim(); ++i) {
CHECK_EQ(X_index.dim(i), X_value_ref.dim(i));
if (i != axis) CHECK_EQ(X_index.dim(i), dY.dim(i));
}
kernels::GatherElements(
axis,
dY.ndim(),
dY.strides().data(),
X_index.dims().data(),
X_index.template data<int64_t, Context>(),
dY.template data<T, Context>(),
dX_value->ReshapeLike(X_index)->template mutable_data<T, Context>(),
ctx());
}
if (dX->has_name()) {
dX->ReshapeLike(dY)->CopyFrom(dY, ctx());
kernels::ScatterElements(
axis,
dY.ndim(),
convert::To<T>(0.f),
X_index.dims().data(),
dY.strides().data(),
X_index.template data<int64_t, Context>(),
dX->template mutable_data<T, Context>(),
ctx());
}
}
DEPLOY_CPU_OPERATOR(ScatterElements);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ScatterElements);
#endif
DEPLOY_CPU_OPERATOR(ScatterElementsGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(ScatterElementsGradient);
#endif
OPERATOR_SCHEMA(ScatterElements)
/* X, X_index, X_value */
.NumInputs(3)
/* Y */
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ScatterElementsGradient)
/* X_index, dY */
.NumInputs(2)
/* dX, dX_value */
.NumOutputs(2)
/* dY => dX */
.AllowInplace({{1, 0}});
namespace {
class GradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
void CreateGradientDefs() override {
AddGradientDef(
def().type() + "Gradient",
"",
vector<string>({I(1), GO(0)}),
vector<string>({GI(0), GI(2)}));
}
};
} // namespace
REGISTER_GRADIENT(ScatterElements, GradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_SCATTER_OPS_H_
#define DRAGON_OPERATORS_ARRAY_SCATTER_OPS_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class ScatterElementsOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(ScatterElementsOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
};
template <class Context>
class ScatterElementsGradientOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(ScatterElementsGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(1));
}
template <typename T>
void DoRunWithType();
};
template <class Context>
class ScatterAddOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(ScatterAddOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
template <typename T>
void DoRunWithTypeAndCast();
};
template <class Context>
class ScatterAddGradientOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(ScatterAddGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(1));
}
template <typename T>
void DoRunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_SCATTER_OPS_H_
......@@ -67,6 +67,7 @@ from dragon.core.ops.array_ops import concat
from dragon.core.ops.array_ops import expand_dims
from dragon.core.ops.array_ops import flatten
from dragon.core.ops.array_ops import gather
from dragon.core.ops.array_ops import gather_elements
from dragon.core.ops.array_ops import identity
from dragon.core.ops.array_ops import linspace
from dragon.core.ops.array_ops import nonzero
......@@ -75,6 +76,8 @@ from dragon.core.ops.array_ops import pad
from dragon.core.ops.array_ops import range
from dragon.core.ops.array_ops import repeat
from dragon.core.ops.array_ops import reshape
from dragon.core.ops.array_ops import scatter_add
from dragon.core.ops.array_ops import scatter_elements
from dragon.core.ops.array_ops import shape
from dragon.core.ops.array_ops import slice
from dragon.core.ops.array_ops import sort
......
......@@ -498,6 +498,11 @@ def roi_pool_args(**kwargs):
}
@register(['ScatterElements', 'ScatterAdd', 'GatherElements'])
def scatter_gather_elements_args(**kwargs):
return {'axis': kwargs.get('axis', 0)}
@register('Selu')
def selu_args(**kwargs):
return {
......@@ -556,6 +561,16 @@ def stack_args(**kwargs):
return {'axis': kwargs.get('axis', 0)}
@register('SyncBatchNorm')
def sync_batch_norm_args(**kwargs):
return {**batch_norm_args(**kwargs), **{
'comm': kwargs.get('comm', 0),
'group': kwargs.get('group', 0),
'backend': kwargs.get('backend', 'MPI'),
'ranks': kwargs.get('ranks', None),
}}
@register('Tile')
def tile_args(**kwargs):
return {'repeats_desc': 'int64' if kwargs.get('ndim', 0) > 0 else None}
......
......@@ -97,7 +97,7 @@ class OpLib(object):
grad_tape = tape.get_tape()
# Add inputs.
enable_grad = False #
enable_grad = False
inputs = nest.flatten(inputs)
for input in inputs:
op_tape.add_source(input)
......
......@@ -414,6 +414,16 @@ def gather_spec(args, inputs, outputs):
return outputs
@register('GatherElements')
def gather_elements_spec(args, inputs, outputs):
outputs[0]._dtype = inputs[0].dtype
try:
outputs[0]._shape = inputs[1]._shape[:]
except TypeError:
pass
return outputs
@register(['IsInf', 'IsNaN', 'Not'])
def is_spec(args, inputs, outputs):
outputs[0]._dtype = 'bool'
......
......@@ -14,8 +14,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
from dragon.core.framework import context
from dragon.core.framework import device_spec
from dragon.core.framework import types
......@@ -179,7 +177,7 @@ class Tensor(types.TensorBase):
if self._shape is None:
return 0
if None in self._shape:
return numpy.inf
return float('inf')
return math_util.prod(self._shape)
def astype(self, dtype, copy=True):
......@@ -434,6 +432,25 @@ class Tensor(types.TensorBase):
if self._is_variable and self._deleter:
self._deleter.release(self._impl.name)
def __eq__(self, other):
"""Compute element-wise equal comparison.
Parameters
----------
other : Union[dragon.Tensor, number]
The value to compare.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.math.equal(...)`_
"""
def __float__(self):
"""Return a float python scalar.
......@@ -494,6 +511,7 @@ class Tensor(types.TensorBase):
"""
def __hash__(self):
"""Return the hashable identity."""
return id(self)
def __iadd__(self, other):
......@@ -687,6 +705,25 @@ class Tensor(types.TensorBase):
"""
def __ne__(self, other):
"""Compute element-wise not-equal comparison.
Parameters
----------
other : Union[dragon.Tensor, number]
The value to compare.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.math.not_equal(...)`_
"""
def __or__(self, other):
"""Compute the element-wise OR bitwise operation.
......
......@@ -558,7 +558,7 @@ def flatten(inputs, axis=0, end_axis=-1, copy=True, **kwargs):
@OpSchema.num_inputs(2)
def gather(inputs, axis=0, end_axis=None, **kwargs):
"""Gather the elements along the given axis using index.
"""Gather elements along the given axis using index.
Index should be a ``int64`` tensor:
......@@ -595,6 +595,46 @@ def gather(inputs, axis=0, end_axis=None, **kwargs):
return OpLib.add('Gather', inputs, axis=axis, end_axis=end_axis, **kwargs)
@OpSchema.num_inputs(2)
def gather_elements(inputs, axis=0, **kwargs):
"""Gather elements along the given axis of index.
Number of dimensions of input and index should be same.
For 3-d input, output is gathered as:
```python
out[i, j, k] = input[index[i, j, k], j, k]
out[i, j, k] = input[i, index[i, j, k], k]
out[i, j, k] = input[i, j, index[i, j, k]]
```
Examples:
```python
x = dragon.constant([[1, 2], [3, 4]])
index = dragon.constant([[0, 0], [0, 1]])
print(dragon.gather_elements([x, index], axis=0)) # [[1, 2], [1, 4]]
print(dragon.gather_elements([x, index], axis=1)) # [[1, 1], [3, 4]]
```
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input and index tensor.
axis : int, optional, default=0
The axis of index values.
Returns
-------
dragon.Tensor
The output tensor.
"""
if context.executing_eagerly():
return OpLib.execute('GatherElements', inputs, axis=axis)
return OpLib.add('GatherElements', inputs, axis=axis, **kwargs)
@OpSchema.num_inputs(1)
def identity(inputs, **kwargs):
"""Return a tensor copied from the input.
......@@ -1187,6 +1227,96 @@ def reshape(inputs, shape, copy=True, **kwargs):
return OpLib.add('Reshape', **args)
@OpSchema.num_inputs(3)
def scatter_add(inputs, axis=0, copy=True, **kwargs):
"""Add elements along the given axis of index.
Number of dimensions of input and index should be same.
For 3-d input, output is updated as:
```python
out[index[i, j, k], j, k] += updates[i, j, k] # ``axis`` is 0
out[i, index[i, j, k], k] += updates[i, j, k] # ``axis`` is 1
out[i, j, index[i, j, k]] += updates[i, j, k] # ``axis`` is 2
```
Examples:
```python
y = dragon.constant([[1, 2], [3, 4]])
x = dragon.constant([[5, 6], [7, 8]])
index = dragon.constant([[0, 0], [0, 0]])
print(dragon.scatter_add([y, index, x], axis=0)) # [[13, 16], [3, 4]]
print(dragon.scatter_add([y, index, x], axis=1)) # [[12, 2], [18, 4]]
```
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input, index and updates tensor.
axis : int, optional, default=0
The axis of index values.
copy : bool, optional, default=True
Return a new tensor or call in-place.
Returns
-------
dragon.Tensor
The output tensor.
"""
if context.executing_eagerly():
return OpLib.execute(
'ScatterAdd', inputs,
outputs=[None] if copy else [inputs[0]], axis=axis)
return OpLib.add('ScatterAdd', inputs, axis=axis, **kwargs)
@OpSchema.num_inputs(3)
def scatter_elements(inputs, axis=0, copy=True, **kwargs):
"""Update elements along the given axis of index.
Number of dimensions of input and index should be same.
For 3-d input, output is updated as:
```python
out[index[i, j, k], j, k] = updates[i, j, k] # ``axis`` is 0
out[i, index[i, j, k], k] = updates[i, j, k] # ``axis`` is 1
out[i, j, index[i, j, k]] = updates[i, j, k] # ``axis`` is 2
```
Examples:
```python
y = dragon.constant([[1, 2], [3, 4]])
x = dragon.constant([[5, 6], [7, 8]])
index = dragon.constant([[0, 0], [0, 1]])
print(dragon.scatter_elements([y, index, x], axis=0)) # [[7, 6], [3, 8]]
print(dragon.scatter_elements([y, index, x], axis=1)) # [[6, 2], [7, 8]]
```
Parameters
----------
inputs : Sequence[dragon.Tensor]
The input, index and updates tensor.
axis : int, optional, default=0
The axis of index values.
copy : bool, optional, default=True
Return a new tensor or call in-place.
Returns
-------
dragon.Tensor
The output tensor.
"""
if context.executing_eagerly():
return OpLib.execute(
'ScatterElements', inputs,
outputs=[None] if copy else [inputs[0]], axis=axis)
return OpLib.add('ScatterElements', inputs, axis=axis, **kwargs)
@OpSchema.num_inputs(1)
def shape(inputs, **kwargs):
r"""Return the shape of input.
......
......@@ -283,21 +283,18 @@ def local_response_norm(
The output tensor.
"""
args = OpSchema.parse_args(locals())
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: %s' % data_format)
args['alpha'] = float(alpha)
args['beta'] = float(beta)
args['bias'] = float(bias)
alpha, beta, bias = float(alpha), float(beta), float(bias)
if context.executing_eagerly():
op = op_lib.create(
'LRN', size=size, alpha=args['alpha'], beta=args['beta'],
bias=args['bias'], data_format=data_format)
return op.execute([inputs])
return op_lib.symbolize('LRN', **args)
return OpLib.execute(
'LRN', inputs, size=size, alpha=alpha, beta=beta,
bias=bias, data_format=data_format)
return OpLib.add('LRN', inputs, size=size, alpha=alpha, beta=beta,
bias=bias, data_format=data_format, **kwargs)
@OpSchema.num_inputs(min_num=5, max_num=5)
@OpSchema.num_inputs(5)
@OpSchema.convert_arg('momentum', as_target=False)
def sync_batch_norm(
inputs,
......@@ -350,9 +347,10 @@ def sync_batch_norm(
if process_group is None:
raise ValueError('<process_group> is required.')
if context.executing_eagerly():
op = op_lib.create(
'SyncBatchNorm', axis=axis, epsilon=args['epsilon'],
use_stats=use_stats, process_group=process_group)
return op.execute(inputs, momentum=args['momentum'])
return OpLib.execute(
'SyncBatchNorm', inputs, axis=axis, epsilon=args['epsilon'],
use_stats=use_stats, momentum=args['momentum'],
**process_group.arguments)
args.pop('process_group')
args.update(process_group.arguments)
return op_lib.symbolize('SyncBatchNorm', **args)
return OpLib.add('SyncBatchNorm', **args)
......@@ -126,6 +126,27 @@ def div(self, other):
return _apply_binary_op([self, other], 'Div')
def eq(self, other):
"""Compute element-wise equal comparison.
Parameters
----------
other : Union[dragon.Tensor, number]
The value to compare.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.math.equal(...)`_
"""
return _apply_binary_op([self, other], 'Equal')
def fill(self, value):
r"""Fill self with a scalar value.
......@@ -516,6 +537,27 @@ def mul(self, other):
return _apply_binary_op([self, other], 'Mul')
def ne(self, other):
"""Compute element-wise not-equal comparison.
Parameters
----------
other : Union[dragon.Tensor, number]
The value to compare.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.math.not_equal(...)`_
"""
return _apply_binary_op([self, other], 'NotEqual')
def neg(self):
"""Compute the element-wise negative.
......@@ -946,6 +988,7 @@ Tensor.truncated_normal = truncated_normal
Tensor.uniform = uniform
Tensor.__add__ = add
Tensor.__and__ = _and
Tensor.__eq__ = eq
Tensor.__ge__ = ge
Tensor.__getitem__ = getitem
Tensor.__gt__ = gt
......@@ -961,6 +1004,7 @@ Tensor.__ixor__ = ixor
Tensor.__le__ = le
Tensor.__lt__ = lt
Tensor.__mul__ = mul
Tensor.__ne__ = ne
Tensor.__neg__ = neg
Tensor.__or__ = _or
Tensor.__radd__ = radd
......
......@@ -142,6 +142,19 @@ def eye_exporter(op_def, context):
return node, const_tensors
@export_util.register('Fill')
def fill_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'Constant'
shape = list(context.blob_shapes[op_def.output[0]])
value = helper.from_array(
numpy.array(shape, 'int64'),
context.unique_name(op_def.output[0] + '/constant/value'))
helper.add_attribute(node, 'value', value)
node.ClearField('input')
return node, [value]
@export_util.register('Flatten')
def flatten_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
......@@ -174,6 +187,20 @@ def gather_exporter(op_def, context):
return node, const_tensors
@export_util.register('GatherElements')
def gather_elements_exporter(op_def, context):
raise RuntimeError('<GatherElements> is supported since opset 11.')
@export_util.register('GatherElements-11')
def gather_elements_exporter_v11(op_def, context):
node, const_tensors = export_util.translate(**locals())
for arg in op_def.arg:
if arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
return node, const_tensors
@export_util.register('Multinomial')
def multinomial_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
......@@ -324,6 +351,27 @@ def reshape_exporter(op_def, context):
return node, [shape]
@export_util.register('ScatterElements')
def scatter_elements_exporter_v8(op_def, context):
raise RuntimeError('<Scatter> is supported since opset 9.')
@export_util.register('ScatterElements-9')
def scatter_elements_exporter_v9(op_def, context):
node, const_tensors = scatter_elements_exporter_v11(**locals())
node.op_type = 'Scatter'
return node, const_tensors
@export_util.register('ScatterElements-11')
def scatter_elements_exporter_v11(op_def, context):
node, const_tensors = export_util.translate(**locals())
for arg in op_def.arg:
if arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i)
return node, const_tensors
def slice_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
in_shape = context.blob_shapes[op_def.input[0]]
......
......@@ -46,6 +46,24 @@ struct AtomicIntegerFunctor<T, MathFunctor, 1> {
};
template <typename T, class MathFunctor>
struct AtomicIntegerFunctor<T, MathFunctor, 8> {
#if defined(__CUDACC__)
inline __device__ void operator()(T* address, T val) {
unsigned long long* address_as_ui = (unsigned long long*)address;
unsigned long long old = *address_as_ui;
unsigned long long newval;
unsigned long long assumed;
do {
assumed = old;
newval = static_cast<unsigned long long>(math_functor_(val, old));
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
}
#endif
MathFunctor math_functor_;
};
template <typename T, class MathFunctor>
struct AtomicFloat16Functor {
#if defined(__CUDACC__)
inline __device__ void operator()(T* address, T val) {
......@@ -103,6 +121,21 @@ inline __device__ void AtomicAdd(T* address, T val) {
atomicAdd(address, val);
}
inline __device__ void AtomicAdd(uint8_t* address, uint8_t val) {
AtomicIntegerFunctor<uint8_t, PlusFunctor<uint8_t>, sizeof(uint8_t)>()(
address, val);
}
inline __device__ void AtomicAdd(int8_t* address, int8_t val) {
AtomicIntegerFunctor<int8_t, PlusFunctor<uint8_t>, sizeof(int8_t)>()(
address, val);
}
inline __device__ void AtomicAdd(int64_t* address, int64_t val) {
AtomicIntegerFunctor<int64_t, PlusFunctor<int64_t>, sizeof(int64_t)>()(
address, val);
}
#if __CUDA_ARCH__ < 700
inline __device__ void AtomicAdd(half* address, half val) {
AtomicFloat16Functor<half, PlusFunctor<half>>()(address, val);
......
......@@ -132,7 +132,7 @@ DEFINE_UNARY_FUNC(Sign, double, [](double x) {
_SimpleUnaryFunc(N, Functor<InputT>(), x, y); \
}
DEFINE_UNARY_FUNC(BitwiseNot, bool, bool, std::bit_not);
DEFINE_UNARY_FUNC(BitwiseNot, bool, bool, std::logical_not);
DEFINE_UNARY_FUNC(BitwiseNot, uint8_t, uint8_t, std::bit_not);
DEFINE_UNARY_FUNC(BitwiseNot, int8_t, int8_t, std::bit_not);
DEFINE_UNARY_FUNC(BitwiseNot, int, int, std::bit_not);
......
......@@ -387,6 +387,17 @@ void GatherGrad(
Context* ctx);
template <typename T, class Context>
void GatherElements(
const int axis,
const int num_dims,
const int64_t* x_strides,
const int64_t* y_dims,
const int64_t* index,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void LinSpace(
const int N,
const int C,
......@@ -479,6 +490,41 @@ void RepeatGrad(
Context* ctx);
template <typename T, class Context>
void ScatterElements(
const int axis,
const int num_dims,
const T value,
const int64_t* dims,
const int64_t* y_strides,
const int64_t* index,
T* y,
Context* ctx);
template <typename T, class Context>
void ScatterElements(
const int axis,
const int num_dims,
const int64_t* dims,
const int64_t* x_strides,
const int64_t* y_strides,
const int64_t* index,
const T* x,
T* y,
Context* ctx);
template <typename T, typename AccT, class Context>
void ScatterAdd(
const int axis,
const int num_dims,
const int64_t* dims,
const int64_t* x_strides,
const int64_t* y_strides,
const int64_t* index,
const T* x,
AccT* y,
Context* ctx);
template <typename T, class Context>
void Slice(
const int num_dims,
const int64_t* x_strides,
......@@ -875,9 +921,9 @@ void GroupNormGrad(
template <typename T, class Context>
void L1Normalize(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const float normalizer,
const float epsilon,
const T* x,
......@@ -886,9 +932,9 @@ void L1Normalize(
template <typename T, class Context>
void L1NormalizeGrad(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const float normalizer,
const float epsilon,
const T* dy,
......@@ -898,9 +944,9 @@ void L1NormalizeGrad(
template <typename T, class Context>
void L2Normalize(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const float normalizer,
const float epsilon,
const T* x,
......@@ -909,9 +955,9 @@ void L2Normalize(
template <typename T, class Context>
void L2NormalizeGrad(
const int outer_dim,
const int inner_dim,
const int reduce_dim,
const int N,
const int S,
const int C,
const float normalizer,
const float epsilon,
const T* dy,
......
......@@ -249,6 +249,15 @@ class TestOpSpec(unittest.TestCase):
self.assertEqual(dragon.gather(
[self.sym3, self.sym2], axis=1).shape, (1, 1))
def test_gather_elements(self):
with dragon.graph_mode():
self.assertEqual(dragon.gather_elements(
[self.sym1, self.sym1]).shape, None)
self.assertEqual(dragon.gather_elements(
[self.sym1, self.sym2], axis=0).shape, self.sym2.shape)
self.assertEqual(dragon.gather_elements(
[self.sym1, self.sym3], axis=1).shape, self.sym3.shape)
def test_gemm(self):
w = dragon.Tensor((3, 2), symbolic=True)
with dragon.graph_mode():
......
......@@ -136,7 +136,7 @@ class TestTensor(unittest.TestCase):
def test_dlpack_converter_cuda(self):
data = np.array([0., 1., 2.], 'float32')
with dragon.device('cuda', 0):
x = dragon.constant(data, copy=True)
x = dragon.constant(data, copy=True) + 0
x_to_dlpack = dragon.dlpack.to_dlpack(x)
x_from_dlpack = dragon.dlpack.from_dlpack(x_to_dlpack)
self.assertEqual(x_from_dlpack.device.type, 'cuda')
......
......@@ -14,6 +14,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import math
import os
import unittest
......@@ -725,6 +726,34 @@ class TestArrayOps(OpTestCase):
with dragon.device('cuda'):
self.test_gather()
def test_gather_elements(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for axis in range(0, 1):
data1 = arange((2, 4))
data2 = np.array([[0, 1, 1, 0], [1, 1, 0, 0]])
x, index = new_tensor(data1), new_tensor(data2)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.gather_elements([x, index], axis=axis)
data3 = arange(y.shape, 100)
dy = new_tensor(data3)
dx, = tape.gradient(y, [x], output_gradients=[dy])
result, grad = np.zeros_like(data2), np.zeros_like(data1)
for i, j in itertools.product(*[range(d) for d in data2.shape]):
if axis == 0:
result[i, j] = data1[data2[i, j], j]
grad[data2[i, j], j] += data3[i, j]
else:
result[i, j] = data1[i, data2[i, j]]
grad[i, data2[i, j]] = data3[i, j]
self.assertEqual([y, dx], [result, grad])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_gather_elements_cuda(self):
with dragon.device('cuda'):
self.test_gather_elements()
def test_identity(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......@@ -2261,6 +2290,68 @@ class TestMathOps(OpTestCase):
with dragon.device('cuda'):
self.test_rsqrt()
def test_scatter_add(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for axis in range(0, 1):
data1 = arange((4, 4))
data2 = np.array([[0, 0, 2, 3], [0, 0, 3, 0],
[2, 3, 0, 1], [3, 0, 1, 2]])
data3 = arange((4, 4), 100)
x, index = new_tensor(data1), new_tensor(data2)
v = new_tensor(data3)
with dragon.GradientTape() as tape:
tape.watch([x, v])
y = dragon.scatter_add([x, index, v], axis=axis)
dx, dv = tape.gradient(y, [x, v], output_gradients=[x])
result, grad1 = data1.copy(), data1.copy()
grad2 = np.zeros_like(data3)
for i, j in itertools.product(*[range(d) for d in data2.shape]):
if axis == 0:
result[data2[i, j], j] += data3[i, j]
grad2[i, j] = data1[data2[i, j], j]
else:
result[i, data2[i, j]] += data3[i, j]
grad2[i, j] = data1[i, data2[i, j]]
self.assertEqual([y, dx, dv], [result, grad1, grad2])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_scatter_add_cuda(self):
with dragon.device('cuda'):
self.test_scatter_add()
def test_scatter_elements(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for axis in range(0, 1):
data1 = arange((4, 4))
data2 = np.array([[0, 1, 2, 3], [1, 2, 3, 0],
[2, 3, 0, 1], [3, 0, 1, 2]])
data3 = arange((4, 4), 100)
x, index = new_tensor(data1), new_tensor(data2)
v = new_tensor(data3)
with dragon.GradientTape() as tape:
tape.watch([x, v])
y = dragon.scatter_elements([x, index, v], axis=axis)
dx, dv = tape.gradient(y, [x, v], output_gradients=[x])
result, grad1 = data1.copy(), data1.copy()
grad2 = np.zeros_like(data3)
for i, j in itertools.product(*[range(d) for d in data2.shape]):
if axis == 0:
result[data2[i, j], j] = data3[i, j]
grad1[data2[i, j], j] = 0
grad2[i, j] = data1[data2[i, j], j]
else:
result[i, data2[i, j]] = data3[i, j]
grad1[i, data2[i, j]] = 0
grad2[i, j] = data1[i, data2[i, j]]
self.assertEqual([y, dx, dv], [result, grad1, grad2])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_scatter_elements_cuda(self):
with dragon.device('cuda'):
self.test_scatter_elements()
def test_sign(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......@@ -2832,6 +2923,14 @@ class TestTensorOps(OpTestCase):
a /= b
self.assertEqual(a, data1 / data2)
def test_eq(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for a_shape, b_shape in self.binary_test_shapes:
data1, data2 = uniform(a_shape), uniform(b_shape)
a, b = new_tensor(data1), new_tensor(data2)
self.assertEqual(a == b, data1 == data2)
def test_ge(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......@@ -2922,6 +3021,14 @@ class TestTensorOps(OpTestCase):
a *= b
self.assertEqual(a, data1 * data2)
def test_ne(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for a_shape, b_shape in self.binary_test_shapes:
data1, data2 = uniform(a_shape), uniform(b_shape)
a, b = new_tensor(data1), new_tensor(data2)
self.assertEqual(a != b, data1 != data2)
def test_neg(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......
......@@ -14,6 +14,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import math
import os
import unittest
......@@ -236,7 +237,7 @@ class TestTensorOps(OpTestCase):
data1 = uniform(a_shape)
data2 = dropout(data1, drop_ratio=0.5)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.eq(b), np.equal(data1, data2))
self.assertEqual(a == b, np.equal(data1, data2))
def test_exp(self):
data = np.array([0., 1., 2.], 'float32')
......@@ -294,6 +295,20 @@ class TestTensorOps(OpTestCase):
x.floor_()
self.assertEqual(x, np.floor(data))
def test_gather(self):
for axis in range(0, 1):
data1 = arange((2, 4))
data2 = np.array([[0, 1, 1, 0], [1, 1, 0, 0]])
x, index = new_tensor(data1), new_tensor(data2)
y = x.gather(axis, index)
result = np.zeros_like(data2)
for i, j in itertools.product(*[range(d) for d in data2.shape]):
if axis == 0:
result[i, j] = data1[data2[i, j], j]
else:
result[i, j] = data1[i, data2[i, j]]
self.assertEqual([y], [result])
def test_getitem(self):
data1, data2 = arange((2, 3)), arange((2,), dtype='int64')
x, index = new_tensor(data1), new_tensor(data2)
......@@ -518,7 +533,7 @@ class TestTensorOps(OpTestCase):
data1 = uniform(a_shape)
data2 = dropout(data1, drop_ratio=0.5)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.ne(b), np.not_equal(data1, data2))
self.assertEqual(a != b, np.not_equal(data1, data2))
def test_neg(self):
data = np.array([-1., 0., 1.], 'float32')
......@@ -605,6 +620,58 @@ class TestTensorOps(OpTestCase):
x.rsqrt_()
self.assertEqual(x, result)
def test_scatter(self):
for axis in range(0, 1):
data1 = arange((4, 4))
data2 = np.array([[0, 1, 2, 3], [1, 2, 3, 0],
[2, 3, 0, 1], [3, 0, 1, 2]])
data3 = arange((4, 4), 100)
x, index = new_tensor(data1), new_tensor(data2)
v = new_tensor(data3)
y = x.scatter(axis, index, v)
result = data1.copy()
for i, j in itertools.product(*[range(d) for d in data2.shape]):
if axis == 0:
result[data2[i, j], j] = data3[i, j]
else:
result[i, data2[i, j]] = data3[i, j]
self.assertEqual(y, result)
x.scatter_(axis, index, v)
self.assertEqual(x, result)
def test_scatter_add(self):
for axis in range(0, 1):
data1 = arange((4, 4))
data2 = np.array([[0, 0], [0, 0]])
data3 = arange((4, 4), 100)
x, index = new_tensor(data1), new_tensor(data2)
v = new_tensor(data3)
y = x.scatter_add(axis, index, v)
result = data1.copy()
for i, j in itertools.product(*[range(d) for d in data2.shape]):
if axis == 0:
result[data2[i, j], j] += data3[i, j]
else:
result[i, data2[i, j]] += data3[i, j]
self.assertEqual(y, result)
x.scatter_(axis, index, v, reduce='add')
self.assertEqual(x, result)
def test_scatter_mul(self):
for axis in range(0, 1):
data1 = arange((4, 4))
data2 = np.array([[0, 1, 2, 3], [1, 2, 3, 0],
[2, 3, 0, 1], [3, 0, 1, 2]])
x, index = new_tensor(data1), new_tensor(data2)
result = data1.copy()
for i, j in itertools.product(*[range(d) for d in data2.shape]):
if axis == 0:
result[data2[i, j], j] *= 2.33
else:
result[i, data2[i, j]] *= 2.33
x.scatter_(axis, index, 2.33, reduce='multiply')
self.assertEqual(x, result)
def test_setitem(self):
data = arange((2, 3))
x = new_tensor(data)
......
......@@ -55,6 +55,7 @@ class TestTensor(unittest.TestCase):
self.assertEqual(torch.empty(2, 3).ndimension(), 2)
self.assertEqual(torch.empty(3).new_empty(2, 3).ndimension(), 2)
self.assertEqual(repr(torch.tensor(1)), '1')
self.assertEqual(repr(torch.tensor(1).new_tensor(1)), '1')
self.assertNotEqual(a.__hash__(), b.__hash__())
self.assertNotEqual(a.__repr__(), b.__repr__())
self.assertEqual(torch.BoolTensor(1).dtype, 'bool')
......
......@@ -44,7 +44,6 @@ from dragon.vm.torch.core.tensor import LongTensor
from dragon.vm.torch.core.tensor import Tensor
# Functions
from dragon.vm.torch.core.cpp import from_numpy
from dragon.vm.torch.core.ops import tensor_ops as _
from dragon.vm.torch.core.ops.array_ops import argmax
from dragon.vm.torch.core.ops.array_ops import argmin
......@@ -57,6 +56,7 @@ from dragon.vm.torch.core.ops.array_ops import channel_shuffle
from dragon.vm.torch.core.ops.array_ops import chunk
from dragon.vm.torch.core.ops.array_ops import cumsum
from dragon.vm.torch.core.ops.array_ops import flatten
from dragon.vm.torch.core.ops.array_ops import gather
from dragon.vm.torch.core.ops.array_ops import index_select
from dragon.vm.torch.core.ops.array_ops import masked_select
from dragon.vm.torch.core.ops.array_ops import masked_fill
......@@ -69,6 +69,8 @@ from dragon.vm.torch.core.ops.array_ops import nonzero
from dragon.vm.torch.core.ops.array_ops import one_hot
from dragon.vm.torch.core.ops.array_ops import permute
from dragon.vm.torch.core.ops.array_ops import reshape
from dragon.vm.torch.core.ops.array_ops import scatter
from dragon.vm.torch.core.ops.array_ops import scatter_add
from dragon.vm.torch.core.ops.array_ops import sort
from dragon.vm.torch.core.ops.array_ops import split
from dragon.vm.torch.core.ops.array_ops import squeeze
......@@ -82,7 +84,10 @@ from dragon.vm.torch.core.ops.array_ops import triu
from dragon.vm.torch.core.ops.array_ops import unique
from dragon.vm.torch.core.ops.array_ops import unsqueeze
from dragon.vm.torch.core.ops.array_ops import where
from dragon.vm.torch.core.ops.constant_ops import from_numpy
from dragon.vm.torch.core.ops.constant_ops import tensor
from dragon.vm.torch.core.ops.init_ops import arange
from dragon.vm.torch.core.ops.init_ops import empty
from dragon.vm.torch.core.ops.init_ops import eye
from dragon.vm.torch.core.ops.init_ops import full
from dragon.vm.torch.core.ops.init_ops import full_like
......@@ -140,8 +145,6 @@ from dragon.vm.torch.core.ops.math_ops import sqrt
from dragon.vm.torch.core.ops.math_ops import sub
from dragon.vm.torch.core.serialization import load
from dragon.vm.torch.core.serialization import save
from dragon.vm.torch.core.tensor import empty
from dragon.vm.torch.core.tensor import tensor
# Aliases
bool = dtype('bool')
......
......@@ -13,11 +13,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
from dragon.core.framework import proto_util
from dragon.core.util import math_util
from dragon.vm.torch.core import tensor as tensor_module
class Size(tuple):
......@@ -177,22 +174,3 @@ class dtype(str):
"""
super(dtype, self).__init__()
def from_numpy(array):
"""Create a tensor from the given numpy array.
Parameters
----------
array : numpy.ndarray
The numpy array data.
Return
------
dragon.vm.torch.Tensor
The torch tensor.
"""
if not isinstance(array, numpy.ndarray):
raise TypeError('The <array> should be a numpy ndarray.')
return tensor_module.Tensor(array, copy=False)
......@@ -17,6 +17,7 @@ from __future__ import print_function
from dragon.core.util import nest
from dragon.vm.torch.core.autograd.function_impl import FunctionLib
from dragon.vm.torch.core.ops import constant_ops
from dragon.vm.torch.core.ops import init_ops
from dragon.vm.torch.core.tensor import Tensor
......@@ -424,6 +425,44 @@ def flatten(input, start_dim=0, end_dim=-1, out=None):
axis=start_dim, end_axis=end_dim)
def gather(input, dim, index, out=None):
"""Gather elements along the given dimension of index.
Number of dimensions of :attr:`input`, :attr:`index` should be same.
For 3-d input, output is gathered as:
```python
out[i, j, k] = input[index[i, j, k], j, k]
out[i, j, k] = input[i, index[i, j, k], k]
out[i, j, k] = input[i, j, index[i, j, k]]
```
Examples:
```python
x = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [0, 1]])
print(torch.gather(x, 0, index)) # [[1, 2], [1, 4]]
print(torch.gather(x, 1, index)) # [[1, 1], [3, 4]]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
out : dragon.vm.torch.Tensor, optional
The output tensor.
"""
return FunctionLib.apply(
'GatherElements', input.device, [input, index],
outputs=[out], axis=dim)
def index_select(input, dim, index, out=None):
"""Select the elements along the given dimension using index.
......@@ -859,6 +898,96 @@ def reshape(input, shape, out=None):
ndim=len(shape), dims=shape)
def scatter(input, dim, index, src, out=None):
"""Update elements along the given dimension of index.
Number of dimensions of :attr:`input`, :attr:`index`, and :attr:`src`
should be same. For 3-d input, output is updated as:
```python
out[index[i, j, k], j, k] = src[i, j, k] # ``dim`` is 0
out[i, index[i, j, k], k] = src[i, j, k] # ``dim`` is 1
out[i, j, index[i, j, k]] = src[i, j, k] # ``dim`` is 2
```
Examples:
```python
y = torch.tensor([[1, 2], [3, 4]])
x = torch.tensor([[5, 6], [7, 8]])
index = torch.tensor([[0, 1], [1, 0]])
print(torch.scatter(y, 0, index, x)) # [[5, 8], [7, 6]]
print(torch.scatter(y, 1, index, x)) # [[5, 6], [8, 7]]
print(torch.scatter(y, 0, index, 8)) # [[8, 8], [8, 8]]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to update from.
out : dragon.vm.torch.Tensor, optional
The output tensor.
"""
if not isinstance(src, Tensor):
src = init_ops.full_like(
index, src, dtype=input.dtype, device=input.device)
return FunctionLib.apply(
'ScatterElements', input.device, [input, index, src],
outputs=[out], axis=dim)
def scatter_add(input, dim, index, src, out=None):
"""Add elements along the given dimension of index.
Number of dimensions of :attr:`input`, :attr:`index`, and :attr:`src`
should be same. For 3-d input, output is updated as:
```python
out[index[i, j, k], j, k] += src[i, j, k] # ``dim`` is 0
out[i, index[i, j, k], k] += src[i, j, k] # ``dim`` is 1
out[i, j, index[i, j, k]] += src[i, j, k] # ``dim`` is 2
```
Examples:
```python
y = torch.tensor([[1, 2], [3, 4]])
x = torch.tensor([[5, 6], [7, 8]])
index = torch.tensor([[0, 0], [0, 0]])
print(torch.scatter_add(y, 0, index, x)) # [[13, 16], [3, 4]]
print(torch.scatter_add(y, 1, index, x)) # [[12, 2], [18, 4]]
print(torch.scatter_add(y, 0, index, 8)) # [[17, 18], [3, 4]]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to add from.
out : dragon.vm.torch.Tensor, optional
The output tensor.
"""
if not isinstance(src, Tensor):
src = init_ops.full_like(
index, src, dtype=input.dtype, device=input.device)
return FunctionLib.apply(
'ScatterAdd', input.device, [input, index, src],
outputs=[out], axis=dim)
def sort(input, dim=-1, descending=False, out=None):
"""Return the sorted elements along the given dimension.
......
......@@ -17,9 +17,62 @@ from __future__ import print_function
import numpy
from dragon.core.framework import workspace
from dragon.vm.torch.core import cpp
from dragon.vm.torch.core.tensor import Tensor
def from_numpy(ndarray):
"""Create a tensor converting from the given numpy array.
Parameters
----------
ndarray : numpy.ndarray
The numpy array data.
Return
------
dragon.vm.torch.Tensor
The torch tensor.
"""
if not isinstance(ndarray, numpy.ndarray):
raise TypeError('<ndarray> should be a numpy array.')
return Tensor(ndarray, copy=False)
def tensor(data, dtype=None, device=None, requires_grad=False):
"""Create a tensor initializing from the given data.
Parameters
----------
data : array_like
The data to initialize from.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
``True`` to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
array_data = numpy.array(data, copy=True)
if dtype is None:
dtype = str(array_data.dtype)
else:
array_data = array_data.astype(dtype)
return Tensor(
array_data,
dtype=dtype,
device=cpp.device() if device is None else device,
requires_grad=requires_grad,
)
def remove_scalars(input1, input2):
"""Remove the input scalars."""
if isinstance(input1, Tensor):
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
from dragon.core.util import nest
from dragon.vm.torch.core import cpp
from dragon.vm.torch.core.tensor import Tensor
from dragon.vm.torch.core.autograd.function_impl import FunctionLib
......@@ -84,6 +85,34 @@ def arange(
return out
def empty(*size, dtype=None, device=None, requires_grad=False):
"""Return a tensor filled with uninitialized data.
Parameters
----------
size : int...
The sizes of output tensor.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device option.
requires_grad : bool, optional, default=False
Whether to compute the gradient if necessary.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return Tensor(
*size,
dtype=dtype if dtype else 'float32',
device=cpp.device() if device is None else device,
requires_grad=requires_grad,
)
def eye(
n,
m=None,
......
......@@ -15,6 +15,7 @@ from __future__ import division
from __future__ import print_function
from dragon.core.util import nest
from dragon.vm.torch.core.autograd import grad_mode
from dragon.vm.torch.core.autograd.function_impl import FunctionLib
from dragon.vm.torch.core.ops import array_ops
from dragon.vm.torch.core.ops import constant_ops
......@@ -932,6 +933,29 @@ def floor_(self):
return math_ops.floor(self, self)
def gather(self, dim, index):
"""Gather elements along the given dimension of index.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.gather(...)`_
"""
return array_ops.gather(self, dim, index)
def ge(self, other):
r"""Compute the element-wise greater-equal comparison.
......@@ -1685,6 +1709,40 @@ def neg_(self):
return math_ops.neg(self, self)
def new_empty(self, *size, dtype=None, device=None, requires_grad=False):
"""Return a tensor filled with uninitialized data.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : int...
The size of output tensor.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
``True`` to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.empty(...)`_
"""
return init_ops.empty(
*nest.flatten(size),
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def new_full(
self,
size,
......@@ -1729,6 +1787,40 @@ def new_full(
)
def new_tensor(self, data, dtype=None, device=None, requires_grad=False):
"""Return a tensor initializing from the given data.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
data : array_like
The data to initialize from.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
``True`` to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.tensor(...)`_
"""
return constant_ops.tensor(
data,
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def nonzero(self):
r"""Return the index of non-zero elements.
......@@ -1977,6 +2069,118 @@ def rsqrt_(self):
return math_ops.rsqrt(self, self)
def scatter(self, dim, index, src):
"""Return a tensor with elements updated from the source.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to update from.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.scatter(...)`_
"""
return array_ops.scatter(self, dim, index, src)
def scatter_(self, dim, index, src, reduce=None):
"""Update elements from the source.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to update from.
reduce : str, optional
``'add'`` or ``'multiply'``.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.scatter(...)`_
"""
if reduce:
if reduce == 'add':
return self.scatter_add_(dim, index, src)
elif reduce == 'multiply':
to_mul = init_ops.ones_like(self, self.dtype, device=self.device)
with grad_mode.no_grad():
to_mul.scatter_(dim, index, src)
return math_ops.mul(self, to_mul, self)
else:
raise ValueError('Unknown reduction: ' + reduce)
return array_ops.scatter(self, dim, index, src, self)
def scatter_add(self, dim, index, src):
"""Return a tensor with elements added from the source.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to add from.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.scatter_add(...)`_
"""
return array_ops.scatter_add(self, dim, index, src)
def scatter_add_(self, dim, index, src):
"""Add elements from the source.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to add from.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.scatter_add(...)`_
"""
return array_ops.scatter_add(self, dim, index, src, self)
def setitem(self, key, value):
"""Set elements at the specific index.
......@@ -2644,6 +2848,7 @@ Tensor.float = _float
Tensor.float_ = _float_
Tensor.floor = floor
Tensor.floor_ = floor_
Tensor.gather = gather
Tensor.ge = ge
Tensor.gt = gt
Tensor.half = half
......@@ -2680,7 +2885,9 @@ Tensor.narrow = narrow
Tensor.ne = ne
Tensor.neg = neg
Tensor.neg_ = neg_
Tensor.new_empty = new_empty
Tensor.new_full = new_full
Tensor.new_tensor = new_tensor
Tensor.nonzero = nonzero
Tensor.normal_ = normal_
Tensor.permute = permute
......@@ -2694,6 +2901,10 @@ Tensor.round = round
Tensor.round_ = round_
Tensor.rsqrt = rsqrt
Tensor.rsqrt_ = rsqrt_
Tensor.scatter = scatter
Tensor.scatter_ = scatter_
Tensor.scatter_add = scatter_add
Tensor.scatter_add_ = scatter_add_
Tensor.sign = sign
Tensor.sign_ = sign_
Tensor.sin = sin
......
......@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/serialization.py>
#
# ------------------------------------------------------------
"""Serialization utilities."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Tensor class."""
from __future__ import absolute_import
from __future__ import division
......@@ -1113,6 +1114,27 @@ class Tensor(object):
"""
def gather(self, dim, index):
"""Gather elements along the given dimension of index.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.gather(...)`_
"""
def ge(self, other):
r"""Compute the element-wise greater-equal comparison.
......@@ -1783,13 +1805,7 @@ class Tensor(object):
"""
def new_empty(
self,
*size,
dtype=None,
device=None,
requires_grad=False,
):
def new_empty(self, *size, dtype=None, device=None, requires_grad=False):
"""Return a tensor filled with uninitialized data.
Refer this tensor if ``dtype`` and ``device`` not provided.
......@@ -1815,12 +1831,6 @@ class Tensor(object):
`torch.empty(...)`_
"""
return empty(
*nest.flatten(size),
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def new_full(
self,
......@@ -1858,13 +1868,7 @@ class Tensor(object):
"""
def new_ones(
self,
*size,
dtype=None,
device=None,
requires_grad=False,
):
def new_ones(self, *size, dtype=None, device=None, requires_grad=False):
"""Return a tensor filled with with ones.
Refer this tensor if ``dtype`` and ``device`` not provided.
......@@ -1891,20 +1895,37 @@ class Tensor(object):
"""
return self.new_full(
nest.flatten(size),
fill_value=1,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
def new_zeros(
self,
*size,
dtype=None,
device=None,
requires_grad=False,
):
nest.flatten(size), fill_value=1, dtype=dtype, device=device,
requires_grad=requires_grad)
def new_tensor(self, data, dtype=None, device=None, requires_grad=False):
"""Return a tensor initializing from the given data.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
data : array_like
The data to initialize from.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
``True`` to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.tensor(...)`_
"""
def new_zeros(self, *size, dtype=None, device=None, requires_grad=False):
"""Return a tensor filled with with zeros.
Refer this tensor if ``dtype`` and ``device`` not provided.
......@@ -1931,12 +1952,8 @@ class Tensor(object):
"""
return self.new_full(
nest.flatten(size),
fill_value=0,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
nest.flatten(size), fill_value=0, dtype=dtype, device=device,
requires_grad=requires_grad)
def nonzero(self):
r"""Return the index of non-zero elements.
......@@ -2197,6 +2214,100 @@ class Tensor(object):
"""
def scatter(self, dim, index, src):
"""Return a tensor with elements updated from the source.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to update from.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.scatter(...)`_
"""
def scatter_(self, dim, index, src, reduce=None):
"""Update elements from the source.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to update from.
reduce : str, optional
``'add'`` or ``'multiply'``.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.scatter(...)`_
"""
def scatter_add(self, dim, index, src):
"""Return a tensor with elements added from the source.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to add from.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.scatter_add(...)`_
"""
def scatter_add_(self, dim, index, src):
"""Add elements from the source.
Parameters
----------
dim : int
The dimension of index values.
index : dragon.vm.torch.Tensor
The index tensor.
src : Union[dragon.vm.torch.Tensor, number]
The tensor to add from.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.scatter_add(...)`_
"""
def sign(self):
r"""Return a tensor taken the sign indication of elements.
......@@ -2894,6 +3005,22 @@ class Tensor(object):
if self._deleter:
self._deleter.release(self._impl.name)
def __eq__(self, other):
r"""Compute the element-wise equal comparison.
Parameters
----------
other : Union[dragon.vm.torch.Tensor, number]
The value to compare.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return self.eq(other)
def __float__(self):
"""Return a float python scalar.
......@@ -2948,6 +3075,7 @@ class Tensor(object):
return self.gt(other)
def __hash__(self):
"""Return the hashable identity."""
return id(self)
def __iadd__(self, other):
......@@ -3132,6 +3260,22 @@ class Tensor(object):
"""
return self.mul(other)
def __ne__(self, other):
"""Compute the element-wise not-equal comparison.
Parameters
----------
other : Union[dragon.vm.torch.Tensor, number]
The value to compare.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return self.ne(other)
def __neg__(self):
"""Compute the element-wise negative.
......@@ -3344,64 +3488,3 @@ class LongTensor(object):
def __new__(cls, *args, **kwargs):
kwargs['dtype'] = 'int64'
return Tensor(*args, **kwargs)
def empty(*size, dtype=None, device=None, requires_grad=False):
"""Return a tensor filled with uninitialized data.
Parameters
----------
size : int...
The sizes of output tensor.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device option.
requires_grad : bool, optional, default=False
Whether to compute the gradient if necessary.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return Tensor(
*size,
dtype=dtype if dtype else 'float32',
device=cpp.device() if device is None else device,
requires_grad=requires_grad,
)
def tensor(data, dtype=None, device=None, requires_grad=False):
"""Create a tensor initializing the content from data.
Parameters
----------
data : array_like
The data to initialize.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
``True`` to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
array_data = numpy.array(data, copy=True)
if dtype is None:
dtype = str(array_data.dtype)
else:
array_data = array_data.astype(dtype)
return Tensor(
array_data,
dtype=dtype,
device=cpp.device() if device is None else device,
requires_grad=requires_grad,
)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!