Commit 936c351b by Ting PAN

Enhance transpose operators

Summary:
This commit allows transpose to compute in-place by leveraging buffer.
We also adds CRD mode for space-depth transpose (i.e., pixel shuffle).
1 parent ac051717
...@@ -39,9 +39,13 @@ shape ...@@ -39,9 +39,13 @@ shape
.. autoattribute:: dragon.Tensor.shape .. autoattribute:: dragon.Tensor.shape
size size
##### ####
.. autoattribute:: dragon.Tensor.size .. autoattribute:: dragon.Tensor.size
T
#
.. autoattribute:: dragon.Tensor.T
Methods Methods
------- -------
...@@ -85,6 +89,10 @@ tolist ...@@ -85,6 +89,10 @@ tolist
###### ######
.. automethod:: dragon.Tensor.tolist .. automethod:: dragon.Tensor.tolist
transpose
#########
.. automethod:: dragon.Tensor.transpose
truncated_normal truncated_normal
################ ################
.. automethod:: dragon.Tensor.truncated_normal .. automethod:: dragon.Tensor.truncated_normal
...@@ -168,6 +176,10 @@ __lt__ ...@@ -168,6 +176,10 @@ __lt__
###### ######
.. automethod:: dragon.Tensor.__lt__ .. automethod:: dragon.Tensor.__lt__
__matmul__
##########
.. automethod:: dragon.Tensor.__matmul__
__mul__ __mul__
####### #######
.. automethod:: dragon.Tensor.__mul__ .. automethod:: dragon.Tensor.__mul__
...@@ -243,6 +255,7 @@ __xor__ ...@@ -243,6 +255,7 @@ __xor__
.. _dragon.math.greater_equal(...): math/greater_equal.html .. _dragon.math.greater_equal(...): math/greater_equal.html
.. _dragon.math.less(...): math/less.html .. _dragon.math.less(...): math/less.html
.. _dragon.math.less_equal(...): math/less_equal.html .. _dragon.math.less_equal(...): math/less_equal.html
.. _dragon.math.matmul(...): math/matmul.html
.. _dragon.math.mul(...): math/mul.html .. _dragon.math.mul(...): math/mul.html
.. _dragon.math.negative(...): math/negative.html .. _dragon.math.negative(...): math/negative.html
.. _dragon.math.not_equal(...): math/not_equal.html .. _dragon.math.not_equal(...): math/not_equal.html
...@@ -253,6 +266,7 @@ __xor__ ...@@ -253,6 +266,7 @@ __xor__
.. _dragon.random.truncated_normal(...): random/truncated_normal.html .. _dragon.random.truncated_normal(...): random/truncated_normal.html
.. _dragon.random.uniform(...): random/uniform.html .. _dragon.random.uniform(...): random/uniform.html
.. _dragon.reshape(...): reshape.html .. _dragon.reshape(...): reshape.html
.. _dragon.transpose(...): transpose.html
.. raw:: html .. raw:: html
......
...@@ -38,6 +38,10 @@ shape ...@@ -38,6 +38,10 @@ shape
##### #####
.. autoattribute:: dragon.vm.torch.Tensor.shape .. autoattribute:: dragon.vm.torch.Tensor.shape
T
#
.. autoattribute:: dragon.vm.torch.Tensor.T
Methods Methods
------- -------
...@@ -213,6 +217,10 @@ exp ...@@ -213,6 +217,10 @@ exp
### ###
.. automethod:: dragon.vm.torch.Tensor.exp .. automethod:: dragon.vm.torch.Tensor.exp
exp\_
#####
.. automethod:: dragon.vm.torch.Tensor.exp_
expand expand
###### ######
.. automethod:: dragon.vm.torch.Tensor.expand .. automethod:: dragon.vm.torch.Tensor.expand
...@@ -317,6 +325,10 @@ log ...@@ -317,6 +325,10 @@ log
### ###
.. automethod:: dragon.vm.torch.Tensor.log .. automethod:: dragon.vm.torch.Tensor.log
log\_
#####
.. automethod:: dragon.vm.torch.Tensor.log_
logical_and logical_and
########### ###########
.. automethod:: dragon.vm.torch.Tensor.logical_and .. automethod:: dragon.vm.torch.Tensor.logical_and
...@@ -461,6 +473,10 @@ permute ...@@ -461,6 +473,10 @@ permute
####### #######
.. automethod:: dragon.vm.torch.Tensor.permute .. automethod:: dragon.vm.torch.Tensor.permute
permute\_
#########
.. automethod:: dragon.vm.torch.Tensor.permute_
pow pow
### ###
.. automethod:: dragon.vm.torch.Tensor.pow .. automethod:: dragon.vm.torch.Tensor.pow
...@@ -593,6 +609,10 @@ transpose ...@@ -593,6 +609,10 @@ transpose
######### #########
.. automethod:: dragon.vm.torch.Tensor.transpose .. automethod:: dragon.vm.torch.Tensor.transpose
transpose\_
###########
.. automethod:: dragon.vm.torch.Tensor.transpose_
tril tril
#### ####
.. automethod:: dragon.vm.torch.Tensor.tril .. automethod:: dragon.vm.torch.Tensor.tril
......
...@@ -197,6 +197,12 @@ vm.torch.nn ...@@ -197,6 +197,12 @@ vm.torch.nn
`class Parameter <nn/Parameter.html>`_ `class Parameter <nn/Parameter.html>`_
: A wrapped tensor considered to be a module parameter. : A wrapped tensor considered to be a module parameter.
`class PixelShuffle <nn/PixelShuffle.html>`_
: Rearrange depth elements into pixels.
`class PixelUnshuffle <nn/PixelUnshuffle.html>`_
: Rearrange pixels into depth elements.
`class PReLU <nn/PReLU.html>`_ `class PReLU <nn/PReLU.html>`_
: Apply the parametric rectified linear unit. : Apply the parametric rectified linear unit.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_. `[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
...@@ -354,6 +360,8 @@ vm.torch.nn ...@@ -354,6 +360,8 @@ vm.torch.nn
nn/MultiheadAttention nn/MultiheadAttention
nn/NLLLoss nn/NLLLoss
nn/Parameter nn/Parameter
nn/PixelShuffle
nn/PixelUnshuffle
nn/PReLU nn/PReLU
nn/ReflectionPad1d nn/ReflectionPad1d
nn/ReflectionPad2d nn/ReflectionPad2d
......
PixelShuffle
============
.. autoclass:: dragon.vm.torch.nn.PixelShuffle
__init__
--------
.. automethod:: dragon.vm.torch.nn.PixelShuffle.__init__
.. _torch.nn.functional.pixel_shuffle(...): functional/pixel_shuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
PixelUnshuffle
==============
.. autoclass:: dragon.vm.torch.nn.PixelUnshuffle
__init__
--------
.. automethod:: dragon.vm.torch.nn.PixelUnshuffle.__init__
.. _torch.nn.functional.pixel_unshuffle(...): functional/pixel_unshuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
...@@ -155,6 +155,12 @@ vm.torch.nn.functional ...@@ -155,6 +155,12 @@ vm.torch.nn.functional
`pad(...) <functional/pad.html>`_ `pad(...) <functional/pad.html>`_
: Pad the input according to the given sizes. : Pad the input according to the given sizes.
`pixel_shuffle(...) <functional/pixel_shuffle.html>`_
: Rearrange depth elements of input into pixels.
`pixel_unshuffle(...) <functional/pixel_unshuffle.html>`_
: Rearrange pixels of input into depth elements.
`prelu(...) <functional/prelu.html>`_ `prelu(...) <functional/prelu.html>`_
: Apply the parametric rectified linear unit to input. : Apply the parametric rectified linear unit to input.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_. `[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
...@@ -256,6 +262,8 @@ vm.torch.nn.functional ...@@ -256,6 +262,8 @@ vm.torch.nn.functional
functional/nll_loss functional/nll_loss
functional/normalize functional/normalize
functional/pad functional/pad
functional/pixel_shuffle
functional/pixel_unshuffle
functional/prelu functional/prelu
functional/relu functional/relu
functional/relu6 functional/relu6
......
pixel_shuffle
=============
.. autofunction:: dragon.vm.torch.nn.functional.pixel_shuffle
.. _torch.nn.PixelShuffle(...): ../PixelShuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
pixel_unshuffle
===============
.. autofunction:: dragon.vm.torch.nn.functional.pixel_unshuffle
.. _torch.nn.PixelUnshuffle(...): ../PixelUnshuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
...@@ -8,62 +8,25 @@ namespace kernels { ...@@ -8,62 +8,25 @@ namespace kernels {
namespace { namespace {
template <typename T> template <typename T, typename AccT, StorageOrder kOrder>
void _GroupNormFusedParams( void _GroupNorm(
const int N, const std::array<int, 4>& dims,
const int G,
const int D,
const T* mu,
const T* rsig,
const T* gamma,
const T* beta,
T* scale,
T* bias) {
const int C = G * D;
ConstEigenArrayMap<T> gamma_arr(gamma, D, G);
ConstEigenArrayMap<T> beta_arr(beta, D, G);
for (int i = 0; i < N; ++i) {
EigenArrayMap<T> scale_arr(scale + i * C, D, G);
scale_arr = gamma_arr.rowwise() *
ConstEigenVectorArrayMap<T>(rsig + i * G, G).transpose();
EigenArrayMap<T>(bias + i * C, D, G) = beta_arr -
scale_arr.rowwise() *
ConstEigenVectorArrayMap<T>(mu + i * G, G).transpose();
}
}
template <typename T, typename AccT>
void _GroupNormNCHW(
const int N,
const int C,
const int S,
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
EigenArrayMap<T>(y, S, N * C) =
(ConstEigenArrayMap<T>(x, S, N * C).rowwise() *
ConstEigenVectorArrayMap<AccT>(scale, N * C).transpose())
.rowwise() +
ConstEigenVectorArrayMap<AccT>(bias, N * C).transpose();
}
template <typename T, typename AccT>
void _GroupNormNHWC(
const int N,
const int C,
const int S,
const T* x, const T* x,
const AccT* scale, const AccT* mu,
const AccT* bias, const AccT* rsig,
const AccT* gamma,
const AccT* beta,
T* y) { T* y) {
const int SC = S * C; const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
for (int i = 0; i < N; ++i) { const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
EigenArrayMap<T>(y + i * SC, C, S) = const int NxGxDxS = dims[0] * dims[1] * dims[2] * dims[3];
(ConstEigenArrayMap<T>(x + i * SC, C, S).colwise() * std::array<int, 4> index = {0, 0, 0, 0};
ConstEigenVectorArrayMap<AccT>(scale + i * C, C)) for (int i = 0; i < NxGxDxS; ++i) {
.colwise() + const int ng = index[0] * dims[kGDim] + index[kGDim];
ConstEigenVectorArrayMap<AccT>(bias + i * C, C); const int c = index[kGDim] * dims[kDDim] + index[kDDim];
AccT val = (convert::To<AccT>(x[i]) - mu[ng]) * rsig[ng];
y[i] = convert::To<T>(val * gamma[c] + beta[c]);
math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
} }
} }
...@@ -77,13 +40,13 @@ void _GroupNormInternalGrad( ...@@ -77,13 +40,13 @@ void _GroupNormInternalGrad(
AccT* db) { AccT* db) {
const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2; const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3; const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int NxGxKxS = dims[0] * dims[1] * dims[2] * dims[3]; const int NxGxDxS = dims[0] * dims[1] * dims[2] * dims[3];
std::array<int, 4> index = {0, 0, 0, 0}; std::array<int, 4> index = {0, 0, 0, 0};
for (int i = 0; i < NxGxKxS; ++i) { for (int i = 0; i < NxGxDxS; ++i) {
const int mi = index[0] * dims[kGDim] + index[kGDim]; const int ng = index[0] * dims[kGDim] + index[kGDim];
const int gi = index[kGDim] * dims[kDDim] + index[kDDim]; const int c = index[kGDim] * dims[kDDim] + index[kDDim];
ds[mi] += gamma[gi] * dy[i] * x[i]; ds[ng] += gamma[c] * dy[i] * x[i];
db[mi] += gamma[gi] * dy[i]; db[ng] += gamma[c] * dy[i];
math::utils::IncreaseIndexInDims(4, dims.data(), index.data()); math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
} }
} }
...@@ -103,19 +66,19 @@ void _GroupNormGrad( ...@@ -103,19 +66,19 @@ void _GroupNormGrad(
T* dx) { T* dx) {
const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2; const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3; const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int NxGxKxS = dims[0] * dims[1] * dims[2] * dims[3]; const int NxGxDxS = dims[0] * dims[1] * dims[2] * dims[3];
const int S = kOrder == StorageOrder::NCHW ? dims[3] : dims[1]; const int S = kOrder == StorageOrder::NCHW ? dims[3] : dims[1];
const AccT denom = AccT(1) / static_cast<AccT>(dims[kDDim] * S); const AccT denom = AccT(1) / static_cast<AccT>(dims[kDDim] * S);
std::array<int, 4> index = {0, 0, 0, 0}; std::array<int, 4> index = {0, 0, 0, 0};
for (int i = 0; i < NxGxKxS; ++i) { for (int i = 0; i < NxGxDxS; ++i) {
const int mi = index[0] * dims[kGDim] + index[kGDim]; const int ng = index[0] * dims[kGDim] + index[kGDim];
const int gi = index[kGDim] * dims[kDDim] + index[kDDim]; const int c = index[kGDim] * dims[kDDim] + index[kDDim];
const AccT u = (db[mi] * mu[mi] - ds[mi]) * (x[i] - mu[mi]) * const AccT u = (db[ng] * mu[ng] - ds[ng]) * (x[i] - mu[ng]) *
math::utils::Cube(rsig[mi]); math::utils::Cube(rsig[ng]);
const AccT v = db[mi] * rsig[mi]; const AccT v = db[ng] * rsig[ng];
dx[i] = gamma[gi] * dy[i] * rsig[mi] + (u - v) * denom; dx[i] = gamma[c] * dy[i] * rsig[ng] + (u - v) * denom;
dgamma[gi] += dy[i] * (x[i] - mu[mi]) * rsig[mi]; dgamma[c] += dy[i] * (x[i] - mu[ng]) * rsig[ng];
dbeta[gi] += dy[i]; dbeta[c] += dy[i];
math::utils::IncreaseIndexInDims(4, dims.data(), index.data()); math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
} }
} }
...@@ -125,25 +88,6 @@ void _GroupNormGrad( ...@@ -125,25 +88,6 @@ void _GroupNormGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> template <>
void GroupNorm<float16, float, CPUContext>(
const int N,
const int G,
const int D,
const int S,
const string& data_format,
const float16* x,
const float* mu,
const float* rsig,
const float* gamma,
const float* beta,
float* scale,
float* bias,
float16* y,
CPUContext* tx) {
CPU_FP16_NOT_SUPPORTED;
}
template <>
void GroupNormGrad<float16, float, CPUContext>( void GroupNormGrad<float16, float, CPUContext>(
const int N, const int N,
const int G, const int G,
...@@ -177,16 +121,14 @@ void GroupNormGrad<float16, float, CPUContext>( ...@@ -177,16 +121,14 @@ void GroupNormGrad<float16, float, CPUContext>(
const AccT* rsig, \ const AccT* rsig, \
const AccT* gamma, \ const AccT* gamma, \
const AccT* beta, \ const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
const int C = G * D; \
_GroupNormFusedParams(N, G, D, mu, rsig, gamma, beta, scale, bias); \
if (data_format == "NCHW") { \ if (data_format == "NCHW") { \
_GroupNormNCHW(N, C, S, x, scale, bias, y); \ _GroupNorm<T, AccT, StorageOrder::NCHW>( \
{N, G, D, S}, x, mu, rsig, gamma, beta, y); \
} else if (data_format == "NHWC") { \ } else if (data_format == "NHWC") { \
_GroupNormNHWC(N, C, S, x, scale, bias, y); \ _GroupNorm<T, AccT, StorageOrder::NHWC>( \
{N, S, G, D}, x, mu, rsig, gamma, beta, y); \
} \ } \
} }
...@@ -226,6 +168,7 @@ void GroupNormGrad<float16, float, CPUContext>( ...@@ -226,6 +168,7 @@ void GroupNormGrad<float16, float, CPUContext>(
} \ } \
} }
DEFINE_KERNEL_LAUNCHER(float16, float);
DEFINE_KERNEL_LAUNCHER(float, float); DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(double, double); DEFINE_KERNEL_LAUNCHER(double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float); DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
......
...@@ -14,40 +14,27 @@ namespace { ...@@ -14,40 +14,27 @@ namespace {
#define LDG(x, i) __ldg(x + i) #define LDG(x, i) __ldg(x + i)
#define LDG2(x, i) convert::To<AccT>(__ldg(x + i)) #define LDG2(x, i) convert::To<AccT>(__ldg(x + i))
template <typename T>
__global__ void _GroupNormFusedParams(
const int NxC,
const int C,
const int D,
const T* mu,
const T* rsig,
const T* gamma,
const T* beta,
T* scale,
T* bias) {
CUDA_1D_KERNEL_LOOP(i, NxC) {
const int c = i % C;
const int ng = i / D;
const T scale_val = LDG(gamma, c) * LDG(rsig, ng);
scale[i] = scale_val;
bias[i] = fma(-scale_val, LDG(mu, ng), LDG(beta, c));
}
}
template <typename T, typename AccT, StorageOrder kOrder> template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _GroupNormAffine( __global__ void _GroupNorm(
const int NxCxS, const int NxCxS,
const int C, const int G,
const int D,
const int S, const int S,
const T* x, const T* x,
const AccT* scale, const AccT* mu,
const AccT* bias, const AccT* rsig,
const AccT* gamma,
const AccT* beta,
T* y) { T* y) {
const int C = G * D;
CUDA_1D_KERNEL_LOOP(i, NxCxS) { CUDA_1D_KERNEL_LOOP(i, NxCxS) {
const int nc = const int ng = kOrder == StorageOrder::NCHW ? i / (D * S)
kOrder == StorageOrder::NCHW ? i / S : i / (C * S) * C + i % C; : i / (C * S) * G + (i / D % G);
const int c = kOrder == StorageOrder::NCHW ? i / S % C : i % C;
y[i] = convert::To<T>( y[i] = convert::To<T>(
fma(convert::To<AccT>(x[i]), LDG(scale, nc), LDG(bias, nc))); fma((convert::To<AccT>(x[i]) - __ldg(mu + ng)) * __ldg(rsig + ng),
__ldg(gamma + c),
__ldg(beta + c)));
} }
} }
...@@ -179,30 +166,24 @@ __global__ void _GroupNormGrad( ...@@ -179,30 +166,24 @@ __global__ void _GroupNormGrad(
const AccT* rsig, \ const AccT* rsig, \
const AccT* gamma, \ const AccT* gamma, \
const AccT* beta, \ const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const auto C = G * D; \ const auto NxCxS = N * G * D * S; \
const auto NxC = N * C; \
const auto NxCxS = NxC * S; \
_GroupNormFusedParams<<< \
CUDA_BLOCKS(NxC), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(NxC, C, D, mu, rsig, gamma, beta, scale, bias); \
DISPATCH_GROUPNORM_KERNEL( \ DISPATCH_GROUPNORM_KERNEL( \
_GroupNormAffine, \ _GroupNorm, \
math::ScalarType<T>::type, \ math::ScalarType<T>::type, \
AccT, \ AccT, \
CUDA_BLOCKS(NxCxS), \ CUDA_BLOCKS(NxCxS), \
CUDA_THREADS, \ CUDA_THREADS, \
NxCxS, \ NxCxS, \
C, \ G, \
D, \
S, \ S, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \ reinterpret_cast<const math::ScalarType<T>::type*>(x), \
scale, \ mu, \
bias, \ rsig, \
gamma, \
beta, \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \ reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} }
......
...@@ -25,21 +25,22 @@ class NumpyWrapper { ...@@ -25,21 +25,22 @@ class NumpyWrapper {
py::object To(bool copy) { py::object To(bool copy) {
const auto& meta = tensor_->meta(); const auto& meta = tensor_->meta();
const auto& typestr = ::dragon::dtypes::to_string(meta); const auto& dtype = ::dragon::dtypes::to_string(meta);
CHECK_GT(tensor_->count(), 0) << "\nConvert an empty tensor."; CHECK_GT(tensor_->count(), 0) << "\nConvert an empty tensor.";
CHECK(typestr != "unknown") << "\nConvert an empty tensor."; CHECK(dtype != "unknown") << "\nConvert an empty tensor.";
if (typestr == "string") { if (dtype == "string") {
CHECK_EQ(tensor_->count(), 1); CHECK_EQ(tensor_->count(), 1);
return py::bytes(tensor_->data<string, CPUContext>()[0]); return py::bytes(tensor_->data<string, CPUContext>()[0]);
} }
auto typenum = dtypes::to_npy(meta);
vector<npy_intp> dims({tensor_->dims().begin(), tensor_->dims().end()}); vector<npy_intp> dims({tensor_->dims().begin(), tensor_->dims().end()});
if (copy) { if (copy) {
auto* memory = tensor_->memory(); auto* memory = tensor_->memory();
CHECK(memory) << "\nConvert an empty tensor."; CHECK(memory) << "\nConvert an empty tensor.";
auto device_type = memory ? memory->info()["device_type"] : "cpu"; auto device_type = memory ? memory->info()["device_type"] : "cpu";
auto* array = PyArray_SimpleNew(tensor_->ndim(), dims.data(), typenum); auto* array =
PyArray_SimpleNew(dims.size(), dims.data(), dtypes::to_npy(meta));
if (device_type == "cuda") { if (device_type == "cuda") {
CUDADeviceGuard guard(memory->device());
CUDAContext::Memcpy<CPUContext, CUDAContext>( CUDAContext::Memcpy<CPUContext, CUDAContext>(
tensor_->nbytes(), tensor_->nbytes(),
PyArray_DATA(reinterpret_cast<PyArrayObject*>(array)), PyArray_DATA(reinterpret_cast<PyArrayObject*>(array)),
...@@ -53,9 +54,11 @@ class NumpyWrapper { ...@@ -53,9 +54,11 @@ class NumpyWrapper {
} }
return py::reinterpret_steal<py::object>(array); return py::reinterpret_steal<py::object>(array);
} }
auto* data = const_cast<void*>(tensor_->raw_data<CPUContext>());
auto* array = PyArray_SimpleNewFromData( auto* array = PyArray_SimpleNewFromData(
tensor_->ndim(), dims.data(), dtypes::to_npy(meta), data); dims.size(),
dims.data(),
dtypes::to_npy(meta),
const_cast<void*>(tensor_->raw_data<CPUContext>()));
return py::reinterpret_steal<py::object>(array); return py::reinterpret_steal<py::object>(array);
} }
...@@ -71,6 +74,7 @@ class NumpyWrapper { ...@@ -71,6 +74,7 @@ class NumpyWrapper {
if (copy) { if (copy) {
auto device_type = memory ? memory->info()["device_type"] : "cpu"; auto device_type = memory ? memory->info()["device_type"] : "cpu";
if (device_type == "cuda") { if (device_type == "cuda") {
CUDADeviceGuard guard(memory->device());
CUDAContext::Memcpy<CUDAContext, CPUContext>( CUDAContext::Memcpy<CUDAContext, CPUContext>(
tensor_->nbytes(), tensor_->nbytes(),
tensor_->raw_mutable_data<CUDAContext>(), tensor_->raw_mutable_data<CUDAContext>(),
......
#include "dragon/operators/array/transpose_op.h" #include "dragon/operators/array/transpose_op.h"
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -7,7 +8,7 @@ namespace dragon { ...@@ -7,7 +8,7 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void TransposeOp<Context>::DoRunWithType() { void TransposeOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0, {0});
int num_axes, num_dims = X.ndim(); int num_axes, num_dims = X.ndim();
vec64_t X_strides(num_dims), Y_dims(num_dims); vec64_t X_strides(num_dims), Y_dims(num_dims);
...@@ -34,13 +35,25 @@ void TransposeOp<Context>::DoRunWithType() { ...@@ -34,13 +35,25 @@ void TransposeOp<Context>::DoRunWithType() {
Y_dims[i] = X.dim(new_axes[i]); Y_dims[i] = X.dim(new_axes[i]);
} }
auto* scratch = ((void*)&X == (void*)Y)
? ctx()->workspace()->template data<T, Context>({X.count()})[0]
: Y->Reshape(Y_dims)->template mutable_data<T, Context>();
kernels::Transpose( kernels::Transpose(
num_dims, num_dims,
X_strides.data(), X_strides.data(),
Y_dims.data(), Y_dims.data(),
X.template data<T, Context>(), X.template data<T, Context>(),
scratch,
ctx());
if ((void*)&X == (void*)Y) {
math::Copy(
X.count(),
scratch,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(), Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx()); ctx());
}
} }
DEPLOY_CPU_OPERATOR(Transpose); DEPLOY_CPU_OPERATOR(Transpose);
...@@ -54,13 +67,17 @@ OPERATOR_SCHEMA(Transpose) ...@@ -54,13 +67,17 @@ OPERATOR_SCHEMA(Transpose)
/* X */ /* X */
.NumInputs(1) .NumInputs(1)
/* Y */ /* Y */
.NumOutputs(1); .NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(TransposeGradient) OPERATOR_SCHEMA(TransposeGradient)
/* dY */ /* dY */
.NumInputs(1) .NumInputs(1)
/* dX */ /* dX */
.NumOutputs(1); .NumOutputs(1)
/* dY => dX */
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Transpose, SimpleGradientMaker); REGISTER_GRADIENT(Transpose, SimpleGradientMaker);
......
...@@ -37,8 +37,6 @@ void GroupNormOp<Context>::DoRunWithType() { ...@@ -37,8 +37,6 @@ void GroupNormOp<Context>::DoRunWithType() {
math::InvStd(N_ * G_, epsilon_, rsig, rsig, ctx()); math::InvStd(N_ * G_, epsilon_, rsig, rsig, ctx());
// Fuse parameters to compute affine transformation. // Fuse parameters to compute affine transformation.
auto* scratch =
ctx()->workspace()->template data<ParamT, Context>({2 * N_ * C_})[0];
kernels::GroupNorm( kernels::GroupNorm(
N_, N_,
G_, G_,
...@@ -50,8 +48,6 @@ void GroupNormOp<Context>::DoRunWithType() { ...@@ -50,8 +48,6 @@ void GroupNormOp<Context>::DoRunWithType() {
rsig, rsig,
W.template data<ParamT, Context>(), W.template data<ParamT, Context>(),
B.template data<ParamT, Context>(), B.template data<ParamT, Context>(),
scratch,
scratch + N_ * C_,
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
......
#include "dragon/operators/vision/space_to_depth_op.h" #include "dragon/operators/vision/space_to_depth_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -6,7 +8,7 @@ namespace dragon { ...@@ -6,7 +8,7 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void SpaceToDepthOp<Context>::DoRunWithType() { void SpaceToDepthOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0, {0});
int start_axis, end_axis, perm_count = 0; int start_axis, end_axis, perm_count = 0;
int num_dims = X.ndim(), num_axes = X.ndim() - 2; int num_dims = X.ndim(), num_axes = X.ndim() - 2;
...@@ -48,8 +50,8 @@ void SpaceToDepthOp<Context>::DoRunWithType() { ...@@ -48,8 +50,8 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
for (int i = 0; i < num_axes; i++) { for (int i = 0; i < num_axes; i++) {
perm.insert(perm.begin() + 1, perm.back()); perm.insert(perm.begin() + (mode_ == "DCR" ? 1 : 2), perm.back());
perm.pop_back(); // DCR mode perm.pop_back();
} }
} }
...@@ -66,19 +68,31 @@ void SpaceToDepthOp<Context>::DoRunWithType() { ...@@ -66,19 +68,31 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
Y_dims[i] = X_reshape.dim(perm[i]); Y_dims[i] = X_reshape.dim(perm[i]);
} }
auto* scratch = ((void*)&X == (void*)Y)
? ctx()->workspace()->template data<T, Context>({X.count()})[0]
: Y->Reshape(out_shape)->template mutable_data<T, Context>();
kernels::Transpose( kernels::Transpose(
X_strides.size(), X_strides.size(),
X_strides.data(), X_strides.data(),
Y_dims.data(), Y_dims.data(),
X.template data<T, Context>(), X.template data<T, Context>(),
scratch,
ctx());
if ((void*)&X == (void*)Y) {
math::Copy(
X.count(),
scratch,
Y->Reshape(out_shape)->template mutable_data<T, Context>(), Y->Reshape(out_shape)->template mutable_data<T, Context>(),
ctx()); ctx());
}
} }
template <class Context> template <class Context>
template <typename T> template <typename T>
void DepthToSpaceOp<Context>::DoRunWithType() { void DepthToSpaceOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0, {0});
int start_axis, end_axis; int start_axis, end_axis;
int num_dims = X.ndim(), num_axes = X.ndim() - 2; int num_dims = X.ndim(), num_axes = X.ndim() - 2;
...@@ -94,11 +108,11 @@ void DepthToSpaceOp<Context>::DoRunWithType() { ...@@ -94,11 +108,11 @@ void DepthToSpaceOp<Context>::DoRunWithType() {
start_axis = 2, end_axis = num_dims; start_axis = 2, end_axis = num_dims;
out_shape[1] /= std::pow(block_size_, num_axes); out_shape[1] /= std::pow(block_size_, num_axes);
in_dims = out_shape; in_dims = out_shape;
perm[1] = num_axes + 1; perm[1] = (mode_ == "DCR" ? num_axes + 1 : 1);
for (int i = 0; i < num_axes; i++) { for (int i = 0; i < num_axes; i++) {
perm[i * 2 + 2] = num_axes + i + 2; perm[i * 2 + 2] = num_axes + i + 2;
perm[i * 2 + 3] = i + 1; perm[i * 2 + 3] = i + (mode_ == "DCR" ? 1 : 2);
in_dims.insert(in_dims.begin() + 1, block_size_); in_dims.insert(in_dims.begin() + (mode_ == "DCR" ? 1 : 2), block_size_);
out_shape[start_axis + i] *= block_size_; out_shape[start_axis + i] *= block_size_;
} }
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
...@@ -129,13 +143,25 @@ void DepthToSpaceOp<Context>::DoRunWithType() { ...@@ -129,13 +143,25 @@ void DepthToSpaceOp<Context>::DoRunWithType() {
Y_dims[i] = X_reshape.dim(perm[i]); Y_dims[i] = X_reshape.dim(perm[i]);
} }
auto* scratch = ((void*)&X == (void*)Y)
? ctx()->workspace()->template data<T, Context>({X.count()})[0]
: Y->Reshape(out_shape)->template mutable_data<T, Context>();
kernels::Transpose( kernels::Transpose(
X_strides.size(), X_strides.size(),
X_strides.data(), X_strides.data(),
Y_dims.data(), Y_dims.data(),
X.template data<T, Context>(), X.template data<T, Context>(),
scratch,
ctx());
if ((void*)&X == (void*)Y) {
math::Copy(
X.count(),
scratch,
Y->Reshape(out_shape)->template mutable_data<T, Context>(), Y->Reshape(out_shape)->template mutable_data<T, Context>(),
ctx()); ctx());
}
} }
DEPLOY_CPU_OPERATOR(SpaceToDepth); DEPLOY_CPU_OPERATOR(SpaceToDepth);
...@@ -152,10 +178,16 @@ DEPLOY_CUDA_OPERATOR(DepthToSpace); ...@@ -152,10 +178,16 @@ DEPLOY_CUDA_OPERATOR(DepthToSpace);
REGISTER_CUDA_OPERATOR(DepthToSpaceGradient, SpaceToDepthOp<CUDAContext>); REGISTER_CUDA_OPERATOR(DepthToSpaceGradient, SpaceToDepthOp<CUDAContext>);
#endif #endif
OPERATOR_SCHEMA(SpaceToDepth).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(SpaceToDepth).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SpaceToDepthGradient).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(SpaceToDepthGradient)
OPERATOR_SCHEMA(DepthToSpace).NumInputs(1).NumOutputs(1); .NumInputs(1)
OPERATOR_SCHEMA(DepthToSpaceGradient).NumInputs(1).NumOutputs(1); .NumOutputs(1)
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DepthToSpace).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DepthToSpaceGradient)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(SpaceToDepth, SimpleGradientMaker); REGISTER_GRADIENT(SpaceToDepth, SimpleGradientMaker);
REGISTER_GRADIENT(DepthToSpace, SimpleGradientMaker); REGISTER_GRADIENT(DepthToSpace, SimpleGradientMaker);
......
...@@ -22,7 +22,8 @@ class SpaceToDepthOp final : public Operator<Context> { ...@@ -22,7 +22,8 @@ class SpaceToDepthOp final : public Operator<Context> {
public: public:
SpaceToDepthOp(const OperatorDef& def, Workspace* ws) SpaceToDepthOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
block_size_(OP_SINGLE_ARG(int, "block_size", 2)) {} block_size_(OP_SINGLE_ARG(int, "block_size", 2)),
mode_(OP_SINGLE_ARG(string, "mode", "DCR")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override { void RunOnDevice() override {
...@@ -33,6 +34,7 @@ class SpaceToDepthOp final : public Operator<Context> { ...@@ -33,6 +34,7 @@ class SpaceToDepthOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
string mode_;
int64_t block_size_; int64_t block_size_;
}; };
...@@ -41,7 +43,8 @@ class DepthToSpaceOp final : public Operator<Context> { ...@@ -41,7 +43,8 @@ class DepthToSpaceOp final : public Operator<Context> {
public: public:
DepthToSpaceOp(const OperatorDef& def, Workspace* ws) DepthToSpaceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
block_size_(OP_SINGLE_ARG(int, "block_size", 2)) {} block_size_(OP_SINGLE_ARG(int, "block_size", 2)),
mode_(OP_SINGLE_ARG(string, "mode", "DCR")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override { void RunOnDevice() override {
...@@ -52,6 +55,7 @@ class DepthToSpaceOp final : public Operator<Context> { ...@@ -52,6 +55,7 @@ class DepthToSpaceOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
string mode_;
int64_t block_size_; int64_t block_size_;
}; };
......
...@@ -159,6 +159,7 @@ def cum_reduce_args(**kwargs): ...@@ -159,6 +159,7 @@ def cum_reduce_args(**kwargs):
def depth_space_args(**kwargs): def depth_space_args(**kwargs):
return { return {
'block_size': kwargs.get('block_size', '2'), 'block_size': kwargs.get('block_size', '2'),
'mode': kwargs.get('mode', 'DCR'),
'data_format': kwargs.get('data_format', 'NCHW'), 'data_format': kwargs.get('data_format', 'NCHW'),
} }
......
...@@ -180,6 +180,18 @@ class Tensor(types.TensorBase): ...@@ -180,6 +180,18 @@ class Tensor(types.TensorBase):
return float('inf') return float('inf')
return math_util.prod(self._shape) return math_util.prod(self._shape)
@property
def T(self):
"""Return a tensor with axes reversed.
Returns
-------
dragon.Tensor
The output tensor.
"""
return self.transpose()
def astype(self, dtype, copy=True): def astype(self, dtype, copy=True):
"""Convert the data type to a specific one. """Convert the data type to a specific one.
...@@ -365,6 +377,27 @@ class Tensor(types.TensorBase): ...@@ -365,6 +377,27 @@ class Tensor(types.TensorBase):
""" """
return self.numpy().tolist() return self.numpy().tolist()
def transpose(self, perm=None, copy=True):
"""Return a tensor with permuted axes.
Parameters
----------
perm : Union[Sequence[int], dragon.Tensor]], optional
The output permutation.
copy : bool, optional, default=True
Return a new tensor or transpose in-place.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.transpose(...)`_
"""
def truncated_normal(self, mean=0, std=1): def truncated_normal(self, mean=0, std=1):
r"""Fill self from a truncated normal distribution. r"""Fill self from a truncated normal distribution.
...@@ -694,6 +727,25 @@ class Tensor(types.TensorBase): ...@@ -694,6 +727,25 @@ class Tensor(types.TensorBase):
""" """
def __matmul__(self, other):
"""Compute the matrix multiplication.
Parameters
----------
other : dragon.Tensor
The value to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.math.matmul(...)`_
"""
def __mul__(self, other): def __mul__(self, other):
"""Compute the element-wise multiplication. """Compute the element-wise multiplication.
......
...@@ -1754,7 +1754,7 @@ def tile(inputs, repeats, **kwargs): ...@@ -1754,7 +1754,7 @@ def tile(inputs, repeats, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
@OpSchema.convert_arg('perm') @OpSchema.convert_arg('perm')
def transpose(inputs, perm=None, **kwargs): def transpose(inputs, perm=None, copy=True, **kwargs):
r"""Permute the dimensions of input. r"""Permute the dimensions of input.
Examples: Examples:
...@@ -1774,6 +1774,8 @@ def transpose(inputs, perm=None, **kwargs): ...@@ -1774,6 +1774,8 @@ def transpose(inputs, perm=None, **kwargs):
The input tensor. The input tensor.
perm : Union[Sequence[int], dragon.Tensor]], optional perm : Union[Sequence[int], dragon.Tensor]], optional
The output permutation. The output permutation.
copy : bool, optional, default=True
Return a new tensor or call in-place.
Returns Returns
------- -------
...@@ -1785,6 +1787,7 @@ def transpose(inputs, perm=None, **kwargs): ...@@ -1785,6 +1787,7 @@ def transpose(inputs, perm=None, **kwargs):
if context.executing_eagerly(): if context.executing_eagerly():
return OpLib.execute( return OpLib.execute(
'Transpose', inputs, 'Transpose', inputs,
outputs=[None] if copy else inputs,
ndim=len(args['perm']) if perm is not None else 0, ndim=len(args['perm']) if perm is not None else 0,
perm=args['perm']) perm=args['perm'])
return OpLib.add('Transpose', **args) return OpLib.add('Transpose', **args)
......
...@@ -516,6 +516,27 @@ def lt(self, other): ...@@ -516,6 +516,27 @@ def lt(self, other):
return _apply_binary_op([self, other], 'Less') return _apply_binary_op([self, other], 'Less')
def matmul(self, other):
"""Compute the matrix multiplication.
Parameters
----------
other : dragon.Tensor
The value to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.math.matmul(...)`_
"""
return _apply_binary_op([self, other], 'MatMul')
def mul(self, other): def mul(self, other):
"""Compute the element-wise multiplication. """Compute the element-wise multiplication.
...@@ -844,6 +865,29 @@ def sub(self, other): ...@@ -844,6 +865,29 @@ def sub(self, other):
return _apply_binary_op([self, other], 'Sub') return _apply_binary_op([self, other], 'Sub')
def transpose(self, perm=None, copy=True):
"""Return a tensor with permuted axes.
Parameters
----------
perm : Union[Sequence[int], dragon.Tensor]], optional
The output permutation.
copy : bool, optional, default=True
Return a new tensor or transpose in-place.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.transpose(...)`_
"""
return array_ops.transpose(self, perm=perm, copy=copy)
def truncated_normal(self, mean=0, std=1): def truncated_normal(self, mean=0, std=1):
r"""Fill self from a truncated normal distribution. r"""Fill self from a truncated normal distribution.
...@@ -984,6 +1028,7 @@ Tensor.glorot_normal = glorot_normal ...@@ -984,6 +1028,7 @@ Tensor.glorot_normal = glorot_normal
Tensor.glorot_uniform = glorot_uniform Tensor.glorot_uniform = glorot_uniform
Tensor.normal = normal Tensor.normal = normal
Tensor.reshape = reshape Tensor.reshape = reshape
Tensor.transpose = transpose
Tensor.truncated_normal = truncated_normal Tensor.truncated_normal = truncated_normal
Tensor.uniform = uniform Tensor.uniform = uniform
Tensor.__add__ = add Tensor.__add__ = add
...@@ -1003,6 +1048,7 @@ Tensor.__itruediv__ = idiv ...@@ -1003,6 +1048,7 @@ Tensor.__itruediv__ = idiv
Tensor.__ixor__ = ixor Tensor.__ixor__ = ixor
Tensor.__le__ = le Tensor.__le__ = le
Tensor.__lt__ = lt Tensor.__lt__ = lt
Tensor.__matmul__ = matmul
Tensor.__mul__ = mul Tensor.__mul__ = mul
Tensor.__ne__ = ne Tensor.__ne__ = ne
Tensor.__neg__ = neg Tensor.__neg__ = neg
......
...@@ -831,7 +831,14 @@ def depthwise_conv2d( ...@@ -831,7 +831,14 @@ def depthwise_conv2d(
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs): def depth_to_space(
inputs,
block_size,
mode='DCR',
data_format='NCHW',
copy=True,
**kwargs
):
"""Rearrange depth data into spatial blocks. """Rearrange depth data into spatial blocks.
Examples: Examples:
...@@ -851,8 +858,12 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs): ...@@ -851,8 +858,12 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs):
The input tensor. The input tensor.
block_size : int, required block_size : int, required
The size of spatial block. The size of spatial block.
mode : str, optional, default='DCR'
Rearrangement order for ``'NCHW'`` format.
data_format : str, optional, default='NCHW' data_format : str, optional, default='NCHW'
``'NCHW'`` or ``'NHWC'``. ``'NCHW'`` or ``'NHWC'``.
copy : bool, optional, default=True
Return a new tensor or call in-place.
Returns Returns
------- -------
...@@ -865,9 +876,10 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs): ...@@ -865,9 +876,10 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs):
if context.executing_eagerly(): if context.executing_eagerly():
return OpLib.execute( return OpLib.execute(
'DepthToSpace', inputs, 'DepthToSpace', inputs,
block_size=block_size, data_format=data_format) outputs=[None] if copy else inputs,
block_size=block_size, mode=mode.upper(), data_format=data_format)
return OpLib.add('DepthToSpace', inputs, block_size=block_size, return OpLib.add('DepthToSpace', inputs, block_size=block_size,
data_format=data_format, **kwargs) mode=mode.upper(), data_format=data_format, **kwargs)
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
...@@ -1482,7 +1494,14 @@ def roi_pool( ...@@ -1482,7 +1494,14 @@ def roi_pool(
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs): def space_to_depth(
inputs,
block_size,
mode='DCR',
data_format='NCHW',
copy=True,
**kwargs
):
"""Rearrange blocks of spatial data into depth. """Rearrange blocks of spatial data into depth.
Examples: Examples:
...@@ -1492,7 +1511,7 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs): ...@@ -1492,7 +1511,7 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs):
x = dragon.range(n * c * h * w).reshape((n, c, h, w)) x = dragon.range(n * c * h * w).reshape((n, c, h, w))
y = dragon.reshape(x, (n, c, h // bs, bs, w // bs, bs)) y = dragon.reshape(x, (n, c, h // bs, bs, w // bs, bs))
y = dragon.transpose(y, (0, 3, 5, 1, 2, 4)) y = dragon.transpose(y, (0, 3, 5, 1, 2, 4))
y = dragon.reshape(y, (n, c * (bs ** 2), h // bs, w // bs)) y = dragon.reshape(y, (n, (bs ** 2) * c, h // bs, w // bs))
z = dragon.nn.space_to_depth(x, 2) # Equivalent z = dragon.nn.space_to_depth(x, 2) # Equivalent
``` ```
...@@ -1502,8 +1521,12 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs): ...@@ -1502,8 +1521,12 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs):
The input tensor. The input tensor.
block_size : int, required block_size : int, required
The size of spatial block. The size of spatial block.
mode : str, optional, default='DCR'
Rearrangement order for ``'NCHW'`` format.
data_format : str, optional, default='NCHW' data_format : str, optional, default='NCHW'
``'NCHW'`` or ``'NHWC'``. ``'NCHW'`` or ``'NHWC'``.
copy : bool, optional, default=True
Return a new tensor or call in-place.
Returns Returns
------- -------
...@@ -1516,9 +1539,10 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs): ...@@ -1516,9 +1539,10 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs):
if context.executing_eagerly(): if context.executing_eagerly():
return OpLib.execute( return OpLib.execute(
'SpaceToDepth', inputs, 'SpaceToDepth', inputs,
block_size=block_size, data_format=data_format) outputs=[None] if copy else inputs,
block_size=block_size, mode=mode.upper(), data_format=data_format)
return OpLib.add('SpaceToDepth', inputs, block_size=block_size, return OpLib.add('SpaceToDepth', inputs, block_size=block_size,
data_format=data_format, **kwargs) mode=mode.upper(), data_format=data_format, **kwargs)
def _normalize_tuple(value, rank): def _normalize_tuple(value, rank):
......
...@@ -68,6 +68,9 @@ def depth_space_exporter(op_def, context): ...@@ -68,6 +68,9 @@ def depth_space_exporter(op_def, context):
_assert_data_format(arg) _assert_data_format(arg)
if arg.name == 'block_size': if arg.name == 'block_size':
helper.add_attribute(node, 'blocksize', arg.i) helper.add_attribute(node, 'blocksize', arg.i)
elif arg.name == 'mode':
if node.op_type != 'SpaceToDepth':
helper.add_attribute(node, 'mode', arg.s)
return node, const_tensors return node, const_tensors
......
...@@ -918,8 +918,6 @@ void GroupNorm( ...@@ -918,8 +918,6 @@ void GroupNorm(
const AccT* rsig, const AccT* rsig,
const AccT* gamma, const AccT* gamma,
const AccT* beta, const AccT* beta,
AccT* scale,
AccT* bias,
T* y, T* y,
Context* ctx); Context* ctx);
......
...@@ -3101,6 +3101,14 @@ class TestTensorOps(OpTestCase): ...@@ -3101,6 +3101,14 @@ class TestTensorOps(OpTestCase):
x = new_tensor(data) x = new_tensor(data)
self.assertEqual(~x, ~data) self.assertEqual(~x, ~data)
def test_matmul(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for a_shape, b_shape in self.binary_test_shapes:
data1, data2 = arange((2, 3)), arange((3, 4), 1)
a, b = new_tensor(data1), new_tensor(data2)
self.assertEqual(a.__matmul__(b), data1.__matmul__(data2))
def test_mul(self): def test_mul(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
...@@ -3195,6 +3203,17 @@ class TestTensorOps(OpTestCase): ...@@ -3195,6 +3203,17 @@ class TestTensorOps(OpTestCase):
a -= b a -= b
self.assertEqual(a, data1 - data2) self.assertEqual(a, data1 - data2)
def test_transpose(self):
entries = [(0, 2, 1), None]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for perm in entries:
data = arange((2, 3, 4))
x = new_tensor(data)
self.assertEqual(x.transpose(perm), data.transpose(perm))
if perm is None:
self.assertEqual(x.T, data.T)
def test_truncated_normal(self): def test_truncated_normal(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
......
...@@ -666,6 +666,14 @@ class TestModules(OpTestCase): ...@@ -666,6 +666,14 @@ class TestModules(OpTestCase):
if m4 is not None: if m4 is not None:
self.assertEqual(m4(x), np.pad(data, pads, 'constant')) self.assertEqual(m4(x), np.pad(data, pads, 'constant'))
def test_pixel_shuffle(self):
data = np.ones((2, 12, 4, 4), dtype='float32')
x = new_tensor(data)
m1 = torch.nn.PixelShuffle(2)
m2 = torch.nn.PixelUnshuffle(2)
_, _, = repr(m1), repr(m2)
self.assertEqual(m2(m1(x)), data)
def test_pool1d(self): def test_pool1d(self):
entries = [((2, 2, 2,), (2,), 2, 1, 'MaxPool1d'), entries = [((2, 2, 2,), (2,), 2, 1, 'MaxPool1d'),
((2, 2, 2,), (2,), 2, 1, 'AvgPool1d'), ((2, 2, 2,), (2,), 2, 1, 'AvgPool1d'),
......
...@@ -463,7 +463,7 @@ class TestTensorOps(OpTestCase): ...@@ -463,7 +463,7 @@ class TestTensorOps(OpTestCase):
for a_shape, b_shape in test_shapes: for a_shape, b_shape in test_shapes:
data1, data2 = arange(a_shape), arange(b_shape, 1) data1, data2 = arange(a_shape), arange(b_shape, 1)
a, b = new_tensor(data1, False), new_tensor(data2, False) a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.matmul(b), np.matmul(data1, data2)) self.assertEqual(a.__matmul__(b), np.matmul(data1, data2))
def test_max(self): def test_max(self):
entries = [(0, True), (0, False), entries = [(0, True), (0, False),
...@@ -570,8 +570,13 @@ class TestTensorOps(OpTestCase): ...@@ -570,8 +570,13 @@ class TestTensorOps(OpTestCase):
x = new_tensor(data) x = new_tensor(data)
if perm is None: if perm is None:
self.assertEqual(x.permute(), np.transpose(data)) self.assertEqual(x.permute(), np.transpose(data))
self.assertEqual(x.T, data.T)
x.permute_()
self.assertEqual(x, np.transpose(data))
else: else:
self.assertEqual(x.permute(*perm), np.transpose(data, perm)) self.assertEqual(x.permute(*perm), np.transpose(data, perm))
x.permute_(*perm)
self.assertEqual(x, np.transpose(data, perm))
entries = [(0, 1), (0, 2), (1, 2)] entries = [(0, 1), (0, 2), (1, 2)]
for dim0, dim1 in entries: for dim0, dim1 in entries:
data = arange((2, 3, 4)) data = arange((2, 3, 4))
...@@ -579,6 +584,8 @@ class TestTensorOps(OpTestCase): ...@@ -579,6 +584,8 @@ class TestTensorOps(OpTestCase):
perm = list(range(len(data.shape))) perm = list(range(len(data.shape)))
perm[dim0], perm[dim1] = perm[dim1], perm[dim0] perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
self.assertEqual(x.transpose(dim0, dim1), np.transpose(data, perm)) self.assertEqual(x.transpose(dim0, dim1), np.transpose(data, perm))
x.transpose_(dim0, dim1)
self.assertEqual(x, np.transpose(data, perm))
def test_pow(self): def test_pow(self):
for a_shape, b_shape in self.binary_test_shapes: for a_shape, b_shape in self.binary_test_shapes:
......
...@@ -81,6 +81,8 @@ from dragon.vm.torch.core.nn.modules.padding import ReplicationPad1d ...@@ -81,6 +81,8 @@ from dragon.vm.torch.core.nn.modules.padding import ReplicationPad1d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad2d from dragon.vm.torch.core.nn.modules.padding import ReplicationPad2d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad3d from dragon.vm.torch.core.nn.modules.padding import ReplicationPad3d
from dragon.vm.torch.core.nn.modules.padding import ZeroPad2d from dragon.vm.torch.core.nn.modules.padding import ZeroPad2d
from dragon.vm.torch.core.nn.modules.pixelshuffle import PixelShuffle
from dragon.vm.torch.core.nn.modules.pixelshuffle import PixelUnshuffle
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool1d from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool1d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool2d from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool3d from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool3d
......
...@@ -60,6 +60,8 @@ from dragon.vm.torch.core.nn.functional import multi_head_attention_forward ...@@ -60,6 +60,8 @@ from dragon.vm.torch.core.nn.functional import multi_head_attention_forward
from dragon.vm.torch.core.nn.functional import nll_loss from dragon.vm.torch.core.nn.functional import nll_loss
from dragon.vm.torch.core.nn.functional import normalize from dragon.vm.torch.core.nn.functional import normalize
from dragon.vm.torch.core.nn.functional import pad from dragon.vm.torch.core.nn.functional import pad
from dragon.vm.torch.core.nn.functional import pixel_shuffle
from dragon.vm.torch.core.nn.functional import pixel_unshuffle
from dragon.vm.torch.core.nn.functional import prelu from dragon.vm.torch.core.nn.functional import prelu
from dragon.vm.torch.core.nn.functional import relu from dragon.vm.torch.core.nn.functional import relu
from dragon.vm.torch.core.nn.functional import relu6 from dragon.vm.torch.core.nn.functional import relu6
......
...@@ -1746,6 +1746,54 @@ def pad(input, pad, mode='constant', value=0): ...@@ -1746,6 +1746,54 @@ def pad(input, pad, mode='constant', value=0):
ndim=ndim, pads=pads_begin + pads_end) ndim=ndim, pads=pads_begin + pads_end)
def pixel_shuffle(input, upscale_factor, inplace=False):
"""Rearrange depth elements of input into pixels.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
upscale_factor : int
The factor to upscale pixels.
inplace : bool, optional, default=False
Whether to do the operation in-place.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return FunctionLib.apply(
'DepthToSpace', input.device, [input],
outputs=[input if inplace else None],
block_size=int(upscale_factor), mode='CRD', data_format='NCHW')
def pixel_unshuffle(input, downscale_factor, inplace=False):
"""Rearrange pixels of input into depth elements.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
downscale_factor : int
The factor to downscale pixels.
inplace : bool, optional, default=False
Whether to do the operation in-place.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return FunctionLib.apply(
'SpaceToDepth', input.device, [input],
outputs=[input if inplace else None],
block_size=int(downscale_factor), mode='CRD', data_format='NCHW')
def prelu(input, weight): def prelu(input, weight):
r"""Apply parametric rectified linear unit to input. r"""Apply parametric rectified linear unit to input.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_. `[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Shuffle modules.""" """Channel shuffle modules."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Pixel shuffle modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules.module import Module
class PixelShuffle(Module):
"""Rearrange depth elements into pixels.
Examples:
```python
m = torch.nn.PixelShuffle(2)
x = torch.arange(1 * 12 * 1 * 1).reshape((1, 12, 1, 1))
print(m(x).size()) # [1, 3, 2, 2]
```
See Also
--------
`torch.nn.functional.pixel_shuffle(...)`_
"""
def __init__(self, upscale_factor, inplace=False):
"""Create a ``PixelShuffle`` module.
Parameters
----------
upscale_factor : int
The factor to upscale pixels.
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(PixelShuffle, self).__init__()
self.upscale_factor = upscale_factor
self.inplace = inplace
def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
return 'upscale_factor={}{}'.format(self.upscale_factor, inplace_str)
def forward(self, input):
return F.pixel_shuffle(input, self.upscale_factor)
class PixelUnshuffle(Module):
"""Rearrange pixels into depth elements.
Examples:
```python
m = torch.nn.PixelUnshuffle(2)
x = torch.arange(1 * 3 * 2 * 2).reshape((1, 3, 2, 2))
print(m(x).size()) # [1, 12, 1, 1]
```
See Also
--------
`torch.nn.functional.pixel_unshuffle(...)`_
"""
def __init__(self, downscale_factor, inplace=False):
"""Create a ``PixelUnshuffle`` module.
Parameters
----------
downscale_factor : int
The factor to downscale pixels.
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(PixelUnshuffle, self).__init__()
self.downscale_factor = downscale_factor
self.inplace = inplace
def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
return 'downscale_factor={}{}'.format(self.downscale_factor, inplace_str)
def forward(self, input):
return F.pixel_unshuffle(input, self. downscale_factor)
...@@ -861,7 +861,7 @@ def one_hot(input, depth, on_value=1, off_value=0): ...@@ -861,7 +861,7 @@ def one_hot(input, depth, on_value=1, off_value=0):
on_value=float(on_value), off_value=float(off_value)) on_value=float(on_value), off_value=float(off_value))
def permute(input, dims): def permute(input, dims, out=None):
"""Return a tensor with the new order of dimensions. """Return a tensor with the new order of dimensions.
Parameters Parameters
...@@ -870,6 +870,8 @@ def permute(input, dims): ...@@ -870,6 +870,8 @@ def permute(input, dims):
The input tensor. The input tensor.
dims : Sequence[int] dims : Sequence[int]
The output of dimensions. The output of dimensions.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns Returns
------- -------
...@@ -878,7 +880,8 @@ def permute(input, dims): ...@@ -878,7 +880,8 @@ def permute(input, dims):
""" """
return FunctionLib.apply( return FunctionLib.apply(
'Transpose', input.device, [input], ndim=len(dims), perm=dims) 'Transpose', input.device, [input], outputs=[out],
ndim=len(dims), perm=dims)
def tile(input, reps): def tile(input, reps):
...@@ -1338,7 +1341,7 @@ def topk(input, k, dim=-1, largest=True, sorted=True, out=None): ...@@ -1338,7 +1341,7 @@ def topk(input, k, dim=-1, largest=True, sorted=True, out=None):
k=k, axis=dim, largest=largest, sorted=sorted) k=k, axis=dim, largest=largest, sorted=sorted)
def transpose(input, dim0, dim1): def transpose(input, dim0, dim1, out=None):
"""Return a new tensor with two dimensions swapped. """Return a new tensor with two dimensions swapped.
Examples: Examples:
...@@ -1356,6 +1359,8 @@ def transpose(input, dim0, dim1): ...@@ -1356,6 +1359,8 @@ def transpose(input, dim0, dim1):
The first dimension to be transposed. The first dimension to be transposed.
dim1 : int dim1 : int
The second dimension to be transposed. The second dimension to be transposed.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns Returns
------- -------
...@@ -1366,7 +1371,8 @@ def transpose(input, dim0, dim1): ...@@ -1366,7 +1371,8 @@ def transpose(input, dim0, dim1):
dims = list(range(input.ndimension())) dims = list(range(input.ndimension()))
dims[dim0], dims[dim1] = dims[dim1], dims[dim0] dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return FunctionLib.apply( return FunctionLib.apply(
'Transpose', input.device, [input], ndim=len(dims), perm=dims) 'Transpose', input.device, [input], outputs=[out],
ndim=len(dims), perm=dims)
def tril(input, diagonal=0, out=None): def tril(input, diagonal=0, out=None):
......
...@@ -1970,6 +1970,23 @@ def permute(self, *dims): ...@@ -1970,6 +1970,23 @@ def permute(self, *dims):
return array_ops.permute(self, nest.flatten(dims)) return array_ops.permute(self, nest.flatten(dims))
def permute_(self, *dims):
"""Reorder the dimensions.
Parameters
----------
dims : Union[Sequence[int], int...]
The new order of dimensions.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return array_ops.permute(self, nest.flatten(dims), self)
def pow(self, exponent): def pow(self, exponent):
r"""Compute the power. r"""Compute the power.
...@@ -2623,6 +2640,29 @@ def transpose(self, dim0, dim1): ...@@ -2623,6 +2640,29 @@ def transpose(self, dim0, dim1):
return array_ops.transpose(self, dim0, dim1) return array_ops.transpose(self, dim0, dim1)
def transpose_(self, dim0, dim1):
"""Swap two dimensions.
Parameters
----------
dim0 : int
The first dimension to be transposed.
dim1 : int
The second dimension to be transposed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.transpose(...)`_
"""
return array_ops.transpose(self, dim0, dim1, self)
def tril(self, k=0): def tril(self, k=0):
r"""Return the lower triangular part. r"""Return the lower triangular part.
...@@ -2738,12 +2778,12 @@ def triu_(self, k=0): ...@@ -2738,12 +2778,12 @@ def triu_(self, k=0):
def _type(self, dtype=None): def _type(self, dtype=None):
"""Return the data type. """Return the data type.
If attr:`dtype` is not ``None``, converts to a new tensor. If ``dtype`` is not ``None``, converts to a new tensor.
Parameters Parameters
---------- ----------
dtype : str, optional dtype : str, optional
The specified type. The data type to convert to.
Returns Returns
------- -------
...@@ -3008,6 +3048,7 @@ Tensor.new_tensor = new_tensor ...@@ -3008,6 +3048,7 @@ Tensor.new_tensor = new_tensor
Tensor.nonzero = nonzero Tensor.nonzero = nonzero
Tensor.normal_ = normal_ Tensor.normal_ = normal_
Tensor.permute = permute Tensor.permute = permute
Tensor.permute_ = permute_
Tensor.pow = pow Tensor.pow = pow
Tensor.reciprocal = reciprocal Tensor.reciprocal = reciprocal
Tensor.reciprocal_ = reciprocal_ Tensor.reciprocal_ = reciprocal_
...@@ -3037,6 +3078,7 @@ Tensor.sub = sub ...@@ -3037,6 +3078,7 @@ Tensor.sub = sub
Tensor.sub_ = sub_ Tensor.sub_ = sub_
Tensor.topk = topk Tensor.topk = topk
Tensor.transpose = transpose Tensor.transpose = transpose
Tensor.transpose_ = transpose_
Tensor.tril = tril Tensor.tril = tril
Tensor.tril_ = tril_ Tensor.tril_ = tril_
Tensor.triu = triu Tensor.triu = triu
......
...@@ -201,6 +201,18 @@ class Tensor(object): ...@@ -201,6 +201,18 @@ class Tensor(object):
return self.size() return self.size()
@property @property
def T(self):
"""Return a tensor with dimensions reversed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return self.permute()
@property
def volatile(self): def volatile(self):
"""Return whether this tensor is volatile. """Return whether this tensor is volatile.
...@@ -2130,6 +2142,21 @@ class Tensor(object): ...@@ -2130,6 +2142,21 @@ class Tensor(object):
""" """
def permute_(self, *dims):
"""Reorder the dimensions.
Parameters
----------
dims : Union[Sequence[int], int...]
The new order of dimensions.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
def pow(self, exponent): def pow(self, exponent):
"""Compute the power. """Compute the power.
...@@ -2776,6 +2803,27 @@ class Tensor(object): ...@@ -2776,6 +2803,27 @@ class Tensor(object):
""" """
def transpose_(self, dim0, dim1):
"""Swap two dimensions.
Parameters
----------
dim0 : int
The first dimension to be transposed.
dim1 : int
The second dimension to be transposed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.transpose(...)`_
"""
def tril(self, k=0): def tril(self, k=0):
r"""Return the lower triangular part. r"""Return the lower triangular part.
...@@ -2881,17 +2929,19 @@ class Tensor(object): ...@@ -2881,17 +2929,19 @@ class Tensor(object):
""" """
def type(self, dtype=None): def type(self, dtype=None):
"""Return the data type or copied tensor with specified type. """Return the data type.
If ``dtype`` is not ``None``, converts to a new tensor.
Parameters Parameters
---------- ----------
dtype : str, optional dtype : str, optional
The specified type to convert to. The data type to convert to.
Returns Returns
------- -------
Union[str, dragon.vm.torch.Tensor] Union[str, dragon.vm.torch.Tensor]
The data type or copied tensor. The data type or new tensor.
""" """
...@@ -3365,6 +3415,28 @@ class Tensor(object): ...@@ -3365,6 +3415,28 @@ class Tensor(object):
""" """
return self.lt(other) return self.lt(other)
def __matmul__(self, other):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{self} \times \text{tensor2}
Parameters
----------
other : dragon.vm.torch.Tensor
The tensor to multiply.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.matmul(...)`_
"""
return self.matmul(other)
def __mul__(self, other): def __mul__(self, other):
"""Compute the element-wise multiplication. """Compute the element-wise multiplication.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!