Commit 936c351b by Ting PAN

Enhance transpose operators

Summary:
This commit allows transpose to compute in-place by leveraging buffer.
We also adds CRD mode for space-depth transpose (i.e., pixel shuffle).
1 parent ac051717
......@@ -39,9 +39,13 @@ shape
.. autoattribute:: dragon.Tensor.shape
size
#####
####
.. autoattribute:: dragon.Tensor.size
T
#
.. autoattribute:: dragon.Tensor.T
Methods
-------
......@@ -85,6 +89,10 @@ tolist
######
.. automethod:: dragon.Tensor.tolist
transpose
#########
.. automethod:: dragon.Tensor.transpose
truncated_normal
################
.. automethod:: dragon.Tensor.truncated_normal
......@@ -168,6 +176,10 @@ __lt__
######
.. automethod:: dragon.Tensor.__lt__
__matmul__
##########
.. automethod:: dragon.Tensor.__matmul__
__mul__
#######
.. automethod:: dragon.Tensor.__mul__
......@@ -243,6 +255,7 @@ __xor__
.. _dragon.math.greater_equal(...): math/greater_equal.html
.. _dragon.math.less(...): math/less.html
.. _dragon.math.less_equal(...): math/less_equal.html
.. _dragon.math.matmul(...): math/matmul.html
.. _dragon.math.mul(...): math/mul.html
.. _dragon.math.negative(...): math/negative.html
.. _dragon.math.not_equal(...): math/not_equal.html
......@@ -253,6 +266,7 @@ __xor__
.. _dragon.random.truncated_normal(...): random/truncated_normal.html
.. _dragon.random.uniform(...): random/uniform.html
.. _dragon.reshape(...): reshape.html
.. _dragon.transpose(...): transpose.html
.. raw:: html
......
......@@ -38,6 +38,10 @@ shape
#####
.. autoattribute:: dragon.vm.torch.Tensor.shape
T
#
.. autoattribute:: dragon.vm.torch.Tensor.T
Methods
-------
......@@ -213,6 +217,10 @@ exp
###
.. automethod:: dragon.vm.torch.Tensor.exp
exp\_
#####
.. automethod:: dragon.vm.torch.Tensor.exp_
expand
######
.. automethod:: dragon.vm.torch.Tensor.expand
......@@ -317,6 +325,10 @@ log
###
.. automethod:: dragon.vm.torch.Tensor.log
log\_
#####
.. automethod:: dragon.vm.torch.Tensor.log_
logical_and
###########
.. automethod:: dragon.vm.torch.Tensor.logical_and
......@@ -461,6 +473,10 @@ permute
#######
.. automethod:: dragon.vm.torch.Tensor.permute
permute\_
#########
.. automethod:: dragon.vm.torch.Tensor.permute_
pow
###
.. automethod:: dragon.vm.torch.Tensor.pow
......@@ -593,6 +609,10 @@ transpose
#########
.. automethod:: dragon.vm.torch.Tensor.transpose
transpose\_
###########
.. automethod:: dragon.vm.torch.Tensor.transpose_
tril
####
.. automethod:: dragon.vm.torch.Tensor.tril
......
......@@ -197,6 +197,12 @@ vm.torch.nn
`class Parameter <nn/Parameter.html>`_
: A wrapped tensor considered to be a module parameter.
`class PixelShuffle <nn/PixelShuffle.html>`_
: Rearrange depth elements into pixels.
`class PixelUnshuffle <nn/PixelUnshuffle.html>`_
: Rearrange pixels into depth elements.
`class PReLU <nn/PReLU.html>`_
: Apply the parametric rectified linear unit.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
......@@ -354,6 +360,8 @@ vm.torch.nn
nn/MultiheadAttention
nn/NLLLoss
nn/Parameter
nn/PixelShuffle
nn/PixelUnshuffle
nn/PReLU
nn/ReflectionPad1d
nn/ReflectionPad2d
......
PixelShuffle
============
.. autoclass:: dragon.vm.torch.nn.PixelShuffle
__init__
--------
.. automethod:: dragon.vm.torch.nn.PixelShuffle.__init__
.. _torch.nn.functional.pixel_shuffle(...): functional/pixel_shuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
PixelUnshuffle
==============
.. autoclass:: dragon.vm.torch.nn.PixelUnshuffle
__init__
--------
.. automethod:: dragon.vm.torch.nn.PixelUnshuffle.__init__
.. _torch.nn.functional.pixel_unshuffle(...): functional/pixel_unshuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
......@@ -155,6 +155,12 @@ vm.torch.nn.functional
`pad(...) <functional/pad.html>`_
: Pad the input according to the given sizes.
`pixel_shuffle(...) <functional/pixel_shuffle.html>`_
: Rearrange depth elements of input into pixels.
`pixel_unshuffle(...) <functional/pixel_unshuffle.html>`_
: Rearrange pixels of input into depth elements.
`prelu(...) <functional/prelu.html>`_
: Apply the parametric rectified linear unit to input.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
......@@ -256,6 +262,8 @@ vm.torch.nn.functional
functional/nll_loss
functional/normalize
functional/pad
functional/pixel_shuffle
functional/pixel_unshuffle
functional/prelu
functional/relu
functional/relu6
......
pixel_shuffle
=============
.. autofunction:: dragon.vm.torch.nn.functional.pixel_shuffle
.. _torch.nn.PixelShuffle(...): ../PixelShuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
pixel_unshuffle
===============
.. autofunction:: dragon.vm.torch.nn.functional.pixel_unshuffle
.. _torch.nn.PixelUnshuffle(...): ../PixelUnshuffle.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
......@@ -8,62 +8,25 @@ namespace kernels {
namespace {
template <typename T>
void _GroupNormFusedParams(
const int N,
const int G,
const int D,
const T* mu,
const T* rsig,
const T* gamma,
const T* beta,
T* scale,
T* bias) {
const int C = G * D;
ConstEigenArrayMap<T> gamma_arr(gamma, D, G);
ConstEigenArrayMap<T> beta_arr(beta, D, G);
for (int i = 0; i < N; ++i) {
EigenArrayMap<T> scale_arr(scale + i * C, D, G);
scale_arr = gamma_arr.rowwise() *
ConstEigenVectorArrayMap<T>(rsig + i * G, G).transpose();
EigenArrayMap<T>(bias + i * C, D, G) = beta_arr -
scale_arr.rowwise() *
ConstEigenVectorArrayMap<T>(mu + i * G, G).transpose();
}
}
template <typename T, typename AccT>
void _GroupNormNCHW(
const int N,
const int C,
const int S,
const T* x,
const AccT* scale,
const AccT* bias,
T* y) {
EigenArrayMap<T>(y, S, N * C) =
(ConstEigenArrayMap<T>(x, S, N * C).rowwise() *
ConstEigenVectorArrayMap<AccT>(scale, N * C).transpose())
.rowwise() +
ConstEigenVectorArrayMap<AccT>(bias, N * C).transpose();
}
template <typename T, typename AccT>
void _GroupNormNHWC(
const int N,
const int C,
const int S,
template <typename T, typename AccT, StorageOrder kOrder>
void _GroupNorm(
const std::array<int, 4>& dims,
const T* x,
const AccT* scale,
const AccT* bias,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* beta,
T* y) {
const int SC = S * C;
for (int i = 0; i < N; ++i) {
EigenArrayMap<T>(y + i * SC, C, S) =
(ConstEigenArrayMap<T>(x + i * SC, C, S).colwise() *
ConstEigenVectorArrayMap<AccT>(scale + i * C, C))
.colwise() +
ConstEigenVectorArrayMap<AccT>(bias + i * C, C);
const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int NxGxDxS = dims[0] * dims[1] * dims[2] * dims[3];
std::array<int, 4> index = {0, 0, 0, 0};
for (int i = 0; i < NxGxDxS; ++i) {
const int ng = index[0] * dims[kGDim] + index[kGDim];
const int c = index[kGDim] * dims[kDDim] + index[kDDim];
AccT val = (convert::To<AccT>(x[i]) - mu[ng]) * rsig[ng];
y[i] = convert::To<T>(val * gamma[c] + beta[c]);
math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
}
}
......@@ -77,13 +40,13 @@ void _GroupNormInternalGrad(
AccT* db) {
const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int NxGxKxS = dims[0] * dims[1] * dims[2] * dims[3];
const int NxGxDxS = dims[0] * dims[1] * dims[2] * dims[3];
std::array<int, 4> index = {0, 0, 0, 0};
for (int i = 0; i < NxGxKxS; ++i) {
const int mi = index[0] * dims[kGDim] + index[kGDim];
const int gi = index[kGDim] * dims[kDDim] + index[kDDim];
ds[mi] += gamma[gi] * dy[i] * x[i];
db[mi] += gamma[gi] * dy[i];
for (int i = 0; i < NxGxDxS; ++i) {
const int ng = index[0] * dims[kGDim] + index[kGDim];
const int c = index[kGDim] * dims[kDDim] + index[kDDim];
ds[ng] += gamma[c] * dy[i] * x[i];
db[ng] += gamma[c] * dy[i];
math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
}
}
......@@ -103,19 +66,19 @@ void _GroupNormGrad(
T* dx) {
const int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
const int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
const int NxGxKxS = dims[0] * dims[1] * dims[2] * dims[3];
const int NxGxDxS = dims[0] * dims[1] * dims[2] * dims[3];
const int S = kOrder == StorageOrder::NCHW ? dims[3] : dims[1];
const AccT denom = AccT(1) / static_cast<AccT>(dims[kDDim] * S);
std::array<int, 4> index = {0, 0, 0, 0};
for (int i = 0; i < NxGxKxS; ++i) {
const int mi = index[0] * dims[kGDim] + index[kGDim];
const int gi = index[kGDim] * dims[kDDim] + index[kDDim];
const AccT u = (db[mi] * mu[mi] - ds[mi]) * (x[i] - mu[mi]) *
math::utils::Cube(rsig[mi]);
const AccT v = db[mi] * rsig[mi];
dx[i] = gamma[gi] * dy[i] * rsig[mi] + (u - v) * denom;
dgamma[gi] += dy[i] * (x[i] - mu[mi]) * rsig[mi];
dbeta[gi] += dy[i];
for (int i = 0; i < NxGxDxS; ++i) {
const int ng = index[0] * dims[kGDim] + index[kGDim];
const int c = index[kGDim] * dims[kDDim] + index[kDDim];
const AccT u = (db[ng] * mu[ng] - ds[ng]) * (x[i] - mu[ng]) *
math::utils::Cube(rsig[ng]);
const AccT v = db[ng] * rsig[ng];
dx[i] = gamma[c] * dy[i] * rsig[ng] + (u - v) * denom;
dgamma[c] += dy[i] * (x[i] - mu[ng]) * rsig[ng];
dbeta[c] += dy[i];
math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
}
}
......@@ -125,25 +88,6 @@ void _GroupNormGrad(
/* ------------------- Launcher Separator ------------------- */
template <>
void GroupNorm<float16, float, CPUContext>(
const int N,
const int G,
const int D,
const int S,
const string& data_format,
const float16* x,
const float* mu,
const float* rsig,
const float* gamma,
const float* beta,
float* scale,
float* bias,
float16* y,
CPUContext* tx) {
CPU_FP16_NOT_SUPPORTED;
}
template <>
void GroupNormGrad<float16, float, CPUContext>(
const int N,
const int G,
......@@ -177,16 +121,14 @@ void GroupNormGrad<float16, float, CPUContext>(
const AccT* rsig, \
const AccT* gamma, \
const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \
CPUContext* ctx) { \
const int C = G * D; \
_GroupNormFusedParams(N, G, D, mu, rsig, gamma, beta, scale, bias); \
if (data_format == "NCHW") { \
_GroupNormNCHW(N, C, S, x, scale, bias, y); \
_GroupNorm<T, AccT, StorageOrder::NCHW>( \
{N, G, D, S}, x, mu, rsig, gamma, beta, y); \
} else if (data_format == "NHWC") { \
_GroupNormNHWC(N, C, S, x, scale, bias, y); \
_GroupNorm<T, AccT, StorageOrder::NHWC>( \
{N, S, G, D}, x, mu, rsig, gamma, beta, y); \
} \
}
......@@ -226,6 +168,7 @@ void GroupNormGrad<float16, float, CPUContext>(
} \
}
DEFINE_KERNEL_LAUNCHER(float16, float);
DEFINE_KERNEL_LAUNCHER(float, float);
DEFINE_KERNEL_LAUNCHER(double, double);
DEFINE_GRAD_KERNEL_LAUNCHER(float, float);
......
......@@ -14,40 +14,27 @@ namespace {
#define LDG(x, i) __ldg(x + i)
#define LDG2(x, i) convert::To<AccT>(__ldg(x + i))
template <typename T>
__global__ void _GroupNormFusedParams(
const int NxC,
const int C,
const int D,
const T* mu,
const T* rsig,
const T* gamma,
const T* beta,
T* scale,
T* bias) {
CUDA_1D_KERNEL_LOOP(i, NxC) {
const int c = i % C;
const int ng = i / D;
const T scale_val = LDG(gamma, c) * LDG(rsig, ng);
scale[i] = scale_val;
bias[i] = fma(-scale_val, LDG(mu, ng), LDG(beta, c));
}
}
template <typename T, typename AccT, StorageOrder kOrder>
__global__ void _GroupNormAffine(
__global__ void _GroupNorm(
const int NxCxS,
const int C,
const int G,
const int D,
const int S,
const T* x,
const AccT* scale,
const AccT* bias,
const AccT* mu,
const AccT* rsig,
const AccT* gamma,
const AccT* beta,
T* y) {
const int C = G * D;
CUDA_1D_KERNEL_LOOP(i, NxCxS) {
const int nc =
kOrder == StorageOrder::NCHW ? i / S : i / (C * S) * C + i % C;
const int ng = kOrder == StorageOrder::NCHW ? i / (D * S)
: i / (C * S) * G + (i / D % G);
const int c = kOrder == StorageOrder::NCHW ? i / S % C : i % C;
y[i] = convert::To<T>(
fma(convert::To<AccT>(x[i]), LDG(scale, nc), LDG(bias, nc)));
fma((convert::To<AccT>(x[i]) - __ldg(mu + ng)) * __ldg(rsig + ng),
__ldg(gamma + c),
__ldg(beta + c)));
}
}
......@@ -179,30 +166,24 @@ __global__ void _GroupNormGrad(
const AccT* rsig, \
const AccT* gamma, \
const AccT* beta, \
AccT* scale, \
AccT* bias, \
T* y, \
CUDAContext* ctx) { \
const auto C = G * D; \
const auto NxC = N * C; \
const auto NxCxS = NxC * S; \
_GroupNormFusedParams<<< \
CUDA_BLOCKS(NxC), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(NxC, C, D, mu, rsig, gamma, beta, scale, bias); \
const auto NxCxS = N * G * D * S; \
DISPATCH_GROUPNORM_KERNEL( \
_GroupNormAffine, \
_GroupNorm, \
math::ScalarType<T>::type, \
AccT, \
CUDA_BLOCKS(NxCxS), \
CUDA_THREADS, \
NxCxS, \
C, \
G, \
D, \
S, \
reinterpret_cast<const math::ScalarType<T>::type*>(x), \
scale, \
bias, \
mu, \
rsig, \
gamma, \
beta, \
reinterpret_cast<math::ScalarType<T>::type*>(y)); \
}
......
......@@ -25,21 +25,22 @@ class NumpyWrapper {
py::object To(bool copy) {
const auto& meta = tensor_->meta();
const auto& typestr = ::dragon::dtypes::to_string(meta);
const auto& dtype = ::dragon::dtypes::to_string(meta);
CHECK_GT(tensor_->count(), 0) << "\nConvert an empty tensor.";
CHECK(typestr != "unknown") << "\nConvert an empty tensor.";
if (typestr == "string") {
CHECK(dtype != "unknown") << "\nConvert an empty tensor.";
if (dtype == "string") {
CHECK_EQ(tensor_->count(), 1);
return py::bytes(tensor_->data<string, CPUContext>()[0]);
}
auto typenum = dtypes::to_npy(meta);
vector<npy_intp> dims({tensor_->dims().begin(), tensor_->dims().end()});
if (copy) {
auto* memory = tensor_->memory();
CHECK(memory) << "\nConvert an empty tensor.";
auto device_type = memory ? memory->info()["device_type"] : "cpu";
auto* array = PyArray_SimpleNew(tensor_->ndim(), dims.data(), typenum);
auto* array =
PyArray_SimpleNew(dims.size(), dims.data(), dtypes::to_npy(meta));
if (device_type == "cuda") {
CUDADeviceGuard guard(memory->device());
CUDAContext::Memcpy<CPUContext, CUDAContext>(
tensor_->nbytes(),
PyArray_DATA(reinterpret_cast<PyArrayObject*>(array)),
......@@ -53,9 +54,11 @@ class NumpyWrapper {
}
return py::reinterpret_steal<py::object>(array);
}
auto* data = const_cast<void*>(tensor_->raw_data<CPUContext>());
auto* array = PyArray_SimpleNewFromData(
tensor_->ndim(), dims.data(), dtypes::to_npy(meta), data);
dims.size(),
dims.data(),
dtypes::to_npy(meta),
const_cast<void*>(tensor_->raw_data<CPUContext>()));
return py::reinterpret_steal<py::object>(array);
}
......@@ -71,6 +74,7 @@ class NumpyWrapper {
if (copy) {
auto device_type = memory ? memory->info()["device_type"] : "cpu";
if (device_type == "cuda") {
CUDADeviceGuard guard(memory->device());
CUDAContext::Memcpy<CUDAContext, CPUContext>(
tensor_->nbytes(),
tensor_->raw_mutable_data<CUDAContext>(),
......
#include "dragon/operators/array/transpose_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -7,7 +8,7 @@ namespace dragon {
template <class Context>
template <typename T>
void TransposeOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
auto &X = Input(0), *Y = Output(0, {0});
int num_axes, num_dims = X.ndim();
vec64_t X_strides(num_dims), Y_dims(num_dims);
......@@ -34,13 +35,25 @@ void TransposeOp<Context>::DoRunWithType() {
Y_dims[i] = X.dim(new_axes[i]);
}
auto* scratch = ((void*)&X == (void*)Y)
? ctx()->workspace()->template data<T, Context>({X.count()})[0]
: Y->Reshape(Y_dims)->template mutable_data<T, Context>();
kernels::Transpose(
num_dims,
X_strides.data(),
Y_dims.data(),
X.template data<T, Context>(),
scratch,
ctx());
if ((void*)&X == (void*)Y) {
math::Copy(
X.count(),
scratch,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
}
}
DEPLOY_CPU_OPERATOR(Transpose);
......@@ -54,13 +67,17 @@ OPERATOR_SCHEMA(Transpose)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
.NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(TransposeGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1);
.NumOutputs(1)
/* dY => dX */
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Transpose, SimpleGradientMaker);
......
......@@ -37,8 +37,6 @@ void GroupNormOp<Context>::DoRunWithType() {
math::InvStd(N_ * G_, epsilon_, rsig, rsig, ctx());
// Fuse parameters to compute affine transformation.
auto* scratch =
ctx()->workspace()->template data<ParamT, Context>({2 * N_ * C_})[0];
kernels::GroupNorm(
N_,
G_,
......@@ -50,8 +48,6 @@ void GroupNormOp<Context>::DoRunWithType() {
rsig,
W.template data<ParamT, Context>(),
B.template data<ParamT, Context>(),
scratch,
scratch + N_ * C_,
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
......
#include "dragon/operators/vision/space_to_depth_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......@@ -6,7 +8,7 @@ namespace dragon {
template <class Context>
template <typename T>
void SpaceToDepthOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
auto &X = Input(0), *Y = Output(0, {0});
int start_axis, end_axis, perm_count = 0;
int num_dims = X.ndim(), num_axes = X.ndim() - 2;
......@@ -48,8 +50,8 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
if (data_format() == "NCHW") {
for (int i = 0; i < num_axes; i++) {
perm.insert(perm.begin() + 1, perm.back());
perm.pop_back(); // DCR mode
perm.insert(perm.begin() + (mode_ == "DCR" ? 1 : 2), perm.back());
perm.pop_back();
}
}
......@@ -66,19 +68,31 @@ void SpaceToDepthOp<Context>::DoRunWithType() {
Y_dims[i] = X_reshape.dim(perm[i]);
}
auto* scratch = ((void*)&X == (void*)Y)
? ctx()->workspace()->template data<T, Context>({X.count()})[0]
: Y->Reshape(out_shape)->template mutable_data<T, Context>();
kernels::Transpose(
X_strides.size(),
X_strides.data(),
Y_dims.data(),
X.template data<T, Context>(),
scratch,
ctx());
if ((void*)&X == (void*)Y) {
math::Copy(
X.count(),
scratch,
Y->Reshape(out_shape)->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
template <typename T>
void DepthToSpaceOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
auto &X = Input(0), *Y = Output(0, {0});
int start_axis, end_axis;
int num_dims = X.ndim(), num_axes = X.ndim() - 2;
......@@ -94,11 +108,11 @@ void DepthToSpaceOp<Context>::DoRunWithType() {
start_axis = 2, end_axis = num_dims;
out_shape[1] /= std::pow(block_size_, num_axes);
in_dims = out_shape;
perm[1] = num_axes + 1;
perm[1] = (mode_ == "DCR" ? num_axes + 1 : 1);
for (int i = 0; i < num_axes; i++) {
perm[i * 2 + 2] = num_axes + i + 2;
perm[i * 2 + 3] = i + 1;
in_dims.insert(in_dims.begin() + 1, block_size_);
perm[i * 2 + 3] = i + (mode_ == "DCR" ? 1 : 2);
in_dims.insert(in_dims.begin() + (mode_ == "DCR" ? 1 : 2), block_size_);
out_shape[start_axis + i] *= block_size_;
}
} else if (data_format() == "NHWC") {
......@@ -129,13 +143,25 @@ void DepthToSpaceOp<Context>::DoRunWithType() {
Y_dims[i] = X_reshape.dim(perm[i]);
}
auto* scratch = ((void*)&X == (void*)Y)
? ctx()->workspace()->template data<T, Context>({X.count()})[0]
: Y->Reshape(out_shape)->template mutable_data<T, Context>();
kernels::Transpose(
X_strides.size(),
X_strides.data(),
Y_dims.data(),
X.template data<T, Context>(),
scratch,
ctx());
if ((void*)&X == (void*)Y) {
math::Copy(
X.count(),
scratch,
Y->Reshape(out_shape)->template mutable_data<T, Context>(),
ctx());
}
}
DEPLOY_CPU_OPERATOR(SpaceToDepth);
......@@ -152,10 +178,16 @@ DEPLOY_CUDA_OPERATOR(DepthToSpace);
REGISTER_CUDA_OPERATOR(DepthToSpaceGradient, SpaceToDepthOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(SpaceToDepth).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(SpaceToDepthGradient).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(DepthToSpace).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(DepthToSpaceGradient).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(SpaceToDepth).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SpaceToDepthGradient)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DepthToSpace).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DepthToSpaceGradient)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(SpaceToDepth, SimpleGradientMaker);
REGISTER_GRADIENT(DepthToSpace, SimpleGradientMaker);
......
......@@ -22,7 +22,8 @@ class SpaceToDepthOp final : public Operator<Context> {
public:
SpaceToDepthOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
block_size_(OP_SINGLE_ARG(int, "block_size", 2)) {}
block_size_(OP_SINGLE_ARG(int, "block_size", 2)),
mode_(OP_SINGLE_ARG(string, "mode", "DCR")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
......@@ -33,6 +34,7 @@ class SpaceToDepthOp final : public Operator<Context> {
void DoRunWithType();
protected:
string mode_;
int64_t block_size_;
};
......@@ -41,7 +43,8 @@ class DepthToSpaceOp final : public Operator<Context> {
public:
DepthToSpaceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
block_size_(OP_SINGLE_ARG(int, "block_size", 2)) {}
block_size_(OP_SINGLE_ARG(int, "block_size", 2)),
mode_(OP_SINGLE_ARG(string, "mode", "DCR")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
......@@ -52,6 +55,7 @@ class DepthToSpaceOp final : public Operator<Context> {
void DoRunWithType();
protected:
string mode_;
int64_t block_size_;
};
......
......@@ -159,6 +159,7 @@ def cum_reduce_args(**kwargs):
def depth_space_args(**kwargs):
return {
'block_size': kwargs.get('block_size', '2'),
'mode': kwargs.get('mode', 'DCR'),
'data_format': kwargs.get('data_format', 'NCHW'),
}
......
......@@ -180,6 +180,18 @@ class Tensor(types.TensorBase):
return float('inf')
return math_util.prod(self._shape)
@property
def T(self):
"""Return a tensor with axes reversed.
Returns
-------
dragon.Tensor
The output tensor.
"""
return self.transpose()
def astype(self, dtype, copy=True):
"""Convert the data type to a specific one.
......@@ -365,6 +377,27 @@ class Tensor(types.TensorBase):
"""
return self.numpy().tolist()
def transpose(self, perm=None, copy=True):
"""Return a tensor with permuted axes.
Parameters
----------
perm : Union[Sequence[int], dragon.Tensor]], optional
The output permutation.
copy : bool, optional, default=True
Return a new tensor or transpose in-place.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.transpose(...)`_
"""
def truncated_normal(self, mean=0, std=1):
r"""Fill self from a truncated normal distribution.
......@@ -694,6 +727,25 @@ class Tensor(types.TensorBase):
"""
def __matmul__(self, other):
"""Compute the matrix multiplication.
Parameters
----------
other : dragon.Tensor
The value to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.math.matmul(...)`_
"""
def __mul__(self, other):
"""Compute the element-wise multiplication.
......
......@@ -1754,7 +1754,7 @@ def tile(inputs, repeats, **kwargs):
@OpSchema.num_inputs(1)
@OpSchema.convert_arg('perm')
def transpose(inputs, perm=None, **kwargs):
def transpose(inputs, perm=None, copy=True, **kwargs):
r"""Permute the dimensions of input.
Examples:
......@@ -1774,6 +1774,8 @@ def transpose(inputs, perm=None, **kwargs):
The input tensor.
perm : Union[Sequence[int], dragon.Tensor]], optional
The output permutation.
copy : bool, optional, default=True
Return a new tensor or call in-place.
Returns
-------
......@@ -1785,6 +1787,7 @@ def transpose(inputs, perm=None, **kwargs):
if context.executing_eagerly():
return OpLib.execute(
'Transpose', inputs,
outputs=[None] if copy else inputs,
ndim=len(args['perm']) if perm is not None else 0,
perm=args['perm'])
return OpLib.add('Transpose', **args)
......
......@@ -516,6 +516,27 @@ def lt(self, other):
return _apply_binary_op([self, other], 'Less')
def matmul(self, other):
"""Compute the matrix multiplication.
Parameters
----------
other : dragon.Tensor
The value to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.math.matmul(...)`_
"""
return _apply_binary_op([self, other], 'MatMul')
def mul(self, other):
"""Compute the element-wise multiplication.
......@@ -844,6 +865,29 @@ def sub(self, other):
return _apply_binary_op([self, other], 'Sub')
def transpose(self, perm=None, copy=True):
"""Return a tensor with permuted axes.
Parameters
----------
perm : Union[Sequence[int], dragon.Tensor]], optional
The output permutation.
copy : bool, optional, default=True
Return a new tensor or transpose in-place.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`dragon.transpose(...)`_
"""
return array_ops.transpose(self, perm=perm, copy=copy)
def truncated_normal(self, mean=0, std=1):
r"""Fill self from a truncated normal distribution.
......@@ -984,6 +1028,7 @@ Tensor.glorot_normal = glorot_normal
Tensor.glorot_uniform = glorot_uniform
Tensor.normal = normal
Tensor.reshape = reshape
Tensor.transpose = transpose
Tensor.truncated_normal = truncated_normal
Tensor.uniform = uniform
Tensor.__add__ = add
......@@ -1003,6 +1048,7 @@ Tensor.__itruediv__ = idiv
Tensor.__ixor__ = ixor
Tensor.__le__ = le
Tensor.__lt__ = lt
Tensor.__matmul__ = matmul
Tensor.__mul__ = mul
Tensor.__ne__ = ne
Tensor.__neg__ = neg
......
......@@ -831,7 +831,14 @@ def depthwise_conv2d(
@OpSchema.num_inputs(1)
def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs):
def depth_to_space(
inputs,
block_size,
mode='DCR',
data_format='NCHW',
copy=True,
**kwargs
):
"""Rearrange depth data into spatial blocks.
Examples:
......@@ -851,8 +858,12 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs):
The input tensor.
block_size : int, required
The size of spatial block.
mode : str, optional, default='DCR'
Rearrangement order for ``'NCHW'`` format.
data_format : str, optional, default='NCHW'
``'NCHW'`` or ``'NHWC'``.
copy : bool, optional, default=True
Return a new tensor or call in-place.
Returns
-------
......@@ -865,9 +876,10 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs):
if context.executing_eagerly():
return OpLib.execute(
'DepthToSpace', inputs,
block_size=block_size, data_format=data_format)
outputs=[None] if copy else inputs,
block_size=block_size, mode=mode.upper(), data_format=data_format)
return OpLib.add('DepthToSpace', inputs, block_size=block_size,
data_format=data_format, **kwargs)
mode=mode.upper(), data_format=data_format, **kwargs)
@OpSchema.num_inputs(1)
......@@ -1482,7 +1494,14 @@ def roi_pool(
@OpSchema.num_inputs(1)
def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs):
def space_to_depth(
inputs,
block_size,
mode='DCR',
data_format='NCHW',
copy=True,
**kwargs
):
"""Rearrange blocks of spatial data into depth.
Examples:
......@@ -1492,7 +1511,7 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs):
x = dragon.range(n * c * h * w).reshape((n, c, h, w))
y = dragon.reshape(x, (n, c, h // bs, bs, w // bs, bs))
y = dragon.transpose(y, (0, 3, 5, 1, 2, 4))
y = dragon.reshape(y, (n, c * (bs ** 2), h // bs, w // bs))
y = dragon.reshape(y, (n, (bs ** 2) * c, h // bs, w // bs))
z = dragon.nn.space_to_depth(x, 2) # Equivalent
```
......@@ -1502,8 +1521,12 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs):
The input tensor.
block_size : int, required
The size of spatial block.
mode : str, optional, default='DCR'
Rearrangement order for ``'NCHW'`` format.
data_format : str, optional, default='NCHW'
``'NCHW'`` or ``'NHWC'``.
copy : bool, optional, default=True
Return a new tensor or call in-place.
Returns
-------
......@@ -1516,9 +1539,10 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs):
if context.executing_eagerly():
return OpLib.execute(
'SpaceToDepth', inputs,
block_size=block_size, data_format=data_format)
outputs=[None] if copy else inputs,
block_size=block_size, mode=mode.upper(), data_format=data_format)
return OpLib.add('SpaceToDepth', inputs, block_size=block_size,
data_format=data_format, **kwargs)
mode=mode.upper(), data_format=data_format, **kwargs)
def _normalize_tuple(value, rank):
......
......@@ -68,6 +68,9 @@ def depth_space_exporter(op_def, context):
_assert_data_format(arg)
if arg.name == 'block_size':
helper.add_attribute(node, 'blocksize', arg.i)
elif arg.name == 'mode':
if node.op_type != 'SpaceToDepth':
helper.add_attribute(node, 'mode', arg.s)
return node, const_tensors
......
......@@ -918,8 +918,6 @@ void GroupNorm(
const AccT* rsig,
const AccT* gamma,
const AccT* beta,
AccT* scale,
AccT* bias,
T* y,
Context* ctx);
......
......@@ -3101,6 +3101,14 @@ class TestTensorOps(OpTestCase):
x = new_tensor(data)
self.assertEqual(~x, ~data)
def test_matmul(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for a_shape, b_shape in self.binary_test_shapes:
data1, data2 = arange((2, 3)), arange((3, 4), 1)
a, b = new_tensor(data1), new_tensor(data2)
self.assertEqual(a.__matmul__(b), data1.__matmul__(data2))
def test_mul(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......@@ -3195,6 +3203,17 @@ class TestTensorOps(OpTestCase):
a -= b
self.assertEqual(a, data1 - data2)
def test_transpose(self):
entries = [(0, 2, 1), None]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for perm in entries:
data = arange((2, 3, 4))
x = new_tensor(data)
self.assertEqual(x.transpose(perm), data.transpose(perm))
if perm is None:
self.assertEqual(x.T, data.T)
def test_truncated_normal(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
......
......@@ -666,6 +666,14 @@ class TestModules(OpTestCase):
if m4 is not None:
self.assertEqual(m4(x), np.pad(data, pads, 'constant'))
def test_pixel_shuffle(self):
data = np.ones((2, 12, 4, 4), dtype='float32')
x = new_tensor(data)
m1 = torch.nn.PixelShuffle(2)
m2 = torch.nn.PixelUnshuffle(2)
_, _, = repr(m1), repr(m2)
self.assertEqual(m2(m1(x)), data)
def test_pool1d(self):
entries = [((2, 2, 2,), (2,), 2, 1, 'MaxPool1d'),
((2, 2, 2,), (2,), 2, 1, 'AvgPool1d'),
......
......@@ -463,7 +463,7 @@ class TestTensorOps(OpTestCase):
for a_shape, b_shape in test_shapes:
data1, data2 = arange(a_shape), arange(b_shape, 1)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.matmul(b), np.matmul(data1, data2))
self.assertEqual(a.__matmul__(b), np.matmul(data1, data2))
def test_max(self):
entries = [(0, True), (0, False),
......@@ -570,8 +570,13 @@ class TestTensorOps(OpTestCase):
x = new_tensor(data)
if perm is None:
self.assertEqual(x.permute(), np.transpose(data))
self.assertEqual(x.T, data.T)
x.permute_()
self.assertEqual(x, np.transpose(data))
else:
self.assertEqual(x.permute(*perm), np.transpose(data, perm))
x.permute_(*perm)
self.assertEqual(x, np.transpose(data, perm))
entries = [(0, 1), (0, 2), (1, 2)]
for dim0, dim1 in entries:
data = arange((2, 3, 4))
......@@ -579,6 +584,8 @@ class TestTensorOps(OpTestCase):
perm = list(range(len(data.shape)))
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
self.assertEqual(x.transpose(dim0, dim1), np.transpose(data, perm))
x.transpose_(dim0, dim1)
self.assertEqual(x, np.transpose(data, perm))
def test_pow(self):
for a_shape, b_shape in self.binary_test_shapes:
......
......@@ -81,6 +81,8 @@ from dragon.vm.torch.core.nn.modules.padding import ReplicationPad1d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad2d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad3d
from dragon.vm.torch.core.nn.modules.padding import ZeroPad2d
from dragon.vm.torch.core.nn.modules.pixelshuffle import PixelShuffle
from dragon.vm.torch.core.nn.modules.pixelshuffle import PixelUnshuffle
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool1d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool3d
......
......@@ -60,6 +60,8 @@ from dragon.vm.torch.core.nn.functional import multi_head_attention_forward
from dragon.vm.torch.core.nn.functional import nll_loss
from dragon.vm.torch.core.nn.functional import normalize
from dragon.vm.torch.core.nn.functional import pad
from dragon.vm.torch.core.nn.functional import pixel_shuffle
from dragon.vm.torch.core.nn.functional import pixel_unshuffle
from dragon.vm.torch.core.nn.functional import prelu
from dragon.vm.torch.core.nn.functional import relu
from dragon.vm.torch.core.nn.functional import relu6
......
......@@ -1746,6 +1746,54 @@ def pad(input, pad, mode='constant', value=0):
ndim=ndim, pads=pads_begin + pads_end)
def pixel_shuffle(input, upscale_factor, inplace=False):
"""Rearrange depth elements of input into pixels.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
upscale_factor : int
The factor to upscale pixels.
inplace : bool, optional, default=False
Whether to do the operation in-place.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return FunctionLib.apply(
'DepthToSpace', input.device, [input],
outputs=[input if inplace else None],
block_size=int(upscale_factor), mode='CRD', data_format='NCHW')
def pixel_unshuffle(input, downscale_factor, inplace=False):
"""Rearrange pixels of input into depth elements.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
downscale_factor : int
The factor to downscale pixels.
inplace : bool, optional, default=False
Whether to do the operation in-place.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return FunctionLib.apply(
'SpaceToDepth', input.device, [input],
outputs=[input if inplace else None],
block_size=int(downscale_factor), mode='CRD', data_format='NCHW')
def prelu(input, weight):
r"""Apply parametric rectified linear unit to input.
`[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Shuffle modules."""
"""Channel shuffle modules."""
from __future__ import absolute_import
from __future__ import division
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Pixel shuffle modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules.module import Module
class PixelShuffle(Module):
"""Rearrange depth elements into pixels.
Examples:
```python
m = torch.nn.PixelShuffle(2)
x = torch.arange(1 * 12 * 1 * 1).reshape((1, 12, 1, 1))
print(m(x).size()) # [1, 3, 2, 2]
```
See Also
--------
`torch.nn.functional.pixel_shuffle(...)`_
"""
def __init__(self, upscale_factor, inplace=False):
"""Create a ``PixelShuffle`` module.
Parameters
----------
upscale_factor : int
The factor to upscale pixels.
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(PixelShuffle, self).__init__()
self.upscale_factor = upscale_factor
self.inplace = inplace
def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
return 'upscale_factor={}{}'.format(self.upscale_factor, inplace_str)
def forward(self, input):
return F.pixel_shuffle(input, self.upscale_factor)
class PixelUnshuffle(Module):
"""Rearrange pixels into depth elements.
Examples:
```python
m = torch.nn.PixelUnshuffle(2)
x = torch.arange(1 * 3 * 2 * 2).reshape((1, 3, 2, 2))
print(m(x).size()) # [1, 12, 1, 1]
```
See Also
--------
`torch.nn.functional.pixel_unshuffle(...)`_
"""
def __init__(self, downscale_factor, inplace=False):
"""Create a ``PixelUnshuffle`` module.
Parameters
----------
downscale_factor : int
The factor to downscale pixels.
inplace : bool, optional, default=False
Whether to do the operation in-place.
"""
super(PixelUnshuffle, self).__init__()
self.downscale_factor = downscale_factor
self.inplace = inplace
def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
return 'downscale_factor={}{}'.format(self.downscale_factor, inplace_str)
def forward(self, input):
return F.pixel_unshuffle(input, self. downscale_factor)
......@@ -861,7 +861,7 @@ def one_hot(input, depth, on_value=1, off_value=0):
on_value=float(on_value), off_value=float(off_value))
def permute(input, dims):
def permute(input, dims, out=None):
"""Return a tensor with the new order of dimensions.
Parameters
......@@ -870,6 +870,8 @@ def permute(input, dims):
The input tensor.
dims : Sequence[int]
The output of dimensions.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
......@@ -878,7 +880,8 @@ def permute(input, dims):
"""
return FunctionLib.apply(
'Transpose', input.device, [input], ndim=len(dims), perm=dims)
'Transpose', input.device, [input], outputs=[out],
ndim=len(dims), perm=dims)
def tile(input, reps):
......@@ -1338,7 +1341,7 @@ def topk(input, k, dim=-1, largest=True, sorted=True, out=None):
k=k, axis=dim, largest=largest, sorted=sorted)
def transpose(input, dim0, dim1):
def transpose(input, dim0, dim1, out=None):
"""Return a new tensor with two dimensions swapped.
Examples:
......@@ -1356,6 +1359,8 @@ def transpose(input, dim0, dim1):
The first dimension to be transposed.
dim1 : int
The second dimension to be transposed.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
......@@ -1366,7 +1371,8 @@ def transpose(input, dim0, dim1):
dims = list(range(input.ndimension()))
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return FunctionLib.apply(
'Transpose', input.device, [input], ndim=len(dims), perm=dims)
'Transpose', input.device, [input], outputs=[out],
ndim=len(dims), perm=dims)
def tril(input, diagonal=0, out=None):
......
......@@ -1970,6 +1970,23 @@ def permute(self, *dims):
return array_ops.permute(self, nest.flatten(dims))
def permute_(self, *dims):
"""Reorder the dimensions.
Parameters
----------
dims : Union[Sequence[int], int...]
The new order of dimensions.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return array_ops.permute(self, nest.flatten(dims), self)
def pow(self, exponent):
r"""Compute the power.
......@@ -2623,6 +2640,29 @@ def transpose(self, dim0, dim1):
return array_ops.transpose(self, dim0, dim1)
def transpose_(self, dim0, dim1):
"""Swap two dimensions.
Parameters
----------
dim0 : int
The first dimension to be transposed.
dim1 : int
The second dimension to be transposed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.transpose(...)`_
"""
return array_ops.transpose(self, dim0, dim1, self)
def tril(self, k=0):
r"""Return the lower triangular part.
......@@ -2738,12 +2778,12 @@ def triu_(self, k=0):
def _type(self, dtype=None):
"""Return the data type.
If attr:`dtype` is not ``None``, converts to a new tensor.
If ``dtype`` is not ``None``, converts to a new tensor.
Parameters
----------
dtype : str, optional
The specified type.
The data type to convert to.
Returns
-------
......@@ -3008,6 +3048,7 @@ Tensor.new_tensor = new_tensor
Tensor.nonzero = nonzero
Tensor.normal_ = normal_
Tensor.permute = permute
Tensor.permute_ = permute_
Tensor.pow = pow
Tensor.reciprocal = reciprocal
Tensor.reciprocal_ = reciprocal_
......@@ -3037,6 +3078,7 @@ Tensor.sub = sub
Tensor.sub_ = sub_
Tensor.topk = topk
Tensor.transpose = transpose
Tensor.transpose_ = transpose_
Tensor.tril = tril
Tensor.tril_ = tril_
Tensor.triu = triu
......
......@@ -201,6 +201,18 @@ class Tensor(object):
return self.size()
@property
def T(self):
"""Return a tensor with dimensions reversed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return self.permute()
@property
def volatile(self):
"""Return whether this tensor is volatile.
......@@ -2130,6 +2142,21 @@ class Tensor(object):
"""
def permute_(self, *dims):
"""Reorder the dimensions.
Parameters
----------
dims : Union[Sequence[int], int...]
The new order of dimensions.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
def pow(self, exponent):
"""Compute the power.
......@@ -2776,6 +2803,27 @@ class Tensor(object):
"""
def transpose_(self, dim0, dim1):
"""Swap two dimensions.
Parameters
----------
dim0 : int
The first dimension to be transposed.
dim1 : int
The second dimension to be transposed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.transpose(...)`_
"""
def tril(self, k=0):
r"""Return the lower triangular part.
......@@ -2881,17 +2929,19 @@ class Tensor(object):
"""
def type(self, dtype=None):
"""Return the data type or copied tensor with specified type.
"""Return the data type.
If ``dtype`` is not ``None``, converts to a new tensor.
Parameters
----------
dtype : str, optional
The specified type to convert to.
The data type to convert to.
Returns
-------
Union[str, dragon.vm.torch.Tensor]
The data type or copied tensor.
The data type or new tensor.
"""
......@@ -3365,6 +3415,28 @@ class Tensor(object):
"""
return self.lt(other)
def __matmul__(self, other):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{self} \times \text{tensor2}
Parameters
----------
other : dragon.vm.torch.Tensor
The tensor to multiply.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.matmul(...)`_
"""
return self.matmul(other)
def __mul__(self, other):
"""Compute the element-wise multiplication.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!