Commit 6bfe3e73 by Ting PAN

Reimplement the general matrix multiplication

Summary:
This commit generalizes the fully-connected operation into GEMM,
and enhances the matmul operation via batched Dot, GEMV and GEMM.
New representations and attributes have been consistent with ONNX.
1 parent 73ed1b96
Showing with 3010 additions and 1337 deletions
......@@ -313,8 +313,8 @@ class InnerProduct(Layer):
param = layer_param.inner_product_param
self.arguments = {
'axis': param.axis,
'out_channels': param.num_output,
'transpose_w': not param.transpose,
'n': param.num_output,
'transpose_b': not param.transpose,
}
self.add_blob(filler=self.get_filler(param, 'weight_filler'))
if param.bias_term:
......@@ -322,7 +322,7 @@ class InnerProduct(Layer):
def __call__(self, bottom):
inputs = [bottom] + [blob['data'] for blob in self._blobs]
return math_ops.fully_connected(inputs, **self.arguments)
return math_ops.gemm(inputs, **self.arguments)
class Input(Layer):
......@@ -409,7 +409,7 @@ class Normalize(Layer):
def __call__(self, bottom):
norm_out = [normalization_ops.lp_normalize(bottom, **self.l2norm_arguments)]
norm_out += [blob['data'] for blob in self._blobs]
return math_ops.affine(norm_out, **self.affine_arguments)
return array_ops.channel_affine(norm_out, **self.affine_arguments)
class Permute(Layer):
......@@ -583,7 +583,7 @@ class Scale(Layer):
def __call__(self, bottom):
inputs = [bottom] + [blob['data'] for blob in self._blobs]
return math_ops.affine(inputs, **self.arguments)
return array_ops.channel_affine(inputs, **self.arguments)
class Slice(Layer):
......
......@@ -48,6 +48,9 @@ dragon.math
`floor(...) <math/floor.html>`_
: Compute the largest integer not greater than input.
`gemm(...) <math/gemm.html>`_
: Compute the general matrix multiplication.
`greater(...) <math/greater.html>`_
: Compute the element-wise greater comparison.
......@@ -158,6 +161,7 @@ dragon.math
math/equal
math/exp
math/floor
math/gemm
math/greater
math/greater_equal
math/is_inf
......
fully_connected
===============
gemm
====
.. autofunction:: dragon.nn.fully_connected
.. autofunction:: dragon.math.gemm
.. raw:: html
<style>
h1:before {
content: "dragon.nn.";
content: "dragon.math.";
color: #103d3e;
}
</style>
......@@ -74,9 +74,6 @@ dragon.nn
: Apply the exponential linear unit.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`fully_connected(...) <nn/fully_connected.html>`_
: Compute the dense matrix multiplication along the given axes.
`group_norm(...) <nn/group_norm.html>`_
: Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
......@@ -167,7 +164,6 @@ dragon.nn
nn/drop_block2d
nn/drop_path
nn/elu
nn/fully_connected
nn/group_norm
nn/hardsigmoid
nn/hardswish
......
......@@ -79,7 +79,7 @@ Name Supported Reference
`Gather`_ |v| :func:`dragon.index_select`
`GatherElements`_
`GatherND`_
`Gemm`_ |v| :func:`dragon.nn.fully_connected`
`Gemm`_ |v| :func:`dragon.math.gemm`
`GlobalAveragePool`_ |v| :func:`dragon.nn.pool2d`
`GlobalLpPool`_
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d`
......
......@@ -36,6 +36,9 @@ vm.torch
`add(...) <torch/add.html>`_
: Compute the element-wise addition.
`addmm(...) <torch/addmm.html>`_
: Add input to the result of matrix-matrix multiplication.
`arange(...) <torch/arange.html>`_
: Return a tensor of evenly spaced values within a interval.
......@@ -51,12 +54,18 @@ vm.torch
`axpby(...) <torch/axpby.html>`_
: Compute the element-wise addition from input to output.
`baddbmm(...) <torch/baddbmm.html>`_
: Add input to the result of batched matrix-matrix multiplication.
`bitwise_not(...) <torch/bitwise_not.html>`_
: Compute the element-wise NOT bitwise operation.
`bitwise_xor(...) <torch/bitwise_xor.html>`_
: Compute the element-wise XOR bitwise operation.
`bmm(...) <torch/bmm.html>`_
: Compute the batched matrix-matrix multiplication.
`cat(...) <torch/cat.html>`_
: Concatenate the inputs along the given dimension.
......@@ -148,6 +157,9 @@ vm.torch
`masked_select(...) <torch/logsumexp.html>`_
: Select the input elements where mask is 1.
`matmul(...) <torch/matmul.html>`_
: Compute the matrix multiplication.
`max(...) <torch/max.html>`_
: Compute the max value of elements along the given dimension.
......@@ -281,13 +293,16 @@ vm.torch
torch/Tensor_
torch/abs
torch/add
torch/addmm
torch/arange
torch/argmax
torch/argmin
torch/argsort
torch/axpby
torch/baddbmm
torch/bitwise_not
torch/bitwise_xor
torch/bmm
torch/cat
torch/ceil
torch/channel_affine
......@@ -321,6 +336,7 @@ vm.torch
torch/logsumexp
torch/lt
torch/masked_select
torch/matmul
torch/max
torch/maximum
torch/mean
......
......@@ -53,6 +53,10 @@ add\_
#####
.. automethod:: dragon.vm.torch.Tensor.add_
addmm
#####
.. automethod:: dragon.vm.torch.Tensor.addmm
argmax
######
.. automethod:: dragon.vm.torch.Tensor.argmax
......@@ -69,6 +73,14 @@ backward
########
.. automethod:: dragon.vm.torch.Tensor.backward
baddbmm
#######
.. automethod:: dragon.vm.torch.Tensor.baddbmm
baddbmm\_
#########
.. automethod:: dragon.vm.torch.Tensor.baddbmm_
bitwise_not
###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_not
......@@ -85,6 +97,10 @@ bitwise_xor\_
#############
.. automethod:: dragon.vm.torch.Tensor.bitwise_xor_
bmm
###
.. automethod:: dragon.vm.torch.Tensor.bmm
bool
####
.. automethod:: dragon.vm.torch.Tensor.bool
......@@ -285,6 +301,14 @@ masked_fill\_
#############
.. automethod:: dragon.vm.torch.Tensor.masked_fill_
masked_select
#############
.. automethod:: dragon.vm.torch.Tensor.masked_select
matmul
######
.. automethod:: dragon.vm.torch.Tensor.matmul
max
###
.. automethod:: dragon.vm.torch.Tensor.max
......@@ -293,10 +317,6 @@ maximum
#######
.. automethod:: dragon.vm.torch.Tensor.maximum
masked_select
#############
.. automethod:: dragon.vm.torch.Tensor.masked_select
mean
####
.. automethod:: dragon.vm.torch.Tensor.mean
......@@ -535,11 +555,14 @@ zero\_
.. _torch.abs(...): abs.html
.. _torch.add(...): add.html
.. _torch.addmm(...): addmm.html
.. _torch.argmax(...): argmax.html
.. _torch.argmin(...): argmin.html
.. _torch.argsort(...): argsort.html
.. _torch.baddbmm(...): baddbmm.html
.. _torch.bitwise_not(...): bitwise_not.html
.. _torch.bitwise_xor(...): bitwise_xor.html
.. _torch.bmm(...): bmm.html
.. _torch.ceil(...): ceil.html
.. _torch.clamp(...): clamp.html
.. _torch.cos(...): cos.html
......@@ -557,6 +580,7 @@ zero\_
.. _torch.isnan(...): isnan.html
.. _torch.le(...): le.html
.. _torch.lt(...): lt.html
.. _torch.matmul(...): matmul.html
.. _torch.max(...): max.html
.. _torch.maximum(...): maximum.html
.. _torch.min(...): min.html
......
addmm
=====
.. autofunction:: dragon.vm.torch.addmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
baddbmm
=======
.. autofunction:: dragon.vm.torch.baddbmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
bmm
===
.. autofunction:: dragon.vm.torch.bmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
matmul
======
.. autofunction:: dragon.vm.torch.matmul
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
......@@ -6,6 +6,24 @@ vm.torch.nn
Classes
-------
`class AdaptiveAvgPool1d <nn/AdaptiveAvgPool1d.html>`_
: Apply the 1d adaptive average pooling.
`class AdaptiveAvgPool2d <nn/AdaptiveAvgPool2d.html>`_
: Apply the 2d adaptive average pooling.
`class AdaptiveAvgPool3d <nn/AdaptiveAvgPool3d.html>`_
: Apply the 3d adaptive average pooling.
`class AdaptiveMaxPool1d <nn/AdaptiveMaxPool1d.html>`_
: Apply the 1d adaptive max pooling.
`class AdaptiveMaxPool2d <nn/AdaptiveMaxPool2d.html>`_
: Apply the 2d adaptive max pooling.
`class AdaptiveMaxPool3d <nn/AdaptiveMaxPool3d.html>`_
: Apply the 3d adaptive max pooling.
`class AffineChannel <nn/AffineChannel.html>`_
: Apply affine transformation along the channels.
......@@ -238,6 +256,12 @@ vm.torch.nn
.. toctree::
:hidden:
nn/AdaptiveAvgPool1d
nn/AdaptiveAvgPool2d
nn/AdaptiveAvgPool3d
nn/AdaptiveMaxPool1d
nn/AdaptiveMaxPool2d
nn/AdaptiveMaxPool3d
nn/AffineChannel
nn/AvgPool1d
nn/AvgPool2d
......
AdaptiveAvgPool1d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool1d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool1d.__init__
.. _torch.nn.functional.adaptive_avg_pool1d(...): functional/adaptive_avg_pool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveAvgPool2d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool2d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool2d.__init__
.. _torch.nn.functional.adaptive_avg_pool2d(...): functional/adaptive_avg_pool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveAvgPool3d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool3d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool3d.__init__
.. _torch.nn.functional.adaptive_avg_pool3d(...): functional/adaptive_avg_pool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool1d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool1d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool1d.__init__
.. _torch.nn.functional.adaptive_max_pool1d(...): functional/adaptive_max_pool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool2d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool2d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool2d.__init__
.. _torch.nn.functional.adaptive_max_pool2d(...): functional/adaptive_max_pool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool3d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool3d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool3d.__init__
.. _torch.nn.functional.adaptive_max_pool3d(...): functional/adaptive_max_pool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
......@@ -6,6 +6,24 @@ vm.torch.nn.functional
Functions
---------
`adaptive_avg_pool1d(...) <functional/adaptive_avg_pool1d.html>`_
: Apply the 1d adaptive average pooling to input.
`adaptive_avg_pool2d(...) <functional/adaptive_avg_pool2d.html>`_
: Apply the 2d adaptive average pooling to input.
`adaptive_avg_pool3d(...) <functional/adaptive_avg_pool3d.html>`_
: Apply the 3d adaptive average pooling to input.
`adaptive_max_pool1d(...) <functional/adaptive_max_pool1d.html>`_
: Apply the 1d adaptive max pooling to input.
`adaptive_max_pool2d(...) <functional/adaptive_max_pool2d.html>`_
: Apply the 2d adaptive max pooling to input.
`adaptive_max_pool3d(...) <functional/adaptive_max_pool3d.html>`_
: Apply the 3d adaptive max pooling to input.
`avg_pool1d(...) <functional/avg_pool1d.html>`_
: Apply the 1d average pooling to input.
......@@ -167,6 +185,12 @@ vm.torch.nn.functional
.. toctree::
:hidden:
functional/adaptive_avg_pool1d
functional/adaptive_avg_pool2d
functional/adaptive_avg_pool3d
functional/adaptive_max_pool1d
functional/adaptive_max_pool2d
functional/adaptive_max_pool3d
functional/avg_pool1d
functional/avg_pool2d
functional/avg_pool3d
......
adaptive_avg_pool1d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool1d
.. _torch.nn.AdaptiveAvgPool1d(...): ../AdaptiveAvgPool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_avg_pool2d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool2d
.. _torch.nn.AdaptiveAvgPool2d(...): ../AdaptiveAvgPool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_avg_pool3d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool3d
.. _torch.nn.AdaptiveAvgPool3d(...): ../AdaptiveAvgPool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool1d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool1d
.. _torch.nn.AdaptiveMaxPool1d(...): ../AdaptiveMaxPool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool2d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool2d
.. _torch.nn.AdaptiveMaxPool2d(...): ../AdaptiveMaxPool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool3d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool3d
.. _torch.nn.AdaptiveMaxPool3d(...): ../AdaptiveMaxPool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
......@@ -185,7 +185,6 @@ const Map<string, Map<string, string>>& ONNXBackend::get_node_renamed_attrs()
const {
const static Map<string, Map<string, string>> kPerNodeRenamedAttrs = {
{"DepthToSpace", {{"blocksize", "block_size"}}},
{"Gemm", {{"transB", "transW"}}},
{"RoiAlign",
{
{"output_height", "pooled_h"},
......
......@@ -180,19 +180,7 @@ ONNXImporterReturns ONNXBackend::GemmImporter(
ONNXNode* onnx_node,
const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes;
auto alpha = attributes.get<float>("alpha", 1.f);
auto beta = attributes.get<float>("beta", 1.f);
auto trans_a = attributes.get<int64_t>("transA", 0L);
// Remove the unsupported attributes
if (alpha != 1.f || beta != 1.f) {
LOG(FATAL) << "alpha/beta can not be set currently.";
}
if (trans_a) {
LOG(FATAL) << "Tranposed A is not supported currently.";
}
attributes.remove("alpha");
attributes.remove("beta");
attributes.remove("transA");
attributes.AddRewrittenAttribute("axis")->set_i(-1);
return GenericImporter(onnx_node, ctx);
}
......
......@@ -98,7 +98,6 @@ DECLARE_ELEMENTWISE_OP(SignGradient);
DECLARE_ELEMENTWISE_OP(SinGradient);
DECLARE_ELEMENTWISE_OP(SqrtGradient);
DECLARE_ELEMENTWISE_OP(SquareGradient);
// Binary ElementwiseOp
DECLARE_ELEMENTWISE_OP(Add);
DECLARE_ELEMENTWISE_OP(Sub);
......@@ -122,7 +121,6 @@ DECLARE_ELEMENTWISE_OP(PowGradient);
DECLARE_ELEMENTWISE_OP(DotGradient);
DECLARE_ELEMENTWISE_OP(MinimumGradient);
DECLARE_ELEMENTWISE_OP(MaximumGradient);
#undef DECLARE_ELEMENTWISE_OP
} // namespace dragon
......
#include "dragon/operators/math/fully_connected_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/filler.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void FullyConnectedOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X);
// Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N;
if (out_channels_ <= 0) {
// Infer the "N" from the weights shape
N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString();
} else {
// Use a fixed "N" from the argument
N = out_channels_;
}
vec64_t Y_dims(axis + 1);
for (int i = 0; i < axis + 1; i++) {
Y_dims[i] = i < axis ? X.dim(i) : N;
}
if (transW_ > 0) {
TENSOR_FILL(W, vec64_t({N, K}));
CHECK(W.ndim() == 2 && W.dim(1) == K)
<< "\nWeights dimensions should be [N, K].\n"
<< "Got X as (" << M << ", " << K << "), "
<< "and W as " << W.DimString();
} else {
TENSOR_FILL(W, vec64_t({K, N}));
CHECK(W.ndim() == 2 && W.dim(0) == K)
<< "\nWeights dimensions should be [K, N].\n"
<< "Got X as (" << M << ", " << K << "), "
<< "and W as " << W.DimString();
}
math::Gemm(
CblasNoTrans,
(CBLAS_TRANSPOSE)transW_,
M,
N,
K,
1.f,
X.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
if (InputSize() > 2) {
TENSOR_FILL(Input(2), vec64_t({N}));
kernel::BiasAdd(
M,
1,
N,
Y->template data<T, Context>(),
Input(2).template data<T, Context>(),
Y->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
void FullyConnectedOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void FullyConnectedGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
CANONICALIZE_AXIS_WITH_TENSOR(X);
// Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N;
if (out_channels_ <= 0) {
// Infer the "N" from the weights shape
N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString();
} else {
// Use a fixed "N" from the argument
N = out_channels_;
}
if (dX->has_name()) {
if (transW_) {
math::Gemm(
CblasNoTrans,
CblasNoTrans,
M,
K,
N,
1.f,
dY.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasNoTrans,
CblasTrans,
M,
K,
N,
1.f,
dY.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
}
if (dW->has_name()) {
if (transW_) {
math::Gemm(
CblasTrans,
CblasNoTrans,
N,
K,
M,
1.f,
dY.template data<T, Context>(),
X.template data<T, Context>(),
0.f,
dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasTrans,
CblasNoTrans,
K,
N,
M,
1.f,
X.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx());
}
}
if (dB->has_name()) {
vec32_t dims = {(int)M, (int)N}, axes = {0};
math::ReduceSum(
2,
dims.data(),
1,
axes.data(),
1.f,
dY.template data<T, Context>(),
dB->Reshape({N})->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
void FullyConnectedGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(FullyConnected);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(FullyConnected);
#endif
DEPLOY_CPU_OPERATOR(FullyConnectedGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(FullyConnectedGradient);
#endif
OPERATOR_SCHEMA(FullyConnected)
/* X, W, B */
.NumInputs(2, 3)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(FullyConnectedGradient)
/* X, W, dY */
.NumInputs(3)
/* dX, dW, dB */
.NumOutputs(3);
namespace {
class GradientMaker : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
"",
vector<string>({I(0), I(1), GO(0)}),
vector<string>({GI(0), GI(1), GI(2)}));
}
};
} // namespace
REGISTER_GRADIENT(FullyConnected, GradientMaker);
} // namespace dragon
#include "dragon/operators/math/gemm_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/filler.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
template <class Context>
template <typename T>
void GemmOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(A);
// Check matrix A
auto M = transA_ ? A.count(axis) : A.count(0, axis);
auto K = transA_ ? A.count(0, axis) : A.count(axis);
// Check matrix B
auto N = n_; // Init "N" from the argument
if (N <= 0) {
// Infer "N" from the B shape
N = B.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer 'N' from "
<< "the B shape: " << B.DimString();
}
if (transB_ > 0) {
TENSOR_FILL(B, vec64_t({N, K}));
CHECK(B.ndim() == 2 && B.dim(1) == K)
<< "\nMatrixB's dimensions should be [N, K].\n"
<< "Got A as (" << M << ", " << K << "), "
<< "and B as " << B.DimString();
} else {
TENSOR_FILL(B, vec64_t({K, N}));
CHECK(B.ndim() == 2 && B.dim(0) == K)
<< "\nMatrixB's dimensions should be [K, N].\n"
<< "Got A as (" << M << ", " << K << "), "
<< "and B as " << B.DimString();
}
// Copy matrix C to Y if provided
vec64_t Y_dims(A.dims().begin(), A.dims().begin() + axis);
Y_dims.insert(transA_ ? Y_dims.begin() : Y_dims.end(), N);
if (InputSize() > 2) {
auto& C = Input(2);
if (C.ndim() == 0) {
TENSOR_FILL(C, vec64_t({N}));
}
if (math::utils::IsBinaryBroadcast(Y_dims, C.dims(), Y_dims)) {
math::Set(
C.ndim(),
C.dims().data(),
Y_dims.size(),
Y_dims.data(),
C.template data<T, Context>(),
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else {
LOG(FATAL) << "Could not broadcast together with shapes: "
<< Tensor::DimString(Y_dims) << " " << C.DimString();
}
}
math::Gemm(
(CBLAS_TRANSPOSE)transA_,
(CBLAS_TRANSPOSE)transB_,
M,
N,
K,
alpha_,
A.template data<T, Context>(),
B.template data<T, Context>(),
InputSize() > 2 ? beta_ : 0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void GemmOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void GemmGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(3);
auto *dA = Output(0), *dB = Output(1), *dC = Output(2);
CANONICALIZE_AXIS_WITH_TENSOR(A);
// Check matrix A
auto M = transA_ ? A.count(axis) : A.count(0, axis);
auto K = transA_ ? A.count(0, axis) : A.count(axis);
// Check matrix B
auto N = n_; // Init "N" from the argument
if (N <= 0) {
// Infer "N" from the B shape
N = B.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer 'N' from "
<< "the B shape: " << B.DimString();
}
if (dA->has_name()) {
if (transA_ > 0) {
math::Gemm(
transB_ ? CblasTrans : CblasNoTrans,
CblasTrans,
K,
M,
N,
alpha_,
B.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasNoTrans,
transB_ ? CblasNoTrans : CblasTrans,
M,
K,
N,
alpha_,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
}
if (dB->has_name()) {
if (transB_) {
math::Gemm(
CblasTrans,
transA_ ? CblasTrans : CblasNoTrans,
N,
K,
M,
alpha_,
dY.template data<T, Context>(),
A.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
transA_ ? CblasNoTrans : CblasTrans,
CblasNoTrans,
K,
N,
M,
alpha_,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
}
if (dC->has_name()) {
auto& C = Input(2);
if (C.count() == dY.count()) {
math::Scale(
dY.count(),
beta_,
dY.template data<T, Context>(),
dC->ReshapeLike(C)->template mutable_data<T, Context>(),
ctx());
} else {
vec32_t Y_axes, C_axes;
math::utils::ComputeBinaryBroadcastAxes(
dY.dims(), C.dims(), dY.dims(), Y_axes, C_axes);
math::ReduceSum(
dY.ndim(),
vec32_t{dY.dims().begin(), dY.dims().end()}.data(),
C_axes.size(),
C_axes.data(),
beta_,
dY.template data<T, Context>(),
dC->ReshapeLike(C)->template mutable_data<T, Context>(),
ctx());
}
}
}
template <class Context>
void GemmGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Gemm);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Gemm);
#endif
DEPLOY_CPU_OPERATOR(GemmGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(GemmGradient);
#endif
OPERATOR_SCHEMA(Gemm)
/* A, B, C */
.NumInputs(2, 3)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(GemmGradient)
/* A, B, C, dY */
.NumInputs(4)
/* dA, dB, dC */
.NumOutputs(3);
namespace {
class GradientMaker : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
"",
vector<string>({I(0), I(1), I(2), GO(0)}),
vector<string>({GI(0), GI(1), GI(2)}));
}
};
} // namespace
REGISTER_GRADIENT(Gemm, GradientMaker);
} // namespace dragon
......@@ -10,20 +10,23 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_
#define DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_
#ifndef DRAGON_OPERATORS_MATH_GEMM_OP_H_
#define DRAGON_OPERATORS_MATH_GEMM_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class FullyConnectedOp final : public Operator<Context> {
class GemmOp final : public Operator<Context> {
public:
FullyConnectedOp(const OperatorDef& def, Workspace* ws)
GemmOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
out_channels_(OP_SINGLE_ARG(int64_t, "out_channels", 0)),
transW_(OP_SINGLE_ARG(int64_t, "transW", 1)) {}
n_(OP_SINGLE_ARG(int64_t, "n", 0)),
alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)),
beta_(OP_SINGLE_ARG(float, "beta", 1.f)),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -32,16 +35,20 @@ class FullyConnectedOp final : public Operator<Context> {
void DoRunWithType();
protected:
int64_t out_channels_, transW_;
float alpha_, beta_;
int64_t n_, transA_, transB_;
};
template <class Context>
class FullyConnectedGradientOp final : public Operator<Context> {
class GemmGradientOp final : public Operator<Context> {
public:
FullyConnectedGradientOp(const OperatorDef& def, Workspace* ws)
GemmGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
out_channels_(OP_SINGLE_ARG(int64_t, "out_channels", 0)),
transW_(OP_SINGLE_ARG(int64_t, "transW", 1)) {}
n_(OP_SINGLE_ARG(int64_t, "n", 0)),
alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)),
beta_(OP_SINGLE_ARG(float, "beta", 1.f)),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -50,9 +57,10 @@ class FullyConnectedGradientOp final : public Operator<Context> {
void DoRunWithType();
protected:
int64_t out_channels_, transW_;
float alpha_, beta_;
int64_t n_, transA_, transB_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_
#endif // DRAGON_OPERATORS_MATH_GEMM_OP_H_
#include "dragon/operators/math/matmul_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
......@@ -7,45 +8,205 @@ template <class Context>
template <typename T>
void MatMulOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), *Y = Output(0);
auto A_ndim = A.ndim(), B_ndim = B.ndim();
CHECK_GE(A.ndim(), 2) << "\nTensor(" << A.name() + ") must be a matrix"
<< "(or rank > 2, representing batches of matrices).";
CHECK_GE(B.ndim(), 2) << "\nTensor(" << B.name() + ") must be a matrix"
<< "(or rank > 2, representing batches of matrices).";
auto M1 = A.dim(-2), N1 = A.dim(-1);
auto M2 = B.dim(-2), N2 = B.dim(-1);
auto M = transA_ ? N1 : M1, N = transB_ ? M2 : N2;
auto K1 = transA_ ? M1 : N1, K2 = transB_ ? N2 : M2;
auto A_stride = M1 * N1, B_stride = M2 * N2, Y_stride = M * N;
auto batch_size = A.count() / A_stride;
if (A_ndim == 1 && B_ndim == 1) {
// Vector x Vector
CHECK_EQ(A.count(), B.count()) << "\nExcept equal length of two vectors.";
math::Dot(
A.count(),
A.template data<T, Context>(),
B.template data<T, Context>(),
Y->Reshape({})->template mutable_data<T, Context>(),
ctx());
return;
}
CHECK((K1 == K2) && (batch_size == (B.count() / B_stride)))
<< "\nTensor(" << A.name() << "): " << A.DimString()
<< " can not mul with Tensor"
<< "(" << B.name() << "): " << B.DimString();
if (A_ndim == 1) {
const auto N = A.count();
CHECK_EQ(B.dim(B_ndim - 2), N) << "\nExcept the second last dim of B is "
<< N << ", got " << B.dim(B_ndim - 2);
const auto M = B.dim(B_ndim - 1);
const auto batch_size = B.count() / (M * N);
vec64_t Y_dims(B.dims().begin(), B.dims().end() - 1);
Y_dims.back() = B.dims().back();
if (batch_size == 1) {
// Vector x Matrix
math::Gemv(
CblasTrans,
N,
M,
1.f,
B.template data<T, Context>(),
A.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else {
// Broadcasted Vector x Batched Matrix
math::GemmStridedBatched(
CblasTrans,
CblasNoTrans,
batch_size,
M,
1,
N,
M * N,
0,
M,
1.f,
B.template data<T, Context>(),
A.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
}
return;
}
if (B_ndim == 1) {
// Matrix x Vector
const auto N = B.count();
CHECK_EQ(A.dim(A_ndim - 1), N) << "\nExcept the last dim of A is " << N
<< ", got " << A.dim(A_ndim - 1);
const auto M = A.count() / N;
vec64_t Y_dims(A.dims());
Y_dims[Y_dims.size() - 2] = M;
Y_dims[Y_dims.size() - 1] = N;
Y->Reshape(Y_dims);
Y_dims.erase(Y_dims.end() - 1);
math::Gemv(
CblasNoTrans,
M,
N,
1.f,
A.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
return;
}
// Check matrix A
const auto M = A.dim(A_ndim - 2);
const auto K = A.dim(A_ndim - 1);
auto* a = A.template data<T, Context>();
auto* b = B.template data<T, Context>();
auto* y = Y->template mutable_data<T, Context>();
// Check matrix B
CHECK_EQ(B.dim(B_ndim - 2), K) << "\nExcept the second last dim of B is " << K
<< ", got " << B.dim(B_ndim - 2);
const auto N = B.dim(B_ndim - 1);
for (int i = 0; i < batch_size; ++i) {
// Check batching && broadcasting
vec64_t A_dims(A.dims().begin(), A.dims().end() - 2);
vec64_t B_dims(B.dims().begin(), B.dims().end() - 2);
vec64_t A_batch_dims, B_batch_dims, Y_dims;
if (math::utils::IsBinaryBroadcast(A_dims, B_dims, Y_dims)) {
math::utils::ComputeBinaryBroadcastDims(
A_dims, B_dims, A_batch_dims, B_batch_dims);
} else {
LOG(FATAL) << "Could not broadcast together with shapes " << A.DimString()
<< " " << B.DimString();
}
const int64_t batch_ndim = A_batch_dims.size();
const bool broadcasting = A_batch_dims != B_batch_dims;
Y_dims.push_back(M);
Y_dims.push_back(N);
const auto A_batch_size = std::accumulate(
A_batch_dims.begin(),
A_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
const auto B_batch_size = std::accumulate(
B_batch_dims.begin(),
B_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
const auto Y_batch_size = std::accumulate(
Y_dims.begin(),
Y_dims.begin() + batch_ndim,
1LL,
std::multiplies<int64_t>());
if (B_batch_size == 1) {
// Batched Matrix x Broadcasted Matrix
math::Gemm(
transA_ > 0 ? CblasTrans : CblasNoTrans,
transB_ > 0 ? CblasTrans : CblasNoTrans,
CblasNoTrans,
CblasNoTrans,
A_batch_size * M,
N,
K,
1.f,
A.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else if (A_batch_size == 1) {
// Broadcasted Matrix x Batched Matrix
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
Y_batch_size,
M,
N,
K1,
K,
0,
K * N,
M * N,
1.f,
a + i * A_stride,
b + i * B_stride,
A.template data<T, Context>(),
B.template data<T, Context>(),
0.0f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else if (!broadcasting) {
// Batched Matrix x Batched Matrix
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
Y_batch_size,
M,
N,
K,
M * K,
K * N,
M * N,
1.f,
A.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
y + i * Y_stride,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else {
// Broadcasted Matrix x Broadcasted Matrix
vector<const T*> A_arr(Y_batch_size);
vector<const T*> B_arr(Y_batch_size);
vector<T*> Y_arr(Y_batch_size);
vec64_t index(batch_ndim, 0);
auto* A_data = A.template data<T, Context>();
auto* B_data = B.template data<T, Context>();
auto* Y_data = Y->Reshape(Y_dims)->template mutable_data<T, Context>();
for (int64_t Y_i = 0; Y_i < Y_batch_size; ++Y_i) {
const auto A_i = math::utils::GetIndexFromDims(
batch_ndim, A_batch_dims.data(), index.data());
const auto B_i = math::utils::GetIndexFromDims(
batch_ndim, B_batch_dims.data(), index.data());
A_arr[Y_i] = A_data + A_i * M * K;
B_arr[Y_i] = B_data + B_i * K * N;
Y_arr[Y_i] = Y_data + Y_i * M * N;
math::utils::IncreaseIndexInDims(batch_ndim, Y_dims.data(), index.data());
}
math::GemmBatched(
CblasNoTrans,
CblasNoTrans,
Y_batch_size,
M,
N,
K,
1.f,
A_arr.data(),
B_arr.data(),
0.f,
Y_arr.data(),
ctx());
}
}
......@@ -60,95 +221,397 @@ template <typename T>
void MatMulGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(2);
auto *dA = Output(0), *dB = Output(1);
auto A_ndim = A.ndim(), B_ndim = B.ndim();
CHECK_GE(A.ndim(), 2) << "\nTensor(" << A.name() + ") must be a matrix"
<< "(or rank > 2, representing batches of matrices).";
CHECK_GE(B.ndim(), 2) << "\nTensor(" << B.name() + ") must be a matrix"
<< "(or rank > 2, representing batches of matrices).";
auto M1 = A.dim(-2), N1 = A.dim(-1);
auto M2 = B.dim(-2), N2 = B.dim(-1);
auto M = transA_ ? N1 : M1, N = transB_ ? M2 : N2;
auto K1 = transA_ ? M1 : N1, K2 = transB_ ? N2 : M2;
auto A_stride = M1 * N1, B_stride = M2 * N2, Y_stride = M * N;
auto batch_size = A.count() / A_stride;
if (A_ndim == 1 && B_ndim == 1) {
// Vector x Vector
if (dA->has_name()) {
math::Mul(
dY.ndim(),
dY.dims().data(),
B.ndim(),
B.dims().data(),
dY.template data<T, Context>(),
B.template data<T, Context>(),
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
math::Mul(
dY.ndim(),
dY.dims().data(),
A.ndim(),
A.dims().data(),
dY.template data<T, Context>(),
A.template data<T, Context>(),
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
return;
}
CHECK((K1 == K2) && (batch_size == (B.count() / B_stride)))
<< "\nTensor(" << A.name() << "): " << A.DimString()
<< " can not mul with Tensor"
<< "(" << B.name() << "): " << B.DimString();
if (A_ndim == 1) {
const auto N = A.count();
if (dA->has_name()) {
const auto M = B.dim(B_ndim - 1);
const auto batch_size = B.count() / (M * N);
if (batch_size == 1) {
// Vector x Matrix
math::Gemv(
CblasNoTrans,
N,
M,
1.f,
B.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
} else {
// Broadcasted Vector x Batched Matrix
auto* scratch =
ctx()->workspace()->template data<T, Context>({batch_size * N})[0];
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
batch_size,
N,
1,
M,
M * N,
M,
N,
1.f,
B.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
scratch,
ctx());
math::ReduceSum(
2,
vec32_t{int(batch_size), int(N)}.data(),
1,
vec32_t{0}.data(),
1.f,
scratch,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
}
if (dB->has_name()) {
const auto M = B.dim(B_ndim - 1);
const auto batch_size = B.count() / (M * N);
if (batch_size == 1) {
// Vector x Matrix
math::Gemm(
CblasNoTrans,
CblasNoTrans,
N,
M,
1,
1.f,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
} else {
// Broadcasted Vector x Batched Matrix
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
batch_size,
N,
M,
1,
0,
M,
M * N,
1.f,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
}
return;
}
if (B_ndim == 1) {
const auto N = B.count();
const auto M = A.count() / N;
// Matrix x Vector
if (dA->has_name()) {
auto* b = B.template data<T, Context>();
auto* dy = dY.template data<T, Context>();
auto* da = dA->ReshapeLike(A)->template mutable_data<T, Context>();
if (transA_ > 0) {
for (int i = 0; i < batch_size; ++i) {
math::Gemm(
transB_ ? CblasTrans : CblasNoTrans,
CblasNoTrans,
CblasNoTrans,
M,
N,
1,
1.f,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
math::Gemv(
CblasTrans,
K1,
M,
N,
1.f,
b + i * B_stride,
dy + i * Y_stride,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
da + i * A_stride,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
return;
}
// Check matrix A && B
const auto M = A.dim(A_ndim - 2);
const auto K = A.dim(A_ndim - 1);
const auto N = B.dim(B_ndim - 1);
// Check batching && broadcasting
vec64_t A_dims(A.dims().begin(), A.dims().end() - 2);
vec64_t B_dims(B.dims().begin(), B.dims().end() - 2);
vec64_t A_batch_dims, B_batch_dims, Y_batch_dims;
vec32_t A_batch_axes, B_batch_axes;
if (math::utils::IsBinaryBroadcast(A_dims, B_dims, Y_batch_dims)) {
math::utils::ComputeBinaryBroadcastDims(
A_dims, B_dims, A_batch_dims, B_batch_dims);
math::utils::ComputeBinaryBroadcastAxes(
A_batch_dims, B_batch_dims, Y_batch_dims, A_batch_axes, B_batch_axes);
} else {
for (int i = 0; i < batch_size; ++i) {
LOG(FATAL) << "Could not broadcast together with shapes " << A.DimString()
<< " " << B.DimString();
}
const int64_t batch_ndim = A_batch_dims.size();
const bool broadcasting = A_batch_dims != B_batch_dims;
const auto A_batch_size = std::accumulate(
A_batch_dims.begin(),
A_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
const auto B_batch_size = std::accumulate(
B_batch_dims.begin(),
B_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
const auto Y_batch_size = std::accumulate(
Y_batch_dims.begin(),
Y_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
if (B_batch_size == 1) {
// Batched Matrix x Broadcasted Matrix
if (dA->has_name()) {
math::Gemm(
CblasNoTrans,
CblasTrans,
A_batch_size * M,
K,
N,
1.f,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
math::Gemm(
CblasTrans,
CblasNoTrans,
transB_ ? CblasNoTrans : CblasTrans,
K,
N,
A_batch_size * M,
1.f,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
} else if (A_batch_size == 1) {
// Broadcasted Matrix x Batched Matrix
if (dA->has_name()) {
auto* scratch = ctx()->workspace()->template data<T, Context>(
{Y_batch_size * M * K})[0];
math::GemmStridedBatched(
CblasNoTrans,
CblasTrans,
Y_batch_size,
M,
K1,
K,
N,
M * N,
K * N,
M * K,
1.f,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.0f,
scratch,
ctx());
math::ReduceSum(
2,
vec32_t{int(Y_batch_size), int(M * K)}.data(),
1,
vec32_t{0}.data(),
1.f,
scratch,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
math::GemmStridedBatched(
CblasTrans,
CblasNoTrans,
Y_batch_size,
K,
N,
M,
0,
M * N,
K * N,
1.f,
dy + i * Y_stride,
b + i * B_stride,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
da + i * A_stride,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
} else if (!broadcasting) {
// Batched Matrix x Batched Matrix
if (dA->has_name()) {
math::GemmStridedBatched(
CblasNoTrans,
CblasTrans,
Y_batch_size,
M,
K,
N,
M * N,
K * N,
M * K,
1.f,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
auto* a = A.template data<T, Context>();
auto* dy = dY.template data<T, Context>();
auto* db = dB->ReshapeLike(B)->template mutable_data<T, Context>();
if (transB_) {
for (int i = 0; i < batch_size; ++i) {
math::Gemm(
math::GemmStridedBatched(
CblasTrans,
transA_ ? CblasTrans : CblasNoTrans,
CblasNoTrans,
Y_batch_size,
K,
N,
K1,
M,
M * K,
M * N,
K * N,
1.f,
dy + i * Y_stride,
a + i * A_stride,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
db + i * B_stride,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
} else {
for (int i = 0; i < batch_size; ++i) {
math::Gemm(
transA_ ? CblasNoTrans : CblasTrans,
// Broadcasted Matrix x Broadcasted Matrix
vector<const T*> A_arr(Y_batch_size);
vector<const T*> B_arr(Y_batch_size);
vector<const T*> dY_arr(Y_batch_size);
vector<T*> dA_arr(Y_batch_size);
vector<T*> dB_arr(Y_batch_size);
if (dA->has_name()) {
vec64_t index(batch_ndim, 0);
vec32_t scratch_dims({Y_batch_dims.begin(), Y_batch_dims.end()});
scratch_dims.push_back(int(M * K));
auto* dY_data = dY.template data<T, Context>();
auto* B_data = B.template data<T, Context>();
auto* scratch = ctx()->workspace()->template data<T, Context>(
{Y_batch_size * std::max(M * K, K * N)})[0];
for (int64_t Y_i = 0; Y_i < Y_batch_size; ++Y_i) {
const auto B_i = math::utils::GetIndexFromDims(
batch_ndim, B_batch_dims.data(), index.data());
dY_arr[Y_i] = dY_data + Y_i * M * N;
B_arr[Y_i] = B_data + B_i * K * N;
dA_arr[Y_i] = scratch + Y_i * M * K;
math::utils::IncreaseIndexInDims(
batch_ndim, Y_batch_dims.data(), index.data());
}
math::GemmBatched(
CblasNoTrans,
K1,
N,
CblasTrans,
Y_batch_size,
M,
K,
N,
1.f,
a + i * A_stride,
dy + i * Y_stride,
dY_arr.data(),
B_arr.data(),
0.f,
db + i * B_stride,
dA_arr.data(),
ctx());
math::ReduceSum(
scratch_dims.size(),
scratch_dims.data(),
A_batch_axes.size(),
A_batch_axes.data(),
1.f,
scratch,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
vec64_t index(batch_ndim, 0);
vec32_t scratch_dims({Y_batch_dims.begin(), Y_batch_dims.end()});
scratch_dims.push_back(int(K * N));
auto* dY_data = dY.template data<T, Context>();
auto* A_data = A.template data<T, Context>();
auto* scratch = ctx()->workspace()->template data<T, Context>(
{Y_batch_size * std::max(M * K, K * N)})[0];
for (int64_t Y_i = 0; Y_i < Y_batch_size; ++Y_i) {
const auto A_i = math::utils::GetIndexFromDims(
batch_ndim, A_batch_dims.data(), index.data());
dY_arr[Y_i] = dY_data + Y_i * M * N;
A_arr[Y_i] = A_data + A_i * M * K;
dB_arr[Y_i] = scratch + Y_i * K * N;
math::utils::IncreaseIndexInDims(
batch_ndim, Y_batch_dims.data(), index.data());
}
math::GemmBatched(
CblasTrans,
CblasNoTrans,
Y_batch_size,
K,
N,
M,
1.f,
A_arr.data(),
dY_arr.data(),
0.f,
dB_arr.data(),
ctx());
math::ReduceSum(
scratch_dims.size(),
scratch_dims.data(),
B_batch_axes.size(),
B_batch_axes.data(),
1.f,
scratch,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
}
}
......
......@@ -20,37 +20,25 @@ namespace dragon {
template <class Context>
class MatMulOp final : public Operator<Context> {
public:
MatMulOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
SIMPLE_CTOR_DTOR(MatMulOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t transA_, transB_;
};
template <class Context>
class MatMulGradientOp final : public Operator<Context> {
public:
MatMulGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
SIMPLE_CTOR_DTOR(MatMulGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t transA_, transB_;
};
} // namespace dragon
......
......@@ -35,6 +35,7 @@ from dragon.core.ops.math_ops import dot
from dragon.core.ops.math_ops import equal
from dragon.core.ops.math_ops import exp
from dragon.core.ops.math_ops import floor
from dragon.core.ops.math_ops import gemm
from dragon.core.ops.math_ops import greater
from dragon.core.ops.math_ops import greater_equal
from dragon.core.ops.math_ops import is_inf
......
......@@ -33,7 +33,6 @@ from dragon.core.ops.activation_ops import relu6
from dragon.core.ops.activation_ops import selu
from dragon.core.ops.activation_ops import softmax
from dragon.core.ops.activation_ops import swish
from dragon.core.ops.math_ops import fully_connected
from dragon.core.ops.normalization_ops import batch_norm
from dragon.core.ops.normalization_ops import group_norm
from dragon.core.ops.normalization_ops import instance_norm
......
......@@ -414,30 +414,34 @@ def flatten_spec(args, inputs, outputs):
return outputs
@register('FullyConnected')
def fully_connected_spec(args, inputs, outputs):
@register('Gemm')
def gemm_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype
axis, out_channels = args['axis'], args.get('out_channels', None)
axis, n = args['axis'], args.get('n', None)
while axis < 0:
try:
axis += len(inputs[0].shape)
except TypeError:
return outputs
out_shape = [None] * (axis + 1)
if out_channels is None:
break
out_shape = [None] * axis if axis >= 0 else None
if n is None:
try:
if args['transW']:
out_channels = inputs[1].shape[0]
if args['transB']:
n = inputs[1].shape[0]
else:
out_channels = inputs[1].shape[1]
n = inputs[1].shape[1]
except (TypeError, IndexError):
out_channels = None
n = None
try:
out_shape[axis] = out_channels
out_shape[:axis] = inputs[0].shape[:axis]
if out_shape is None or inputs[0].shape is not None:
out_shape = list(inputs[0].shape[:axis])
if args['transA']:
out_shape.insert(0, n)
else:
out_shape.append(n)
outputs[0].shape = out_shape
except (TypeError, IndexError):
pass
outputs[0].shape = out_shape
return outputs
......@@ -510,12 +514,25 @@ def masked_select_spec(args, inputs, outputs):
@register('MatMul')
def matmul_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype
ta, tb = args['transA'], args['transB']
try:
a_shape = list(inputs[0].shape[:])
b_shape = list(inputs[1].shape[:])
a_shape = out_shape = list(inputs[0].shape[:])
out_shape[-2] = a_shape[-1] if ta else a_shape[-2]
out_shape[-1] = b_shape[-2] if tb else b_shape[-1]
if len(a_shape) >= 2 and len(b_shape) >= 2:
out_shape = [1] * max(len(a_shape), len(b_shape))
a_shape = [1] * (len(out_shape) - len(a_shape)) + a_shape
b_shape = [1] * (len(out_shape) - len(b_shape)) + b_shape
for i in range(len(out_shape)):
try:
out_shape[i] = max(a_shape[i], b_shape[i])
except TypeError:
out_shape[i] = None
out_shape[-2] = a_shape[-2]
out_shape[-1] = b_shape[-1]
elif len(a_shape) == 1 and len(b_shape) == 1:
out_shape = []
else:
out_shape = a_shape if len(b_shape) == 1 else b_shape
out_shape.pop(-1 if len(b_shape) == 1 else -2)
except (TypeError, IndexError):
out_shape = None
outputs[0].shape = out_shape
......
......@@ -498,23 +498,30 @@ def floor(inputs, **kwargs):
@OpSchema.num_inputs(2, 3)
def fully_connected(inputs, axis=1, transpose_w=True, **kwargs):
r"""Compute the dense matrix multiplication along the given axes.
.. math:: y = Wx + b
The column of input matrix is determined by:
.. math:: \text{Col} = \text{DimSince}(\text{Input}, \text{Axis})
def gemm(
inputs,
alpha=1.0,
beta=1.0,
transpose_a=False,
transpose_b=False,
**kwargs
):
r"""Compute the general matrix multiplication.
.. math:: \text{out} = \alpha AB + \beta C
Parameters
----------
inputs : Sequence[dragon.Tensor]
The tensor :math:`x`, :math:`W` and :math:`b`.
axis : int, optional, default=1
The start axis to compute, can be negative.
transpose_w : bool, optional, default=True
**True** to transpose :math:`W` before computation.
The matrix :math:`A`, :math:`B` and optional :math:`C`.
alpha : float, optional, default=1.0
The value to :math:`\alpha`.
beta : float, optional, default=1.0
The value to :math:`\beta`.
transpose_a : bool, optional, default=False
**True** to transpose :math:`A` before computation.
transpose_b : bool, optional, default=False
**True** to transpose :math:`B` before computation.
Returns
-------
......@@ -523,15 +530,22 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs):
"""
args = ArgHelper.parse(locals())
op_lib = math_ops_lib.FullyConnected
args['axis'] = kwargs.get('axis', -1)
args['alpha'], args['beta'] = float(alpha), float(beta)
op_lib = math_ops_lib.Gemm
if context.executing_eagerly():
return op_lib \
.instantiate(axis=axis, transpose_w=transpose_w) \
.apply(inputs)
.instantiate(
axis=args['axis'],
alpha=args['alpha'],
beta=args['beta'],
transpose_a=transpose_a,
transpose_b=transpose_b,
).apply(inputs)
else:
args.pop('transpose_w')
args['transW'] = transpose_w
return op_lib.blend('FullyConnected', **args)
args['transA'] = args.pop('transpose_a')
args['transB'] = args.pop('transpose_b')
return op_lib.blend(**args)
@OpSchema.num_inputs(2)
......@@ -812,42 +826,44 @@ def less_equal(inputs, **kwargs):
@OpSchema.num_inputs(2)
def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs):
def matmul(inputs, **kwargs):
r"""Compute the matrix multiplication.
.. math:: y = a \times b
.. math:: \text{out} = \text{input1} \times \text{input2}
The rank of ``a`` and ``b`` should be equal and >= 2:
The behavior depends on the shape of input tensors:
```python
# Ok, a typical matrix multiplication
a = dragon.ones((2, 3), 'float32')
b = dragon.ones((3, 3), 'float32')
print(dragon.math.matmul([a, b]))
* If both tensors are 1d, computes the vector product.
* If tensors are 1d and >=2d, computes the vector-matrix multiplication.
* If tensors are >=2d and 1d, computes the matrix-vector multiplication.
* If both tensors are >= 2d, computes the matrix-matrix multiplication.
* If one tensor is >= 3d, applies batching and broadcasting to the computation.
# Compute a batch matrix multiplication if rank > 2
aa = dragon.ones((4, 2, 3), 'float32')
bb = dragon.ones((4, 3, 3), 'float32')
print(dragon.math.matmul([aa, bb]))
```
If inputs are transposed, remember to transpose them back:
Examples:
```python
# Vector x Vector
a = dragon.ones((2,), 'float32')
b = dragon.ones((2,), 'float32')
print(dragon.math.matmul([a, b]))
# Vector x Matrix
a = dragon.ones((2,), 'float32')
b = dragon.ones((2, 3), 'float32')
print(dragon.math.matmul([a, b]))
# Matrix x Vector
a = dragon.ones((3, 2), 'float32')
b = dragon.ones((3, 3), 'float32')
print(dragon.math.matmul([a, b])) # ``a`` takes the wrong dimensions
print(dragon.math.matmul([a, b], transpose_a=True)) # Ok
b = dragon.ones((2,), 'float32')
print(dragon.math.matmul([a, b]))
# Matrix x Matrix
a = dragon.ones((2, 3), 'float32')
b = dragon.ones((3, 2), 'float32')
print(dragon.math.matmul([a, b]))
```
Parameters
----------
inputs : Sequence[dragon.Tensor]
The matrix :math:`a` and :math:`b`.
transpose_a : bool, optional, default=False
**True** to transpose :math:`a` before computation.
transpose_b : bool, optional, default=False
**True** to transpose :math:`b` before computation.
The input tensors.
Returns
-------
......@@ -858,15 +874,9 @@ def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs):
args = ArgHelper.parse(locals())
op_lib = math_ops_lib.MatMul
if context.executing_eagerly():
return op_lib \
.instantiate(
transpose_a=transpose_a,
transpose_b=transpose_b,
).apply(inputs)
return op_lib.instantiate().apply(inputs)
else:
args.pop('transpose_a')
args.pop('transpose_b')
return op_lib.blend(transA=transpose_a, transB=transpose_b, **args)
return op_lib.blend(**args)
@OpSchema.num_inputs(2)
......
......@@ -80,20 +80,26 @@ class Clip(Operator):
return self.dispatch(inputs, [self.alloc()])
class FullyConnected(Operator):
class Gemm(Operator):
"""FullyConnected operator."""
def __init__(self, key, dev, **kwargs):
super(FullyConnected, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1)
self.transpose_w = kwargs.get('transpose_w', True)
super(Gemm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
self.alpha = kwargs.get('alpha', 1.0)
self.beta = kwargs.get('beta', 1.0)
self.transpose_a = kwargs.get('transpose_a', False)
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self):
return {
'op_type': 'FullyConnected',
'op_type': 'Gemm',
'arguments': {
'axis': self.axis,
'transW': self.transpose_w,
'alpha': self.alpha,
'beta': self.beta,
'transA': self.transpose_a,
'transB': self.transpose_b,
}
}
......@@ -104,18 +110,10 @@ class FullyConnected(Operator):
class MatMul(Operator):
"""MatMul operator."""
def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False)
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self):
return {
'op_type': 'MatMul',
'arguments': {
'transA': self.transpose_a,
'transB': self.transpose_b,
}
'arguments': {},
}
def forward(self, inputs):
......
......@@ -136,6 +136,7 @@ def conv(
data_format=data_format,
bias=len(inputs) > 2,
dtype=inputs[1].dtype,
input_shape=inputs[0].shape,
).apply(inputs)
else:
return op_lib.blend(**args)
......@@ -465,6 +466,7 @@ def conv_transpose(
data_format=data_format,
bias=len(inputs) > 2,
dtype=inputs[1].dtype,
input_shape=inputs[0].shape,
).apply(inputs)
else:
return op_lib.blend(**args)
......
......@@ -44,13 +44,6 @@ def hardsigmoid_exporter(op_def, context):
return node, const_tensors
@export_util.register('PRelu')
def prelu_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
const_tensors = [helper.from_tensor(op_def.input[1], context.ws)]
return node, const_tensors
@export_util.register('Relu')
def relu_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
......
......@@ -81,24 +81,14 @@ def clip_exporter_v11(op_def, context):
return node, const_tensors
@export_util.register('FullyConnected-7')
def fully_connected_exporter_v7(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'Gemm'
helper.add_attribute(node, 'alpha', 1.)
helper.add_attribute(node, 'beta', 1.)
for arg in op_def.arg:
if arg.name == 'transW':
helper.add_attribute(node, 'transB', arg.i)
# Weights and biases
const_tensors = [helper.from_tensor(name, context.ws)
for name in op_def.input[1:]]
return node, const_tensors
@export_util.register('Gemm-7')
def gemm_exporter_v7(op_def, context):
return export_util.translate(**locals())
@export_util.register('FullyConnected')
def fully_connected_exporter(op_def, context):
node, const_tensors = fully_connected_exporter_v7(op_def, context)
@export_util.register('Gemm')
def gemm_exporter(op_def, context):
node, const_tensors = gemm_exporter_v7(op_def, context)
helper.add_attribute(node, 'broadcast', 1) # Removed since opset 7
return node, const_tensors
......
......@@ -29,8 +29,6 @@ def batch_norm_exporter(op_def, context):
elif arg.name == 'momentum_desc':
momentum = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'momentum', float(momentum))
# Weight, bias, running mean and running variance
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors
......@@ -48,8 +46,6 @@ def group_norm_exporter(op_def, context):
else:
helper.add_attribute(node, 'op_type', 'GroupNorm')
helper.add_attribute(node, 'group', arg.i)
# Weight and bias
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors
......
......@@ -25,7 +25,7 @@ from dragon.vm.onnx.core.exporters import utils as export_util
'ConvTranspose',
'DepthwiseConv',
])
def convolution(op_def, context):
def conv_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
node.op_type = 'ConvTranspose' if 'Transpose' in op_def.type else 'Conv'
if 'Depthwise' in op_def.type:
......@@ -58,8 +58,6 @@ def convolution(op_def, context):
helper.add_attribute(node, 'output_shape', arg.ints)
elif arg.name == 'output_padding':
helper.add_attribute(node, 'output_padding', arg.ints)
# Weights and biases
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors
......
......@@ -203,8 +203,7 @@ DRAGON_API void Gemv<float16, CPUContext>(
const float16* x,
const float beta,
float16* y,
CPUContext* ctx,
const std::string math_type) {
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
......@@ -219,8 +218,7 @@ DRAGON_API void Gemv<float16, CPUContext>(
const T* x, \
const float beta, \
T* y, \
CPUContext* ctx, \
const string math_type) { \
CPUContext* ctx) { \
T _alpha_ = alpha, _beta_ = beta; \
EigenVectorMap<T> y_vec(y, TransA == CblasNoTrans ? M : N); \
if (beta == 0.f) \
......@@ -260,8 +258,7 @@ DRAGON_API void Gemm<float16, CPUContext>(
const float16* B,
const float beta,
float16* C,
CPUContext* ctx,
const string math_type) {
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
......@@ -278,8 +275,7 @@ DRAGON_API void Gemm<float16, CPUContext>(
const T* B, \
const float beta, \
T* C, \
CPUContext* ctx, \
const string math_type) { \
CPUContext* ctx) { \
T _alpha_ = alpha, _beta_ = beta; \
auto C_mat = EigenMatrixMap<T>(C, N, M); \
if (beta == 0.f) \
......@@ -328,6 +324,105 @@ DEFINE_GEMM_FUNC(float);
DEFINE_GEMM_FUNC(double);
#undef DEFINE_GEMM_FUNC
template <>
DRAGON_API void GemmBatched<float16, CPUContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
const float16** A,
const float16** B,
const float beta,
float16** C,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_BATCHED_GEMM_FUNC(T) \
template <> \
DRAGON_API void GemmBatched<T, CPUContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const float alpha, \
const T** A, \
const T** B, \
const float beta, \
T** C, \
CPUContext* ctx) { \
for (int i = 0; i < batch_size; ++i) { \
Gemm(TransA, TransB, M, N, K, alpha, A[i], B[i], beta, C[i], ctx); \
} \
}
DEFINE_BATCHED_GEMM_FUNC(float);
DEFINE_BATCHED_GEMM_FUNC(double);
#undef DEFINE_BATCHED_GEMM_FUNC
template <>
DRAGON_API void GemmStridedBatched<float16, CPUContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const int A_stride,
const int B_stride,
const int C_stride,
const float alpha,
const float16* A,
const float16* B,
const float beta,
float16* C,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_STRIDED_BATCHED_GEMM_FUNC(T) \
template <> \
DRAGON_API void GemmStridedBatched<T, CPUContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const int A_stride, \
const int B_stride, \
const int C_stride, \
const float alpha, \
const T* A, \
const T* B, \
const float beta, \
T* C, \
CPUContext* ctx) { \
for (int i = 0; i < batch_size; ++i) { \
Gemm( \
TransA, \
TransB, \
M, \
N, \
K, \
alpha, \
A + i * A_stride, \
B + i * B_stride, \
beta, \
C + i * C_stride, \
ctx); \
} \
}
DEFINE_STRIDED_BATCHED_GEMM_FUNC(float);
DEFINE_STRIDED_BATCHED_GEMM_FUNC(double);
#undef DEFINE_STRIDED_BATCHED_GEMM_FUNC
} // namespace math
} // namespace dragon
......@@ -2,6 +2,7 @@
#include "dragon/core/context_cuda.h"
#include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_thrust.h"
#include "dragon/utils/math/blas.h"
namespace dragon {
......@@ -456,8 +457,7 @@ DRAGON_API void Gemv<float16, CUDAContext>(
const float16* x,
const float beta,
float16* y,
CUDAContext* ctx,
const string math_type) {
CUDAContext* ctx) {
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
int m = cuTransA == CUBLAS_OP_N ? N : M;
int k = cuTransA == CUBLAS_OP_N ? M : N;
......@@ -465,10 +465,7 @@ DRAGON_API void Gemv<float16, CUDAContext>(
int LDC = m;
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
if (math_type == "float32") {
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
// GEMV + MATH32 + TENSOR-CORE
CUBLAS_CHECK(cublasGemmEx(
ctx->cublas_handle(),
cuTransA,
......@@ -490,7 +487,6 @@ DRAGON_API void Gemv<float16, CUDAContext>(
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// GEMV + MATH32 + DEFAULT
CUBLAS_CHECK(cublasSgemmEx(
ctx->cublas_handle(),
cuTransA,
......@@ -510,142 +506,43 @@ DRAGON_API void Gemv<float16, CUDAContext>(
CUDA_R_16F,
LDC));
}
#else
CUBLAS_CHECK(cublasSgemmEx(
ctx->cublas_handle(),
cuTransA,
CUBLAS_OP_N,
m,
1,
k,
&alpha,
A,
CUDA_R_16F,
LDA,
x,
CUDA_R_16F,
k,
&beta,
y,
CUDA_R_16F,
LDC));
#endif
} else if (math_type == "float16") {
const half alpha_val = convert::To<half>(alpha);
const half beta_val = convert::To<half>(beta);
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
// GEMV + MATH16 + TENSOR-CORE
CUBLAS_CHECK(cublasGemmEx(
ctx->cublas_handle(),
cuTransA,
CUBLAS_OP_N,
m,
1,
k,
&alpha_val,
A,
CUDA_R_16F,
LDA,
x,
CUDA_R_16F,
k,
&beta_val,
y,
CUDA_R_16F,
LDC,
CUDA_R_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// GEMV + MATH16 + DEFAULT
CUBLAS_CHECK(cublasHgemm(
ctx->cublas_handle(),
cuTransA,
CUBLAS_OP_N,
m,
1,
k,
&alpha_val,
reinterpret_cast<const half*>(A),
LDA,
reinterpret_cast<const half*>(x),
k,
&beta_val,
reinterpret_cast<half*>(y),
LDC));
}
#else
CUBLAS_CHECK(cublasHgemm(
ctx->cublas_handle(),
cuTransA,
CUBLAS_OP_N,
m,
1,
k,
&alpha_val,
reinterpret_cast<const half*>(A),
LDA,
reinterpret_cast<const half*>(x),
k,
&beta_val,
reinterpret_cast<half*>(y),
LDC));
#endif
} else {
LOG(FATAL) << "Unknown math type: " << math_type;
}
}
template <>
DRAGON_API void Gemv<float, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
const int M,
const int N,
const float alpha,
const float* A,
const float* x,
const float beta,
float* y,
CUDAContext* ctx,
const string math_type) {
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSgemv(
ctx->cublas_handle(), cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1));
}
#define DEFINE_GEMV_FUNC(T, cublas_func) \
template <> \
DRAGON_API void Gemv<T, CUDAContext>( \
const CBLAS_TRANSPOSE TransA, \
const int M, \
const int N, \
const float alpha, \
const T* A, \
const T* x, \
const float beta, \
T* y, \
CUDAContext* ctx) { \
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N; \
const auto alpha_val = static_cast<T>(alpha); \
const auto beta_val = static_cast<T>(beta); \
CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func( \
ctx->cublas_handle(), \
cuTransA, \
N, \
M, \
&alpha_val, \
A, \
N, \
x, \
1, \
&beta_val, \
y, \
1)); \
}
template <>
DRAGON_API void Gemv<double, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
const int M,
const int N,
const float alpha,
const double* A,
const double* x,
const float beta,
double* y,
CUDAContext* ctx,
const string math_type) {
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto alpha_val = static_cast<double>(alpha);
const auto beta_val = static_cast<double>(beta);
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasDgemv(
ctx->cublas_handle(),
cuTransA,
N,
M,
&alpha_val,
A,
N,
x,
1,
&beta_val,
y,
1));
}
DEFINE_GEMV_FUNC(float, cublasSgemv);
DEFINE_GEMV_FUNC(double, cublasDgemv);
#undef DEFINE_GEMV_FUNC
template <>
DRAGON_API void Gemm<float16, CUDAContext>(
......@@ -659,18 +556,14 @@ DRAGON_API void Gemm<float16, CUDAContext>(
const float16* B,
const float beta,
float16* C,
CUDAContext* ctx,
const std::string math_type) {
CUDAContext* ctx) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
if (math_type == "float32") {
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
// GEMM + MATH32 + TENSOR-CORE
CUBLAS_CHECK(cublasGemmEx(
ctx->cublas_handle(),
cuTransB,
......@@ -692,7 +585,6 @@ DRAGON_API void Gemm<float16, CUDAContext>(
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// GEMM + MATH32 + DEFAULT
CUBLAS_CHECK(cublasSgemmEx(
ctx->cublas_handle(),
cuTransB,
......@@ -712,113 +604,76 @@ DRAGON_API void Gemm<float16, CUDAContext>(
CUDA_R_16F,
N));
}
#else
CUBLAS_CHECK(cublasSgemmEx(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&beta,
C,
CUDA_R_16F,
N));
#endif
} else if (math_type == "float16") {
const half alpha_val = convert::To<half>(alpha);
const half beta_val = convert::To<half>(beta);
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
// GEMM + MATH16 + TENSOR-CORE
CUBLAS_CHECK(cublasGemmEx(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha_val,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&beta_val,
C,
CUDA_R_16F,
N,
CUDA_R_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// GEMM + MATH16 + DEFAULT
CUBLAS_CHECK(cublasHgemm(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha_val,
reinterpret_cast<const half*>(B),
ldb,
reinterpret_cast<const half*>(A),
lda,
&beta_val,
reinterpret_cast<half*>(C),
N));
}
#else
CUBLAS_CHECK(cublasHgemm(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha_val,
reinterpret_cast<const half*>(B),
ldb,
reinterpret_cast<const half*>(A),
lda,
&beta_val,
reinterpret_cast<half*>(C),
N));
#endif
} else {
LOG(FATAL) << "Unknown math type: " << math_type;
}
}
#define DEFINE_GEMM_FUNC(T, cublas_func) \
template <> \
DRAGON_API void Gemm<T, CUDAContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int M, \
const int N, \
const int K, \
const float alpha, \
const T* A, \
const T* B, \
const float beta, \
T* C, \
CUDAContext* ctx) { \
int lda = TransA == CblasNoTrans ? K : M; \
int ldb = TransB == CblasNoTrans ? N : K; \
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
const auto alpha_val = static_cast<T>(alpha); \
const auto beta_val = static_cast<T>(beta); \
CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func( \
ctx->cublas_handle(), \
cuTransB, \
cuTransA, \
N, \
M, \
K, \
&alpha_val, \
B, \
ldb, \
A, \
lda, \
&beta_val, \
C, \
N)); \
}
DEFINE_GEMM_FUNC(float, cublasSgemm);
DEFINE_GEMM_FUNC(double, cublasDgemm);
#undef DEFINE_GEMM_FUNC
template <>
DRAGON_API void Gemm<float, CUDAContext>(
DRAGON_API void GemmBatched<float16, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
const float* A,
const float* B,
const float16** A,
const float16** B,
const float beta,
float* C,
CUDAContext* ctx,
const string math_type) {
float16** C,
CUDAContext* ctx) {
int lda = TransA == CblasNoTrans ? K : M;
int ldb = TransB == CblasNoTrans ? N : K;
int ldc = N;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
thrust::device_vector<const void*> A_arr(A, A + batch_size);
thrust::device_vector<const void*> B_arr(B, B + batch_size);
thrust::device_vector<void*> C_arr(C, C + batch_size);
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSgemm(
CUBLAS_CHECK(cublasGemmBatchedEx(
ctx->cublas_handle(),
cuTransB,
cuTransA,
......@@ -826,54 +681,172 @@ DRAGON_API void Gemm<float, CUDAContext>(
M,
K,
&alpha,
B,
B_arr.data().get(),
CUDA_R_16F,
ldb,
A,
A_arr.data().get(),
CUDA_R_16F,
lda,
&beta,
C,
N));
C_arr.data().get(),
CUDA_R_16F,
ldc,
batch_size,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
#define DEFINE_BATCHED_GEMM_FUNC(T, cublas_func) \
template <> \
DRAGON_API void GemmBatched<T, CUDAContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const float alpha, \
const T** A, \
const T** B, \
const float beta, \
T** C, \
CUDAContext* ctx) { \
int lda = TransA == CblasNoTrans ? K : M; \
int ldb = TransB == CblasNoTrans ? N : K; \
int ldc = N; \
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
const auto alpha_val = static_cast<T>(alpha); \
const auto beta_val = static_cast<T>(beta); \
thrust::device_vector<const T*> A_arr(A, A + batch_size); \
thrust::device_vector<const T*> B_arr(B, B + batch_size); \
thrust::device_vector<T*> C_arr(C, C + batch_size); \
CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func( \
ctx->cublas_handle(), \
cuTransB, \
cuTransA, \
N, \
M, \
K, \
&alpha_val, \
B_arr.data().get(), \
ldb, \
A_arr.data().get(), \
lda, \
&beta_val, \
C_arr.data().get(), \
ldc, \
batch_size)); \
}
DEFINE_BATCHED_GEMM_FUNC(float, cublasSgemmBatched);
DEFINE_BATCHED_GEMM_FUNC(double, cublasDgemmBatched);
#undef DEFINE_BATCHED_GEMM_FUNC
template <>
DRAGON_API void Gemm<double, CUDAContext>(
DRAGON_API void GemmStridedBatched<float16, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const int A_stride,
const int B_stride,
const int C_stride,
const float alpha,
const double* A,
const double* B,
const float16* A,
const float16* B,
const float beta,
double* C,
CUDAContext* ctx,
const string math_type) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
float16* C,
CUDAContext* ctx) {
int lda = TransA == CblasNoTrans ? K : M;
int ldb = TransB == CblasNoTrans ? N : K;
int ldc = N;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
const auto alpha_val = static_cast<double>(alpha);
const auto beta_val = static_cast<double>(beta);
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasDgemm(
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha_val,
&alpha,
B,
CUDA_R_16F,
ldb,
B_stride,
A,
CUDA_R_16F,
lda,
&beta_val,
A_stride,
&beta,
C,
N));
CUDA_R_16F,
ldc,
C_stride,
batch_size,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
#define DEFINE_STRIDED_BATCHED_GEMM_FUNC(T, cublas_func) \
template <> \
DRAGON_API void GemmStridedBatched<T, CUDAContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const int A_stride, \
const int B_stride, \
const int C_stride, \
const float alpha, \
const T* A, \
const T* B, \
const float beta, \
T* C, \
CUDAContext* ctx) { \
int lda = TransA == CblasNoTrans ? K : M; \
int ldb = TransB == CblasNoTrans ? N : K; \
int ldc = N; \
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
const auto alpha_val = static_cast<T>(alpha); \
const auto beta_val = static_cast<T>(beta); \
CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func( \
ctx->cublas_handle(), \
cuTransB, \
cuTransA, \
N, \
M, \
K, \
&alpha_val, \
B, \
ldb, \
B_stride, \
A, \
lda, \
A_stride, \
&beta_val, \
C, \
ldc, \
C_stride, \
batch_size)); \
}
DEFINE_STRIDED_BATCHED_GEMM_FUNC(float, cublasSgemmStridedBatched);
DEFINE_STRIDED_BATCHED_GEMM_FUNC(double, cublasDgemmStridedBatched);
#undef DEFINE_STRIDED_BATCHED_GEMM_FUNC
} // namespace math
} // namespace dragon
......
......@@ -85,8 +85,7 @@ DRAGON_API void Gemv(
const T* x,
const float beta,
T* y,
Context* ctx,
const string math_type = "float32");
Context* ctx);
template <typename T, class Context>
DRAGON_API void Gemm(
......@@ -100,8 +99,40 @@ DRAGON_API void Gemm(
const T* B,
const float beta,
T* C,
Context* ctx,
const string math_type = "float32");
Context* ctx);
template <typename T, class Context>
DRAGON_API void GemmBatched(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
const T** A,
const T** B,
const float beta,
T** C,
Context* ctx);
template <typename T, class Context>
DRAGON_API void GemmStridedBatched(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const int A_stride,
const int B_stride,
const int C_stride,
const float alpha,
const T* A,
const T* B,
const float beta,
T* C,
Context* ctx);
} // namespace math
......
......@@ -158,15 +158,15 @@ __global__ void _BroadcastWhere(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_SET_FUNC(T1, T2) \
#define DEFINE_SET_FUNC(T, ScalarT) \
template <> \
DRAGON_API void Set<T1, CUDAContext>( \
DRAGON_API void Set<T, CUDAContext>( \
const int x_ndim, \
const int64_t* x_dims, \
const int y_ndim, \
const int64_t* y_dims, \
const T1* x, \
T1* y, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
int rows, cols; \
vec64_t X_dims(x_dims, x_dims + x_ndim); \
......@@ -189,8 +189,8 @@ __global__ void _BroadcastWhere(
ctx->cuda_stream()>>>( \
nthreads, \
cols, \
reinterpret_cast<const T2*>(x), \
reinterpret_cast<T2*>(y)); \
reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<ScalarT*>(y)); \
return; \
} \
if (math::utils::IsColwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
......@@ -202,8 +202,8 @@ __global__ void _BroadcastWhere(
ctx->cuda_stream()>>>( \
nthreads, \
cols, \
reinterpret_cast<const T2*>(x), \
reinterpret_cast<T2*>(y)); \
reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<ScalarT*>(y)); \
return; \
} \
vec64_t X_broadcast_strides, _; \
......@@ -226,8 +226,8 @@ __global__ void _BroadcastWhere(
Y_dims.size(), \
strides, \
dims, \
reinterpret_cast<const T2*>(x), \
reinterpret_cast<T2*>(y)); \
reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<ScalarT*>(y)); \
}
DEFINE_SET_FUNC(bool, uint8_t);
......@@ -235,8 +235,8 @@ DEFINE_SET_FUNC(int8_t, int8_t);
DEFINE_SET_FUNC(uint8_t, uint8_t);
DEFINE_SET_FUNC(int, int);
DEFINE_SET_FUNC(int64_t, int64_t);
DEFINE_SET_FUNC(float, float);
DEFINE_SET_FUNC(float16, half);
DEFINE_SET_FUNC(float, float);
DEFINE_SET_FUNC(double, double);
#undef DEFINE_SET_FUNC
......@@ -267,13 +267,31 @@ DEFINE_SET_FUNC(double, double);
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
const auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_RowwiseBinaryFunc<InputT, OutputT, Functor<InputT>, true> \
_RowwiseBinaryFunc< \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Functor<InputT>(), a, b, y); \
nthreads, \
cols, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} else { \
_RowwiseBinaryFunc<InputT, OutputT, Functor<InputT>, false> \
_RowwiseBinaryFunc< \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Functor<InputT>(), a, b, y); \
nthreads, \
cols, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} \
return; \
} \
......@@ -281,13 +299,31 @@ DEFINE_SET_FUNC(double, double);
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
const auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_ColwiseBinaryFunc<InputT, OutputT, Functor<InputT>, true> \
_ColwiseBinaryFunc< \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Functor<InputT>(), a, b, y); \
nthreads, \
cols, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} else { \
_ColwiseBinaryFunc<InputT, OutputT, Functor<InputT>, false> \
_ColwiseBinaryFunc< \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Functor<InputT>(), a, b, y); \
nthreads, \
cols, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} \
return; \
} \
......@@ -304,9 +340,9 @@ DEFINE_SET_FUNC(double, double);
y_dims.data[i] = Y_dims[i]; \
} \
_BroadcastBinaryFunc< \
InputT, \
OutputT, \
Functor<InputT>, \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
CUDA_TENSOR_MAX_DIMS> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
......@@ -314,89 +350,102 @@ DEFINE_SET_FUNC(double, double);
a_strides, \
b_strides, \
y_dims, \
Functor<InputT>(), \
a, \
b, \
y); \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
}
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float16, float16, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, float16, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float16, float16, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float16, float16, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, float16, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, float16, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, int8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, uint8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int64_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float16, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, double, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float16, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, double, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, int8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, uint8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int64_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float16, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, double, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float16, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, double, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, int8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, uint8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int64_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float16, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, double, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float16, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, T, dtype) \
#define DEFINE_BINARY_FUNC(name, T, ScalarT) \
template <> \
DRAGON_API void name<T, CUDAContext>( \
const int a_ndim, \
......@@ -412,9 +461,9 @@ DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor);
a_dims, \
b_ndim, \
b_dims, \
reinterpret_cast<const dtype*>(a), \
reinterpret_cast<const dtype*>(b), \
reinterpret_cast<dtype*>(y), \
reinterpret_cast<const ScalarT*>(a), \
reinterpret_cast<const ScalarT*>(b), \
reinterpret_cast<ScalarT*>(y), \
ctx); \
}
......@@ -423,130 +472,19 @@ DEFINE_BINARY_FUNC(Sub, bool, uint8_t); // Xor
DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, OutputT1, OutputT2, Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int a_ndim, \
const int64_t* a_dims, \
const int b_ndim, \
const int64_t* b_dims, \
const float16* a, \
const float16* b, \
OutputT1* y, \
CUDAContext* ctx) { \
int rows, cols, broadcast_1st; \
vec64_t A_dims(a_dims, a_dims + a_ndim); \
vec64_t B_dims(b_dims, b_dims + b_ndim); \
vec64_t A_broadcast_dims, B_broadcast_dims; \
math::utils::ComputeBinaryBroadcastDims( \
A_dims, B_dims, A_broadcast_dims, B_broadcast_dims); \
if (A_broadcast_dims == B_broadcast_dims) { \
auto count = std::accumulate( \
a_dims, a_dims + a_ndim, 1, std::multiplies<int64_t>()); \
name(count, a, b, y, ctx); \
return; \
} \
if (math::utils::IsRowwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_RowwiseBinaryFunc<half, OutputT2, Functor<half>, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
} else { \
_RowwiseBinaryFunc<half, OutputT2, Functor<half>, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
} \
return; \
} \
if (math::utils::IsColwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_ColwiseBinaryFunc<half, OutputT2, Functor<half>, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
} else { \
_ColwiseBinaryFunc<half, OutputT2, Functor<half>, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
} \
return; \
} \
vec64_t A_broadcast_strides, B_broadcast_strides, Y_dims; \
math::utils::ComputeBinaryBroadcastStrides( \
A_dims, B_dims, A_broadcast_strides, B_broadcast_strides, Y_dims); \
CUDA_TENSOR_DIMS_CHECK((int)Y_dims.size()); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> a_strides, b_strides, y_dims; \
const auto nthreads = std::accumulate( \
Y_dims.begin(), Y_dims.end(), 1, std::multiplies<int64_t>()); \
for (int i = 0; i < Y_dims.size(); ++i) { \
a_strides.data[i] = A_broadcast_strides[i]; \
b_strides.data[i] = B_broadcast_strides[i]; \
y_dims.data[i] = Y_dims[i]; \
} \
_BroadcastBinaryFunc<half, OutputT2, Functor<half>, CUDA_TENSOR_MAX_DIMS> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
Y_dims.size(), \
a_strides, \
b_strides, \
y_dims, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
}
DEFINE_BINARY_FUNC(Add, float16, half, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, half, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, float16, half, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, float16, half, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, half, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, half, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, half, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, bool, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, bool, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, bool, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, bool, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, bool, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_WHERE_FUNC(T1, T2) \
#define DEFINE_WHERE_FUNC(T, ScalarT) \
template <> \
DRAGON_API void Where<T1, CUDAContext>( \
DRAGON_API void Where<T, CUDAContext>( \
const int a_ndim, \
const int64_t* a_dims, \
const int b_ndim, \
const int64_t* b_dims, \
const int c_ndim, \
const int64_t* c_dims, \
const T1* a, \
const T1* b, \
const T* a, \
const T* b, \
const bool* c, \
T1* y, \
T* y, \
CUDAContext* ctx) { \
vec64_t A_dims(a_dims, a_dims + a_ndim); \
vec64_t B_dims(b_dims, b_dims + b_ndim); \
......@@ -597,10 +535,10 @@ DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, math::GreaterEqualFunctor);
b_strides, \
c_strides, \
y_dims, \
reinterpret_cast<const T2*>(a), \
reinterpret_cast<const T2*>(b), \
reinterpret_cast<const ScalarT*>(a), \
reinterpret_cast<const ScalarT*>(b), \
reinterpret_cast<const uint8_t*>(c), \
reinterpret_cast<T2*>(y)); \
reinterpret_cast<ScalarT*>(y)); \
}
DEFINE_WHERE_FUNC(bool, uint8_t);
......
......@@ -24,21 +24,21 @@ void _Cast(const int n, const InputT* x, OutputT* y) {
#define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \
void Cast<InputT, OutputT, CPUContext>( \
DRAGON_API void Cast<InputT, OutputT, CPUContext>( \
const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \
_Cast(n, x, y); \
}
#define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \
void Cast<InputT, OutputT, CPUContext>( \
DRAGON_API void Cast<InputT, OutputT, CPUContext>( \
const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<InputT>()) << " -> " \
<< types::to_string(TypeMeta::Make<OutputT>()); \
} \
template <> \
void Cast<OutputT, InputT, CPUContext>( \
DRAGON_API void Cast<OutputT, InputT, CPUContext>( \
const int n, const OutputT* x, InputT* y, CPUContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<OutputT>()) << " -> " \
......
......@@ -23,7 +23,7 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) {
#define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \
void Cast<InputT, OutputT, CUDAContext>( \
DRAGON_API void Cast<InputT, OutputT, CUDAContext>( \
const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \
_Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, \
......@@ -33,14 +33,14 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) {
#define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \
void Cast<InputT, OutputT, CUDAContext>( \
DRAGON_API void Cast<InputT, OutputT, CUDAContext>( \
const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<InputT>()) << " -> " \
<< types::to_string(TypeMeta::Make<OutputT>()); \
} \
template <> \
void Cast<OutputT, InputT, CUDAContext>( \
DRAGON_API void Cast<OutputT, InputT, CUDAContext>( \
const int n, const OutputT* x, InputT* y, CUDAContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<OutputT>()) << " -> " \
......
......@@ -599,14 +599,14 @@ DEFINE_POWX_FUNC(double);
#define DEFINE_NOT_ZERO_FUNC(T) \
template <> \
void NotZero<T, CUDAContext>( \
DRAGON_API void NotZero<T, CUDAContext>( \
const int count, const T* x, bool* y, CUDAContext* ctx) { \
_NotZero<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, x, y); \
}
template <>
void NotZero<float16, CUDAContext>(
DRAGON_API void NotZero<float16, CUDAContext>(
const int count,
const float16* x,
bool* y,
......@@ -742,7 +742,7 @@ DEFINE_BIAS_FUNC(float);
DEFINE_BIAS_FUNC(double);
#undef DEFINE_BIAS_FUNC
#define DEFINE_BINARY_FUNC(name, InputT, OutputT, Op) \
#define DEFINE_BINARY_FUNC(name, InputT, OutputT, Functor) \
template <> \
DRAGON_API void name<InputT, CUDAContext>( \
const int n, \
......@@ -754,94 +754,112 @@ DEFINE_BIAS_FUNC(double);
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(n, Op<InputT>(), a, b, y); \
ctx->cuda_stream()>>>( \
n, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
}
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float16, float16, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, float16, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float16, float16, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float16, float16, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, float16, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, float16, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, int8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, uint8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int64_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float16, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, double, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float16, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, double, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, int8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, uint8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int64_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float16, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, double, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float16, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, double, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, int8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, uint8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int64_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float16, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, double, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float16, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, T, dtype) \
#define DEFINE_BINARY_FUNC(name, T, ScalarT) \
template <> \
DRAGON_API void name<T, CUDAContext>( \
const int n, const T* a, const T* b, T* y, CUDAContext* ctx) { \
name( \
n, \
reinterpret_cast<const dtype*>(a), \
reinterpret_cast<const dtype*>(b), \
reinterpret_cast<dtype*>(y), \
reinterpret_cast<const ScalarT*>(a), \
reinterpret_cast<const ScalarT*>(b), \
reinterpret_cast<ScalarT*>(y), \
ctx); \
}
......@@ -850,76 +868,6 @@ DEFINE_BINARY_FUNC(Sub, bool, uint8_t); // Xor
DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, \
const float16* a, \
const float16* b, \
float16* y, \
CUDAContext* ctx) { \
if ((n & 1) == 0) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n >> 1), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n >> 1, \
Functor<half2>(), \
reinterpret_cast<const half2*>(a), \
reinterpret_cast<const half2*>(b), \
reinterpret_cast<half2*>(y)); \
} else { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<half*>(y)); \
} \
}
DEFINE_BINARY_FUNC(Add, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, math::MaxFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, \
const float16* a, \
const float16* b, \
bool* y, \
CUDAContext* ctx) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
y); \
}
DEFINE_BINARY_FUNC(Equal, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_WHERE_FUNC(T) \
template <> \
DRAGON_API void Where<T, CUDAContext>( \
......
......@@ -219,7 +219,7 @@ DEFINE_REDUCE_FUNC(Sum);
#define DEFINE_KERNEL_LAUNCHER(name) \
template <> \
void Reduce##name<float16, CPUContext>( \
DRAGON_API void Reduce##name<float16, CPUContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
......@@ -258,7 +258,7 @@ DRAGON_API float16 Sum<float16, CPUContext>(
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void Reduce##name<T, CPUContext>( \
DRAGON_API void Reduce##name<T, CPUContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
......@@ -298,7 +298,7 @@ DEFINE_KERNEL_LAUNCHER(Sum, double);
*y = val * T(scale); \
} \
template <> \
T Sum<T, CPUContext>( \
DRAGON_API T Sum<T, CPUContext>( \
const int n, const float scale, const T* x, CPUContext* ctx) { \
T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \
return val * T(scale); \
......
......@@ -174,7 +174,7 @@ DEFINE_REDUCE_DISPATCHER(Sum);
// We found that FP16 accumulator drops too many small values in
// empirical experiments.
template <>
void ReduceSum<float16, CUDAContext>(
DRAGON_API void ReduceSum<float16, CUDAContext>(
const int num_dims,
const int* dims,
const int num_axes,
......@@ -199,7 +199,7 @@ void ReduceSum<float16, CUDAContext>(
#define DEFINE_KERNEL_LAUNCHER(name, T, AccT, Reducer, kInit) \
template <> \
void Reduce##name<T, CUDAContext>( \
DRAGON_API void Reduce##name<T, CUDAContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
......
......@@ -174,7 +174,8 @@ void ComputeBinaryBroadcastDims(
const vec64_t& A_dims,
const vec64_t& B_dims,
vec64_t& A_broadcast_dims,
vec64_t& B_broadcast_dims) {
vec64_t& B_broadcast_dims,
int64_t* C_broadcast_dims) {
auto num_dims = std::max(A_dims.size(), B_dims.size());
A_broadcast_dims.resize(num_dims);
B_broadcast_dims.resize(num_dims);
......@@ -194,6 +195,16 @@ void ComputeBinaryBroadcastDims(
B_dims.begin(),
B_dims.end(),
B_broadcast_dims.begin() + num_dims - B_dims.size());
if (C_broadcast_dims != nullptr) {
for (int i = 0; i < num_dims; ++i) {
if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) {
C_broadcast_dims[i] = 0;
} else {
C_broadcast_dims[i] =
std::max(A_broadcast_dims[i], B_broadcast_dims[i]);
}
}
}
}
void ComputeBinaryBroadcastStrides(
......
......@@ -304,7 +304,8 @@ DRAGON_API void ComputeBinaryBroadcastDims(
const vec64_t& A_dims,
const vec64_t& B_dims,
vec64_t& A_broadcast_dims,
vec64_t& B_broadcast_dims);
vec64_t& B_broadcast_dims,
int64_t* C_broadcast_dims = nullptr);
DRAGON_API void ComputeBinaryBroadcastStrides(
const vec64_t& A_dims,
......@@ -326,22 +327,22 @@ DRAGON_API void TransposeAxesForReduce(
const int* reduce_axes,
int* transpose_axes);
template <typename dim_t, typename stride_t>
template <typename DimT, typename StrideT>
inline void
ComputeStrides(const int num_dims, const dim_t* dims, stride_t* strides) {
ComputeStrides(const int num_dims, const DimT* dims, StrideT* strides) {
int64_t cur_stride = 1;
for (int i = num_dims - 1; i >= 0; --i) {
strides[i] = stride_t(cur_stride);
strides[i] = StrideT(cur_stride);
cur_stride *= int64_t(dims[i]);
}
}
template <typename dim_t, typename axis_t, typename stride_t>
template <typename DimT, typename AxisT, typename StrideT>
inline void ComputeTransposeStrides(
const int num_dims,
const dim_t* dims,
const axis_t* axes,
stride_t* strides) {
const DimT* dims,
const AxisT* axes,
StrideT* strides) {
vec64_t buf(num_dims);
int64_t cur_stride = 1;
for (int i = num_dims - 1; i >= 0; --i) {
......@@ -349,13 +350,25 @@ inline void ComputeTransposeStrides(
cur_stride *= int64_t(dims[i]);
}
for (int i = 0; i < num_dims; ++i) {
strides[i] = stride_t(buf[axes[i]]);
strides[i] = StrideT(buf[axes[i]]);
}
}
template <typename dim_t, typename index_t>
template <typename DimT, typename IndexT>
inline IndexT
GetIndexFromDims(const int num_dims, const DimT* dims, IndexT* index) {
IndexT ret = 0;
for (int i = 0; i < num_dims; ++i) {
if (dims[i] > 1) {
ret = ret * dims[i] + index[i];
}
}
return ret;
}
template <typename DimT, typename IndexT>
inline void
IncreaseIndexInDims(const int num_dims, const dim_t* dims, index_t* index) {
IncreaseIndexInDims(const int num_dims, const DimT* dims, IndexT* index) {
for (int i = num_dims - 1; i >= 0; --i) {
++index[i];
if (index[i] >= dims[i]) {
......
......@@ -116,11 +116,9 @@ class Dense(Layer):
self.built = True
def call(self, inputs):
outputs = math_ops.fully_connected(
[inputs, self.kernel] + [self.bias]
if self.use_bias else [],
axis=-1,
transW=False,
outputs = math_ops.gemm(
[inputs, self.kernel] +
([self.bias] if self.use_bias else []),
)
if self.activation is not None:
return self.activation(outputs)
......
......@@ -703,38 +703,38 @@ def log(x, name=None):
return math_ops.log(x, name=name)
def matmul(
a,
b,
transpose_a=False,
transpose_b=False,
name=None,
):
def matmul(a, b, name=None):
r"""Compute the matrix multiplication.
.. math:: y = a \times b
.. math:: \text{out} = a \times b
The rank of ``a`` and ``b`` should be equal and >= 2:
The behavior depends on the shape of input tensors:
```python
# Ok, a typical matrix multiplication
a = tf.ones((2, 3), 'float32')
b = tf.ones((3, 3), 'float32')
print(tf.linalg.matmul(a, b))
* If both tensors are 1d, computes the vector product.
* If tensors are 1d and >=2d, computes the vector-matrix multiplication.
* If tensors are >=2d and 1d, computes the matrix-vector multiplication.
* If both tensors are >= 2d, computes the matrix-matrix multiplication.
* If one tensor is >= 3d, applies batching and broadcasting to the computation.
# Compute a batch matrix multiplication if rank > 2
aa = tf.ones((4, 2, 3), 'float32')
bb = tf.ones((4, 3, 3), 'float32')
print(tf.linalg.matmul(aa, bb))
```
If inputs are transposed, remember to transpose them back:
Examples:
```python
# Vector x Vector
a = tf.ones((2,), 'float32')
b = tf.ones((2,), 'float32')
print(tf.linalg.matmul(a, b))
# Vector x Matrix
a = tf.ones((2,), 'float32')
b = tf.ones((2, 3), 'float32')
print(tf.linalg.matmul(a, b))
# Matrix x Vector
a = tf.ones((3, 2), 'float32')
b = tf.ones((3, 3), 'float32')
print(tf.linalg.matmul(a, b)) # ``a`` takes the wrong dimensions
print(tf.linalg.matmul(a, b, transpose_a=True)) # Ok
b = tf.ones((2,), 'float32')
print(tf.linalg.matmul(a, b))
# Matrix x Matrix
a = tf.ones((2, 3), 'float32')
b = tf.ones((3, 2), 'float32')
print(tf.linalg.matmul(a, b))
```
Parameters
......@@ -743,10 +743,6 @@ def matmul(
The matrix :math:`a`.
b : dragon.Tensor
The matrix :math:`b`.
transpose_a : bool, optional, default=False
**True** to transpose :math:`a` before computing.
transpose_b : bool, optional, default=False
**True** to transpose :math:`b` before computing.
name : str, optional
The operation name.
......@@ -756,12 +752,7 @@ def matmul(
The output tensor.
"""
return math_ops.matmul(
[a, b],
transpose_a=transpose_a,
transpose_b=transpose_b,
name=name,
)
return math_ops.matmul([a, b], name=name)
def multiply(x, y, name=None):
......
......@@ -85,25 +85,25 @@ class Dense(layer.Layer):
raise AssertionError('The input dimension must be rank 2.'
'Please reshape or flatten it.')
if self.in_channels:
shape = [self.n_units, self.in_channels]
shape = [self.in_channels, self.n_units]
else:
self.in_channels = inputs_shape[1]
shape = [self.n_units, inputs_shape[1]]
shape = [inputs_shape[1], self.n_units]
self.W = self.add_weight(
name="weights",
name='weights',
shape=shape,
init=self.W_init,
)
if self.b_init:
self.b = self.add_weight(
name="biases",
name='biases',
shape=[self.n_units],
init=self.b_init,
)
def forward(self, inputs):
outputs = math_ops.fully_connected(
[inputs, self.W] + ([self.b] if self.b_init else []), axis=1)
outputs = math_ops.gemm(
[inputs, self.W] + ([self.b] if self.b_init else []))
if self.act:
outputs = self.act(outputs)
return outputs
......@@ -281,17 +281,15 @@ class TestOpSpec(unittest.TestCase):
self.assertEqual(dragon.flatten(
self.sym4, axis=1, num_axes=-1).shape, (1, None))
def test_fully_connected(self):
def test_gemm(self):
w = dragon.Tensor((3, 2))
with dragon.graph_mode():
self.assertEqual(dragon.nn.fully_connected(
[self.sym1, w]).shape, (None, 3))
self.assertEqual(dragon.nn.fully_connected(
[self.sym1, w], transpose_w=False).shape, (None, 2))
self.assertEqual(dragon.nn.fully_connected(
[self.sym1, w], axis=-1).shape, None)
self.assertEqual(dragon.nn.fully_connected(
[self.sym1, self.sym1]).shape, (None, None))
self.assertEqual(dragon.math.gemm(
[self.sym1, w]).shape, None)
self.assertEqual(dragon.math.gemm(
[self.sym1, w], axis=1).shape, (None, 2))
self.assertEqual(dragon.math.gemm(
[self.sym1, self.sym1]).shape, None)
def test_index_select(self):
with dragon.graph_mode():
......@@ -325,7 +323,9 @@ class TestOpSpec(unittest.TestCase):
self.assertEqual(dragon.math.matmul(
[self.sym1, self.sym3]).shape, None)
self.assertEqual(dragon.math.matmul(
[self.sym2, self.sym3]).shape, None)
[self.sym2, self.sym3]).shape, (None,))
self.assertEqual(dragon.math.matmul(
[self.sym3, self.sym2]).shape, (1,))
self.assertEqual(dragon.math.matmul(
[self.sym3, self.sym3]).shape, (1, None))
self.assertEqual(dragon.math.matmul(
......
......@@ -1868,22 +1868,22 @@ class TestMathOps(OpTestCase):
with dragon.device('cuda'):
self.test_floor()
def test_fully_connected(self):
def test_gemm(self):
entries = [((2, 3), (3, 4), (4,), False),
((2, 3), (4, 3), (4,), True)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for x_shape, w_shape, b_shape, trans_w in entries:
for x_shape, w_shape, b_shape, trans_b in entries:
data1, data2, data3 = arange(x_shape), arange(w_shape), arange(b_shape)
x, w, b = new_tensor(data1), new_tensor(data2), new_tensor(data3)
with dragon.GradientTape() as tape:
tape.watch([x, w, b])
y = dragon.nn.fully_connected([x, w, b], transpose_w=trans_w)
y = dragon.math.gemm([x, w, b], transpose_b=trans_b)
data4 = arange(y.shape)
dy = new_tensor(data4)
dx, dw, db = tape.gradient(y, [x, w, b], output_gradients=[dy])
result = np.matmul(data1, data2.T if trans_w else data2) + data3
if trans_w:
result = np.matmul(data1, data2.T if trans_b else data2) + data3
if trans_b:
grad1 = np.matmul(data4, data2)
grad2 = np.matmul(data4.T, data1)
else:
......@@ -1894,9 +1894,9 @@ class TestMathOps(OpTestCase):
[result, grad1, grad2, reduce_like(data4, data3)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_fully_connected_cuda(self):
def test_gemm_cuda(self):
with dragon.device('cuda'):
self.test_fully_connected()
self.test_gemm()
def test_greater(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......@@ -1997,40 +1997,62 @@ class TestMathOps(OpTestCase):
self.test_log()
def test_matmul(self):
entries = [
((2, 3), (3, 4), False, False),
((2, 3), (4, 3), False, True),
((3, 2), (3, 4), True, False),
((3, 2), (4, 3), True, True)]
entries = [((2, 3), (3, 4)),
((1, 2, 3), (2, 3, 4)),
((2, 2, 3), (1, 3, 4)),
((2, 2, 3), (2, 3, 4)),
((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for execution in ('EAGER_MODE', 'GRAPH_MODE',):
with execution_context().mode(execution):
for a_shape, b_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
a, b = new_tensor(data1), new_tensor(data2)
with dragon.GradientTape() as tape:
tape.watch([a, b])
y = dragon.math.matmul([a, b])
data3 = arange(y.shape)
dy = new_tensor(data3)
da, db = tape.gradient(y, [a, b], output_gradients=[dy])
grad1 = np.matmul(data3, transpose_last(data2, 2))
grad2 = np.matmul(transpose_last(data1, 2), data3)
self.assertEqual(
[y, da, db],
[np.matmul(data1, data2),
reduce_like(grad1, data1),
reduce_like(grad2, data2)])
entries = [((2,), (2,), (2, 1), (2, 1), (1, 1)),
((2,), (2, 3), (2, 1), (2, 3), (1, 3)),
((2, 3), (3,), (2, 3), (1, 3), (2, 1)),
((2,), (4, 2, 3), (1, 2, 1), (4, 2, 3), (4, 1, 3)),
((4, 2, 3), (3,), (4, 2, 3), (1, 1, 3), (4, 2, 1))]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for a_shape, b_shape, trans_a, trans_b in entries:
for a_shape, b_shape, da_shape, db_shape, dy_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
data4 = data1 if len(a_shape) > len(b_shape) else data2
a, b = new_tensor(data1), new_tensor(data2)
with dragon.GradientTape() as tape:
tape.watch([a, b])
y = dragon.math.matmul([a, b], trans_a, trans_b)
y = dragon.math.matmul([a, b])
data3 = arange(y.shape)
dy = new_tensor(data3)
da, db = tape.gradient(y, [a, b], output_gradients=[dy])
if trans_a:
if trans_b:
grad1 = np.matmul(data2.T, data3.T)
grad2 = np.matmul(data3.T, data1.T)
else:
grad1 = np.matmul(data2, data3.T)
grad2 = np.matmul(data1, data3)
else:
if trans_b:
grad1 = np.matmul(data3, data2)
grad2 = np.matmul(data3.T, data1)
else:
grad1 = np.matmul(data3, data2.T)
grad2 = np.matmul(data1.T, data3)
grad1 = data3.reshape(dy_shape) * data2.reshape(db_shape)
grad2 = data1.reshape(da_shape) * data3.reshape(dy_shape)
grad1_axes, grad2_axes = [], []
for i in range(len(dy_shape)):
if da_shape[i] != db_shape[i]:
if da_shape[i] == 1:
grad1_axes.append(i)
if db_shape[i] == 1:
grad2_axes.append(i)
self.assertEqual(
[y, da, db],
[np.matmul(data1.T if trans_a else data1,
data2.T if trans_b else data2), grad1, grad2])
[np.matmul(data1, data2),
reduce(grad1, tuple(grad1_axes)).reshape(data1.shape),
reduce(grad2, tuple(grad2_axes)).reshape(data2.shape)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_matmul_cuda(self):
......@@ -4145,6 +4167,16 @@ def reduce_like(data, other, reduction='sum'):
return data
def transpose_last(data, num_axes=None, axes=None):
"""Transpose the last axes of data."""
if axes is None and num_axes is not None:
axes = list(range(num_axes))[::-1]
perm = list(range(len(data.shape)))
start_axis = len(perm) - len(axes)
perm[start_axis:] = [v + start_axis for v in axes]
return np.transpose(data, perm)
def uniform(shape, dtype='float32'):
"""Return the uniform data with given shape."""
return np.random.uniform(-1., 1., size=shape).astype(dtype)
......
......@@ -619,39 +619,54 @@ class TestModules(OpTestCase):
self.assertEqual(m4(x), np.pad(data, pads, 'constant'))
def test_pool1d(self):
entries = [((2, 2, 2,), (2,), 2, 1, 'MAX'),
((2, 2, 2,), (2,), 2, 1, 'AVG')]
entries = [((2, 2, 2,), (2,), 2, 1, 'MaxPool1d'),
((2, 2, 2,), (2,), 2, 1, 'AvgPool1d'),
((2, 2, 2,), (1,), 1, 0, 'AdaptiveMaxPool1d'),
((2, 2, 2,), (1,), 1, 0, 'AdaptiveAvgPool1d')]
for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool1d if mode == 'AVG' else torch.nn.MaxPool1d
module_cls = getattr(torch.nn, mode)
x = new_tensor(data)
if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.)
result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result)
def test_pool2d(self):
entries = [((2, 2, 2, 2), (2, 2), 2, 1, 'MAX'),
((2, 2, 2, 2), (2, 2), 2, 1, 'AVG')]
entries = [((2, 2, 2, 2), (2, 2), 2, 1, 'MaxPool2d'),
((2, 2, 2, 2), (2, 2), 2, 1, 'AvgPool2d'),
((2, 2, 2, 2), (1, 1), 1, 0, 'AdaptiveMaxPool2d'),
((2, 2, 2, 2), (1, 1), 1, 0, 'AdaptiveAvgPool2d')]
for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool2d if mode == 'AVG' else torch.nn.MaxPool2d
module_cls = getattr(torch.nn, mode)
x = new_tensor(data)
if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.)
result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result)
def test_pool3d(self):
entries = [((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'MAX'),
((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'AVG')]
entries = [((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'MaxPool3d'),
((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'AvgPool3d'),
((2, 2, 2, 2, 2), (1, 1, 1), 1, 0, 'AdaptiveMaxPool3d'),
((2, 2, 2, 2, 2), (1, 1, 1), 1, 0, 'AdaptiveAvgPool3d')]
for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool3d if mode == 'AVG' else torch.nn.MaxPool3d
module_cls = getattr(torch.nn, mode)
x = new_tensor(data)
if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.)
result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result)
def test_prelu(self):
......
......@@ -95,6 +95,16 @@ class TestTensorOps(OpTestCase):
a += b
self.assertEqual(a, data1 + data2)
def test_addmm(self):
entries = [((2, 3), (3, 4), (2, 4))]
for a_shape, b_shape, c_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
data3 = arange(c_shape)
a, b = new_tensor(data1), new_tensor(data2)
c = new_tensor(data3)
y = c.addmm(a, b)
self.assertEqual(y, np.matmul(data1, data2) + data3)
def test_argmax(self):
entries = [(0, True), (0, False), (1, True), (1, False), (None, False)]
for axis, keepdims in entries:
......@@ -115,6 +125,18 @@ class TestTensorOps(OpTestCase):
result = np.expand_dims(result, axis)
self.assertEqual(x.argmin(axis, keepdims), result)
def test_baddbmm(self):
entries = [((2, 2, 3), (2, 3, 4), (2, 2, 4))]
for a_shape, b_shape, c_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
data3 = arange(c_shape)
a, b = new_tensor(data1), new_tensor(data2)
c = new_tensor(data3)
y = c.baddbmm(a, b)
self.assertEqual(y, np.matmul(data1, data2) + data3)
c.baddbmm_(a, b)
self.assertEqual(c, np.matmul(data1, data2) + data3)
def test_bitwise_not(self):
for shape in self.unary_test_shapes:
data = np.random.binomial(1, 0.5, shape).astype('bool')
......@@ -132,6 +154,18 @@ class TestTensorOps(OpTestCase):
a.bitwise_xor_(b)
self.assertEqual(a, np.bitwise_xor(data1, data2))
def test_bmm(self):
test_shapes = [((1, 2, 3), (2, 3, 4)),
((2, 2, 3), (1, 3, 4)),
((2, 2, 3), (2, 3, 4)),
((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for a_shape, b_shape in test_shapes:
data1, data2 = arange(a_shape), arange(b_shape, 1)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.bmm(b), np.matmul(data1, data2))
def test_ceil(self):
data = np.array([1.4, 1.7, 2.0])
x = new_tensor(data)
......@@ -334,6 +368,24 @@ class TestTensorOps(OpTestCase):
data[data > 2] = 0
self.assertEqual(x, data)
def test_matmul(self):
test_shapes = [((2,), (2,)),
((2,), (2, 3)),
((2, 3), (3,)),
((2, 3), (3, 4)),
((2,), (4, 2, 3)),
((4, 2, 3), (3,)),
((1, 2, 3), (2, 3, 4)),
((2, 2, 3), (1, 3, 4)),
((2, 2, 3), (2, 3, 4)),
((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for a_shape, b_shape in test_shapes:
data1, data2 = arange(a_shape), arange(b_shape, 1)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.matmul(b), np.matmul(data1, data2))
def test_max(self):
entries = [(0, True), (0, False),
(1, True), (1, False),
......@@ -382,20 +434,12 @@ class TestTensorOps(OpTestCase):
self.assertEqual(y, np.minimum(data1, data2))
def test_mm(self):
entries = [
((2, 3), (3, 4), False, False),
((2, 3), (4, 3), False, True),
((3, 2), (3, 4), True, False),
((3, 2), (4, 3), True, True)]
for a_shape, b_shape, trans_a, trans_b in entries:
entries = [((2, 3), (3, 4))]
for a_shape, b_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
a, b = new_tensor(data1), new_tensor(data2)
if trans_a or trans_b:
y = torch.mm(a, b, trans_a, trans_b)
else:
y = a.mm(b)
self.assertEqual(y, np.matmul(data1.T if trans_a else data1,
data2.T if trans_b else data2))
self.assertEqual(y, np.matmul(data1, data2))
def test_mul(self):
for a_shape, b_shape in self.binary_test_shapes:
......
......@@ -94,9 +94,12 @@ from dragon.vm.torch.core.ops.init.functional import zeros
from dragon.vm.torch.core.ops.init.functional import zeros_like
from dragon.vm.torch.core.ops.math.functional import abs
from dragon.vm.torch.core.ops.math.functional import add
from dragon.vm.torch.core.ops.math.functional import addmm
from dragon.vm.torch.core.ops.math.functional import axpby
from dragon.vm.torch.core.ops.math.functional import baddbmm
from dragon.vm.torch.core.ops.math.functional import bitwise_not
from dragon.vm.torch.core.ops.math.functional import bitwise_xor
from dragon.vm.torch.core.ops.math.functional import bmm
from dragon.vm.torch.core.ops.math.functional import ceil
from dragon.vm.torch.core.ops.math.functional import clamp
from dragon.vm.torch.core.ops.math.functional import cos
......@@ -112,6 +115,7 @@ from dragon.vm.torch.core.ops.math.functional import le
from dragon.vm.torch.core.ops.math.functional import log
from dragon.vm.torch.core.ops.math.functional import logsumexp
from dragon.vm.torch.core.ops.math.functional import lt
from dragon.vm.torch.core.ops.math.functional import matmul
from dragon.vm.torch.core.ops.math.functional import maximum
from dragon.vm.torch.core.ops.math.functional import minimum
from dragon.vm.torch.core.ops.math.functional import mm
......
......@@ -76,6 +76,12 @@ from dragon.vm.torch.core.nn.modules.padding import ReplicationPad1d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad2d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad3d
from dragon.vm.torch.core.nn.modules.padding import ZeroPad2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool1d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool3d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool1d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool3d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool1d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool2d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool3d
......
......@@ -14,6 +14,12 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool1d
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool2d
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool3d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool1d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool2d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool3d
from dragon.vm.torch.core.nn.functional import avg_pool1d
from dragon.vm.torch.core.nn.functional import avg_pool2d
from dragon.vm.torch.core.nn.functional import avg_pool3d
......
......@@ -76,7 +76,7 @@ class Function(object):
Parameters
----------
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......
......@@ -18,16 +18,166 @@ from dragon.core.util import nest
from dragon.vm.torch.core.nn.modules import _functions
from dragon.vm.torch.core.nn import _reduction
from dragon.vm.torch.core.nn.modules import utils
from dragon.vm.torch.core.ops.math import _functions as _math_functions
from dragon.vm.torch.core.ops.math import functional as math_funcs
def adaptive_avg_pool1d(input, output_size):
"""Apply the 1d adaptive average pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveAvgPool1d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-1:], utils._single(output_size))
return _pool(input, _pool_mode='AVG', _nd_util=utils._single, **kwargs)
def adaptive_avg_pool2d(input, output_size):
"""Apply the 2d adaptive average pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveAvgPool2d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-2:], utils._pair(output_size))
return _pool(input, _pool_mode='AVG', _nd_util=utils._pair, **kwargs)
def adaptive_avg_pool3d(input, output_size):
"""Apply the 3d adaptive average pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveAvgPool3d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-3:], utils._triple(output_size))
return _pool(input, _pool_mode='AVG', _nd_util=utils._triple, **kwargs)
def adaptive_max_pool1d(input, output_size):
"""Apply the 1d adaptive max pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveMaxPool1d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-1:], utils._single(output_size))
return _pool(input, _pool_mode='MAX', _nd_util=utils._single, **kwargs)
def adaptive_max_pool2d(input, output_size):
"""Apply the 2d adaptive max pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveMaxPool2d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-2:], utils._pair(output_size))
return _pool(input, _pool_mode='MAX', _nd_util=utils._pair, **kwargs)
def adaptive_max_pool3d(input, output_size):
"""Apply the 3d adaptive max pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveMaxPool3d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-3:], utils._triple(output_size))
return _pool(input, _pool_mode='MAX', _nd_util=utils._triple, **kwargs)
def avg_pool1d(
input,
kernel_size,
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
r"""Apply the 1d average pooling to input.
......@@ -36,15 +186,13 @@ def avg_pool1d(
input : dragon.vm.torch.Tensor
The input tensor.
kernel_size : Union[int, Sequence[int]]
The size of sliding window.
The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns
-------
......@@ -65,7 +213,6 @@ def avg_pool2d(
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
r"""Apply the 2d average pooling to input.
......@@ -74,15 +221,13 @@ def avg_pool2d(
input : dragon.vm.torch.Tensor
The input tensor.
kernel_size : Union[int, Sequence[int]]
The size of sliding window.
The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns
-------
......@@ -103,7 +248,6 @@ def avg_pool3d(
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
r"""Apply the 3d average pooling to input.
......@@ -112,15 +256,13 @@ def avg_pool3d(
input : dragon.vm.torch.Tensor
The input tensor.
kernel_size : Union[int, Sequence[int]]
The size of sliding window.
The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns
-------
......@@ -262,9 +404,9 @@ def conv1d(
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias tensor.
The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
dilation : Union[int, Sequence[int]], optional, default=1
......@@ -303,9 +445,9 @@ def conv2d(
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias tensor.
The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
dilation : Union[int, Sequence[int]], optional, default=1
......@@ -344,9 +486,9 @@ def conv3d(
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias tensor.
The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
dilation : Union[int, Sequence[int]], optional, default=1
......@@ -386,9 +528,9 @@ def conv_transpose1d(
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias.
The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1
The stride of slidaing window.
The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
output_padding : int, optional, default=1
......@@ -430,9 +572,9 @@ def conv_transpose2d(
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias.
The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
output_padding : int, optional, default=1
......@@ -474,9 +616,9 @@ def conv_transpose3d(
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias.
The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
output_padding : int, optional, default=1
......@@ -604,9 +746,9 @@ def depthwise_conv2d(
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias.
The bias tensor.
stride : Sequence[int], default=1
The stride of sliding window.
The stride of convolution window.
padding : Sequence[int], default=0
The zero padding size.
dilation : Sequence[int], default=1
......@@ -1093,7 +1235,7 @@ def leaky_relu(input, negative_slope=0.01, inplace=False):
def linear(input, weight, bias=None):
r"""Apply the linear transformation to input.
.. math:: y = Wx + b
.. math:: \text{out} = \text{input} \times \text{weight}^{T} + \text{bias}
Parameters
----------
......@@ -1102,7 +1244,7 @@ def linear(input, weight, bias=None):
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias.
The bias tensor.
Returns
-------
......@@ -1114,7 +1256,9 @@ def linear(input, weight, bias=None):
`torch.nn.Linear(...)`_
"""
return _functions.Linear.instantiate(input.device).apply(input, weight, bias)
return _math_functions.Gemm \
.instantiate(input.device, transB=True) \
.apply(input, weight, bias)
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
......@@ -1217,7 +1361,6 @@ def max_pool1d(
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
r"""Apply the 1d max pooling to input.
......@@ -1226,15 +1369,13 @@ def max_pool1d(
input : dragon.vm.torch.Tensor
The input tensor.
kernel_size : Union[int, Sequence[int]]
The size of sliding window.
The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns
-------
......@@ -1255,7 +1396,6 @@ def max_pool2d(
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
r"""Apply the 2d max pooling to input.
......@@ -1264,15 +1404,13 @@ def max_pool2d(
input : dragon.vm.torch.Tensor
The input tensor.
kernel_size : Union[int, Sequence[int]]
The size of sliding window.
The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns
-------
......@@ -1293,7 +1431,6 @@ def max_pool3d(
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
r"""Apply the 3d max pooling to input.
......@@ -1302,15 +1439,13 @@ def max_pool3d(
input : dragon.vm.torch.Tensor
The input tensor.
kernel_size : Union[int, Sequence[int]]
The size of sliding window.
The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns
-------
......@@ -1442,7 +1577,7 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None):
eps : float, optional, default=1e-12
The value to :math:`\epsilon`.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -2095,6 +2230,7 @@ def _conv(
group=groups,
bias=bias is not None,
dtype=weight.dtype,
input_shape=input.shape,
).apply(input, weight, bias)
......@@ -2124,6 +2260,7 @@ def _conv_transpose(
output_padding=_nd_util(output_padding),
bias=bias is not None,
dtype=weight.dtype,
input_shape=input.shape,
).apply(input, weight, bias)
......@@ -2133,7 +2270,6 @@ def _pool(
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
_pool_mode='MAX',
_nd_util=utils._pair,
_pool_fn=_functions.Pool,
......@@ -2145,5 +2281,4 @@ def _pool(
pads=_nd_util(padding),
mode=_pool_mode,
ceil_mode=ceil_mode,
global_pool=global_pool,
).apply(input)
......@@ -86,7 +86,6 @@ class Pool(function.Function):
self.pads = kwargs.get('pads', 0)
self.ceil_mode = kwargs.get('ceil_mode', False)
self.mode = kwargs.get('mode', 'MAX')
self.global_pool = kwargs.get('global_pool', False)
def attributes(self):
return {
......@@ -98,7 +97,6 @@ class Pool(function.Function):
'ceil_mode': self.ceil_mode,
'mode': self.mode,
'data_format': 'NCHW',
'global_pool': self.global_pool,
}
}
......@@ -316,24 +314,6 @@ class L2Loss(Loss):
}
class Linear(function.Function):
"""Linear function."""
def attributes(self):
return {
'op_type': 'FullyConnected',
'arguments': {
'axis': -1,
'transW': True,
},
}
def forward(self, input, weight, bias=None, out=None):
inputs = [input, weight] + ([bias] if bias else [])
outputs = [out] if out else [self.alloc()]
return self.dispatch(inputs, outputs)
class LocalResponseNorm(function.Function):
"""LocalResponseNorm function."""
......
......@@ -48,7 +48,7 @@ class Identity(Module):
class Linear(Module):
r"""Apply the linear transformation.
.. math:: y = Wx + b
.. math:: \text{out} = \text{input} \times \text{weight}^{T} + \text{bias}
Examples:
......
......@@ -18,6 +18,17 @@ from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules.module import Module
class _AdaptivePoolNd(Module):
"""Apply the n-dimension adaptive pooling."""
def __init__(self, output_size):
super(_AdaptivePoolNd, self).__init__()
self.output_size = output_size
def extra_repr(self):
return 'output_size={}'.format(self.output_size)
class _PoolNd(Module):
"""Apply the n-dimension pooling."""
......@@ -27,24 +38,243 @@ class _PoolNd(Module):
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
super(_PoolNd, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.ceil_mode = ceil_mode
self.global_pool = global_pool
def extra_repr(self):
return 'kernel_size={kernel_size}, ' \
'stride={stride}, ' \
'padding={padding}, ' \
'ceil_mode={ceil_mode}, ' \
'global_pool={global_pool}' \
'ceil_mode={ceil_mode}' \
.format(**self.__dict__)
class AdaptiveAvgPool1d(_AdaptivePoolNd):
r"""Apply the 1d adaptive average pooling.
This module excepts the input size :math:`(N, C, H)`,
and output size is :math:`(N, C, H_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of data.
Examples:
```python
m = torch.nn.AdaptiveAvgPool1d(1)
x = torch.ones(2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_avg_pool1d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveAvgPool1d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveAvgPool1d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_avg_pool1d(input, self.output_size)
class AdaptiveAvgPool2d(_AdaptivePoolNd):
r"""Apply the 2d adaptive average pooling.
This module excepts the input size :math:`(N, C, H, W)`,
and output size is :math:`(N, C, H_{\text{out}}, W_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`H` and :math:`W` are the height and width of data.
Examples:
```python
m = torch.nn.AdaptiveAvgPool2d(1)
x = torch.ones(2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_avg_pool2d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveAvgPool2d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveAvgPool2d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_avg_pool2d(input, self.output_size)
class AdaptiveAvgPool3d(_AdaptivePoolNd):
r"""Apply the 3d adaptive average pooling.
This module excepts the input size :math:`(N, C, D, H, W)`,
and output size is :math:`(N, C, D_{\text{out}}, H_{\text{out}}, W_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`D`, :math:`H` and :math:`W` are the depth, height and width of data.
Examples:
```python
m = torch.nn.AdaptiveAvgPool3d(1)
x = torch.ones(2, 2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_avg_pool3d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveAvgPool3d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveAvgPool3d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_avg_pool3d(input, self.output_size)
class AdaptiveMaxPool1d(_AdaptivePoolNd):
r"""Apply the 1d adaptive max pooling.
This module excepts the input size :math:`(N, C, H)`,
and output size is :math:`(N, C, H_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of data.
Examples:
```python
m = torch.nn.AdaptiveMaxPool1d(1)
x = torch.ones(2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_max_pool1d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveMaxPool1d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveMaxPool1d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_max_pool1d(input, self.output_size)
class AdaptiveMaxPool2d(_AdaptivePoolNd):
r"""Apply the 2d adaptive max pooling.
This module excepts the input size :math:`(N, C, H, W)`,
and output size is :math:`(N, C, H_{\text{out}}, W_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`H` and :math:`W` are the height and width of data.
Examples:
```python
m = torch.nn.AdaptiveMaxPool2d(1)
x = torch.ones(2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_max_pool2d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveMaxPool2d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveMaxPool2d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_max_pool2d(input, self.output_size)
class AdaptiveMaxPool3d(_AdaptivePoolNd):
r"""Apply the 3d adaptive max pooling.
This module excepts the input size :math:`(N, C, D, H, W)`,
and output size is :math:`(N, C, D_{\text{out}}, H_{\text{out}}, W_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`D`, :math:`H` and :math:`W` are the depth, height and width of data.
Examples:
```python
m = torch.nn.AdaptiveMaxPool3d(1)
x = torch.ones(2, 2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_max_pool3d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveMaxPool3d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveMaxPool3d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_max_pool3d(input, self.output_size)
class AvgPool1d(_PoolNd):
r"""Apply the 1d average pooling.
......@@ -73,7 +303,6 @@ class AvgPool1d(_PoolNd):
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
"""Create a ``AvgPool1d`` module.
......@@ -87,8 +316,6 @@ class AvgPool1d(_PoolNd):
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
"""
super(AvgPool1d, self).__init__(
......@@ -96,7 +323,6 @@ class AvgPool1d(_PoolNd):
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
global_pool=global_pool,
)
def forward(self, input):
......@@ -106,7 +332,6 @@ class AvgPool1d(_PoolNd):
stride=self.stride,
padding=self.padding,
ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
)
......@@ -138,7 +363,6 @@ class AvgPool2d(_PoolNd):
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
"""Create a ``AvgPool2d`` module.
......@@ -152,8 +376,6 @@ class AvgPool2d(_PoolNd):
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
"""
super(AvgPool2d, self).__init__(
......@@ -161,7 +383,6 @@ class AvgPool2d(_PoolNd):
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
global_pool=global_pool,
)
def forward(self, input):
......@@ -171,7 +392,6 @@ class AvgPool2d(_PoolNd):
stride=self.stride,
padding=self.padding,
ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
)
......@@ -203,7 +423,6 @@ class AvgPool3d(_PoolNd):
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
"""Create a ``AvgPool3d`` module.
......@@ -217,8 +436,6 @@ class AvgPool3d(_PoolNd):
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
"""
super(AvgPool3d, self).__init__(
......@@ -226,7 +443,6 @@ class AvgPool3d(_PoolNd):
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
global_pool=global_pool,
)
def forward(self, input):
......@@ -236,7 +452,6 @@ class AvgPool3d(_PoolNd):
stride=self.stride,
padding=self.padding,
ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
)
......@@ -268,7 +483,6 @@ class MaxPool1d(_PoolNd):
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
"""Create a ``MaxPool1d`` module.
......@@ -282,8 +496,6 @@ class MaxPool1d(_PoolNd):
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
"""
super(MaxPool1d, self).__init__(
......@@ -291,7 +503,6 @@ class MaxPool1d(_PoolNd):
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
global_pool=global_pool,
)
def forward(self, input):
......@@ -301,7 +512,6 @@ class MaxPool1d(_PoolNd):
stride=self.stride,
padding=self.padding,
ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
)
......@@ -333,7 +543,6 @@ class MaxPool2d(_PoolNd):
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
"""Create a ``MaxPool2d`` module.
......@@ -347,8 +556,6 @@ class MaxPool2d(_PoolNd):
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional
Apply the global pooling or not.
"""
super(MaxPool2d, self).__init__(
......@@ -356,7 +563,6 @@ class MaxPool2d(_PoolNd):
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
global_pool=global_pool,
)
def forward(self, input):
......@@ -366,7 +572,6 @@ class MaxPool2d(_PoolNd):
stride=self.stride,
padding=self.padding,
ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
)
......@@ -398,7 +603,6 @@ class MaxPool3d(_PoolNd):
stride=1,
padding=0,
ceil_mode=False,
global_pool=False,
):
"""Create a ``MaxPool3d`` module.
......@@ -412,8 +616,6 @@ class MaxPool3d(_PoolNd):
The zero padding size.
ceil_mode : bool, optional, default=False
Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
"""
super(MaxPool3d, self).__init__(
......@@ -421,7 +623,6 @@ class MaxPool3d(_PoolNd):
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
global_pool=global_pool,
)
def forward(self, input):
......@@ -431,5 +632,4 @@ class MaxPool3d(_PoolNd):
stride=self.stride,
padding=self.padding,
ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
)
......@@ -22,6 +22,18 @@ import itertools
from dragon.core.util import six
def _get_adaptive_pool_kwargs(input_sizes, output_sizes):
stride, kernel_size = [], []
for input_size, output_size in zip(input_sizes, output_sizes):
if output_size == 1:
stride.append(1)
kernel_size.append(input_size)
else:
stride.append(input_size // output_size)
kernel_size.append(input_size - (output_size - 1) * stride[-1])
return {'kernel_size': kernel_size, 'stride': stride}
def _ntuple(n):
def parse(x):
if isinstance(x, six.collections_abc.Sequence):
......
......@@ -315,12 +315,16 @@ class OneHot(function.Function):
def __init__(self, key, dev, **kwargs):
super(OneHot, self).__init__(key, dev, **kwargs)
self.depth = kwargs.get('depth', 1)
self.on_value = kwargs.get('on_value', 1)
self.off_value = kwargs.get('off_value', 0)
def attributes(self):
return {
'op_type': 'OneHot',
'arguments': {
'depth': self.depth,
'on_value': self.on_value,
'off_value': self.off_value,
},
}
......
......@@ -46,7 +46,7 @@ def argmax(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False
Keep the reduced dimension or not.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -81,7 +81,7 @@ def argmin(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False
Keep the reduced dimension or not.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -174,7 +174,7 @@ def cat(seq, dim=0, out=None):
dim : int, optional
The dim to concatenate.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -197,11 +197,11 @@ def channel_affine(input, weight, bias=None, dim=0, out=None):
weight : dragon.vm.torch.Tensor
The weight tensor.
bias : dragon.vm.torch.Tensor, optional
The optional bias.
The bias tensor.
dim : int, optional, default=0
The start dimension to transform.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -369,7 +369,7 @@ def cumsum(input, dim, out=None):
dim : int
The cumulative dimension.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -429,7 +429,7 @@ def flatten(input, start_dim=0, end_dim=-1, out=None):
end_dim : int, optional, default=-1
The end dimension to flatten.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -465,7 +465,7 @@ def index_select(input, dim, index, out=None):
index : dragon.vm.torch.Tensor
The index tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -523,7 +523,7 @@ def masked_select(input, mask, out=None):
mask : dragon.vm.torch.Tensor
The mask for selecting.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -566,7 +566,7 @@ def max(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False
Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -606,7 +606,7 @@ def mean(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False
Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -646,7 +646,7 @@ def min(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False
Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -721,7 +721,7 @@ def nonzero(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -732,7 +732,7 @@ def nonzero(input, out=None):
return _functions.NonZero.instantiate(input.device).apply(input, out)
def one_hot(input, depth):
def one_hot(input, depth, on_value=1, off_value=0):
r"""Return the one-hot representation for input.
.. math::
......@@ -748,6 +748,10 @@ def one_hot(input, depth):
The input tensor.
depth : int
The depth of channels.
on_value : int, optional, default=1
The value for equal branch.
off_value : int, optional, default=0
The value for not-equal branch.
Returns
-------
......@@ -755,7 +759,12 @@ def one_hot(input, depth):
The output tensor.
"""
return _functions.OneHot.instantiate(input.device, depth=depth).apply(input)
return _functions.OneHot.instantiate(
input.device,
depth=depth,
on_value=on_value,
off_value=off_value,
).apply(input)
def permute(input, dims):
......@@ -812,7 +821,7 @@ def reshape(input, shape, out=None):
shape : Sequence[int]
The new shape.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -986,7 +995,7 @@ def stack(seq, dim=0, out=None):
dim : int, optional, default=0
The dim to stack.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -1030,7 +1039,7 @@ def sum(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False
Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......
......@@ -59,7 +59,7 @@ def arange(
step : number, optional, default=1
The spacing between two elements.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='int64'
The optional data type.
device : dragon.vm.torch.device, optional
......@@ -113,7 +113,7 @@ def eye(
m : int, optional
The number output cols.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='float32'
The optional data type.
device : dragon.vm.torch.device, optional
......@@ -175,7 +175,7 @@ def full(
fill_value : number
The scalar to fill.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='int64'
The optional data type.
device : dragon.vm.torch.device, optional
......@@ -216,7 +216,7 @@ def full_like(
fill_value : number
The scalar to fill.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='int64'
The optional data type.
device : dragon.vm.torch.device, optional
......@@ -268,7 +268,7 @@ def linspace(
steps : int, optional, default=100
The number of values to generate.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='int64'
The optional data type.
dim : int, optional, default=0
......@@ -326,7 +326,7 @@ def ones(*size, **kwargs):
size : int...
The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='float32'
The optional data type.
device : dragon.vm.torch.device, optional
......@@ -378,7 +378,7 @@ def rand(*size, **kwargs):
size : int...
The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='float32'
The optional data type.
device : dragon.vm.torch.device, optional
......@@ -404,7 +404,7 @@ def randn(*size, **kwargs):
size : int...
The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='float32'
The optional data type.
device : dragon.vm.torch.device, optional
......@@ -436,7 +436,7 @@ def randperm(n, out=None, dtype='int64', device=None, requires_grad=False):
n: number
The end of interval.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='int64'
The optional data type.
device : dragon.vm.torch.device, optional
......@@ -479,7 +479,7 @@ def zeros(*size, **kwargs):
size : int...
The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
dtype : str, optional, default='float32'
The optional data type.
device : dragon.vm.torch.device, optional
......
......@@ -77,36 +77,42 @@ class Clip(function.Function):
return self.dispatch([input], [self.alloc(out)])
class UnaryFunc(function.Function):
"""Unary function."""
class Gemm(function.Function):
"""Gemm function."""
def __init__(self, key, dev, **kwargs):
super(UnaryFunc, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
super(Gemm, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.0)
self.beta = kwargs.get('beta', 1.0)
self.transA = kwargs.get('transA', False)
self.transB = kwargs.get('transB', False)
def attributes(self):
return {'op_type': self.op_type, 'arguments': {}}
return {
'op_type': 'Gemm',
'arguments': {
'axis': -1,
'alpha': self.alpha,
'beta': self.beta,
'transA': self.transA,
'transB': self.transB,
},
}
def forward(self, input, out=None):
return self.dispatch([input], [self.alloc(out)])
def forward(self, mat1, mat2, mat3=None, out=None):
inputs = [mat1, mat2] + ([mat3] if mat3 else [])
return self.dispatch(inputs, [self.alloc(out)])
class MatMul(function.Function):
"""MatMul function."""
class UnaryFunc(function.Function):
"""Unary function."""
def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False)
self.transpose_b = kwargs.get('transpose_b', False)
super(UnaryFunc, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '')
def attributes(self):
return {
'op_type': 'MatMul',
'arguments': {
'transA': self.transpose_a,
'transB': self.transpose_b,
},
}
return {'op_type': self.op_type, 'arguments': {}}
def forward(self, mat1, mat2, out=None):
return self.dispatch([mat1, mat2], [self.alloc(out)])
def forward(self, input, out=None):
return self.dispatch([input], [self.alloc(out)])
......@@ -34,7 +34,7 @@ def abs(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -59,7 +59,7 @@ def axpby(input, alpha=1., beta=1., out=None):
beta : float, optional, default=1.
The value to :math:`\beta`.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -87,7 +87,7 @@ def add(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to add.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -98,6 +98,74 @@ def add(input, other, out=None):
return _binary_func(input, other, 'Add', out)
def addmm(input, mat1, mat2, beta=1, alpha=1, out=None):
r"""Add input to the result of matrix-matrix multiplication.
.. math:: \text{out} = \alpha (\text{mat1} \times \text{mat2}) + \beta \text{input}
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
mat1 : dragon.vm.torch.Tensor
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return _functions.Gemm \
.instantiate(
input.device,
alpha=float(alpha),
beta=float(beta),
).apply(mat1, mat2, input, out=out)
def baddbmm(input, batch1, batch2, beta=1, alpha=1, out=None):
r"""Add input to the result of batched matrix-matrix multiplication.
.. math::
\text{out}_{i} = \alpha (\text{mat1}_{i} \times \text{mat2}_{i}) +
\beta \text{input}_{i}
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
input1 = bmm(batch1, batch2)
input2 = input * beta if beta != 1 else input
input1 = input1 * alpha if alpha != 1 else input1
return add(input1, input2, out)
def bitwise_not(input, out=None):
r"""Compute the element-wise NOT bitwise operation.
......@@ -120,7 +188,7 @@ def bitwise_not(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -152,7 +220,7 @@ def bitwise_xor(input, other, out=None):
other : dragon.vm.torch.Tensor
The second input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -163,6 +231,31 @@ def bitwise_xor(input, other, out=None):
return _binary_func(input, other, 'Sub', out)
def bmm(input, mat2, out=None):
r"""Compute the batched matrix-matrix multiplication.
.. math:: \text{out}_{i} = \text{input}_{i} \times \text{mat2}_{i}
Parameters
----------
input : dragon.vm.torch.Tensor
The first batch of matrices.
mat2 : dragon.vm.torch.Tensor
The second batch of matrices.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return _functions.BinaryFunc \
.instantiate(input.device, op_type='MatMul') \
.apply(input, mat2, out=out)
def ceil(input, out=None):
r"""Compute the smallest integer not less than input.
......@@ -180,7 +273,7 @@ def ceil(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -205,7 +298,7 @@ def clamp(input, min=None, max=None, out=None):
max : number, optional
The max value.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -238,7 +331,7 @@ def cos(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -261,7 +354,7 @@ def div(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to divide.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -284,7 +377,7 @@ def eq(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -305,7 +398,7 @@ def exp(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -333,7 +426,7 @@ def floor(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -356,7 +449,7 @@ def ge(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -379,7 +472,7 @@ def gt(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -454,7 +547,7 @@ def le(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -475,7 +568,7 @@ def log(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -523,7 +616,7 @@ def lt(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -534,6 +627,60 @@ def lt(input, other, out=None):
return _binary_func(input, other, 'Less', out)
def matmul(input, other, out=None):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{input} \times \text{other}
The behavior depends on the shape of input tensors:
* If both tensors are 1d, computes the vector product.
* If tensors are 1d and >=2d, computes the vector-matrix multiplication.
* If tensors are >=2d and 1d, computes the matrix-vector multiplication.
* If both tensors are >= 2d, computes the matrix-matrix multiplication.
* If one tensor is >= 3d, applies batching and broadcasting to the computation.
Examples:
```python
# Vector x Vector
a = torch.ones(2)
b = torch.ones(2)
print(torch.matmul(a, b))
# Vector x Matrix
a = torch.ones(2)
b = torch.ones(2, 3)
print(torch.matmul(a, b))
# Matrix x Vector
a = torch.ones(3, 2)
b = torch.ones(2)
print(torch.matmul(a, b))
# Matrix x Matrix
a = torch.ones(2, 3)
b = torch.ones(3, 2)
print(torch.matmul(a, b))
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
other : dragon.vm.torch.Tensor
The tensor to multiply.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
return _functions.BinaryFunc \
.instantiate(input.device, op_type='MatMul') \
.apply(input, other, out=out)
def maximum(input, other, out=None):
r"""Compute the maximum value of inputs.
......@@ -546,7 +693,7 @@ def maximum(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The second input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -575,7 +722,7 @@ def minimum(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The second input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -586,13 +733,11 @@ def minimum(input, other, out=None):
input, other = utils \
.remove_binary_scalar(input, other)
return _functions.BinaryFunc \
.instantiate(
input.device,
op_type='Minimum',
).apply(input, other, out)
.instantiate(input.device, op_type='Minimum') \
.apply(input, other, out)
def mm(input, mat2, transpose_a=False, transpose_b=False, out=None):
def mm(input, mat2, out=None):
r"""Compute the matrix-matrix multiplication.
.. math:: \text{out} = \text{input} \times \text{mat2}
......@@ -603,12 +748,8 @@ def mm(input, mat2, transpose_a=False, transpose_b=False, out=None):
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
transpose_a : bool, optional, default=False
Transpose the first matrix before computation or not.
transpose_b : bool, optional, default=False
Transpose the second matrix before computation or not.
out : dragon.vm.torch.Tensor, optional
The optional output.
The output tensor.
Returns
-------
......@@ -616,12 +757,9 @@ def mm(input, mat2, transpose_a=False, transpose_b=False, out=None):
The output tensor.
"""
return _functions.MatMul \
.instantiate(
utils.unify_devices([input, mat2]),
transpose_a=transpose_a,
transpose_b=transpose_b,
).apply(input, mat2, out)
return _functions.Gemm \
.instantiate(input.device) \
.apply(input, mat2, out=out)
def mul(input, other, out=None):
......@@ -636,7 +774,7 @@ def mul(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to multiply.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -659,7 +797,7 @@ def ne(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -680,7 +818,7 @@ def neg(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -712,7 +850,7 @@ def pow(input, exponent, out=None):
exponent : Union[dragon.vm.torch.Tensor, number]
The exponent tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -740,7 +878,7 @@ def reciprocal(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -768,7 +906,7 @@ def round(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -796,7 +934,7 @@ def rsqrt(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -830,7 +968,7 @@ def sign(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -858,7 +996,7 @@ def sin(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -886,7 +1024,7 @@ def sqrt(input, out=None):
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......@@ -909,7 +1047,7 @@ def sub(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number]
The tensor to subtract.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The output tensor.
Returns
-------
......
......@@ -85,6 +85,35 @@ def add_(self, other):
return math_funcs.add(self, other, self)
def addmm(self, mat1, mat2, beta=1, alpha=1):
r"""Add the result of matrix-matrix multiplication.
.. math:: \text{out} = \alpha (\text{mat1} \times \text{mat2}) + \beta \text{self}
Parameters
----------
mat1 : dragon.vm.torch.Tensor
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.addmm(...)`_
"""
return math_funcs.addmm(self, mat1, mat2, beta=beta, alpha=alpha)
def argmax(self, dim=None, keepdim=False):
"""Return the index of maximum elements.
......@@ -154,6 +183,71 @@ def argsort(self, dim=-1, descending=False):
return array_funcs.argsort(self, dim, descending)
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{out}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
return math_funcs.baddbmm(self, batch1, batch2, beta=beta, alpha=alpha)
def baddbmm_(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{self}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
return math_funcs.baddbmm(
self, batch1, batch2,
beta=beta, alpha=alpha, out=self,
)
def backward(self, gradient=None, retain_graph=False):
"""Compute the derivatives of this tensor w.r.t. graph leaves.
......@@ -254,6 +348,29 @@ def bitwise_xor_(self, other):
return math_funcs.bitwise_xor(self, other, self)
def bmm(self, batch2):
r"""Compute the batched matrix multiplication.
.. math:: \text{out}_{i} = \text{self}_{i} \times \text{batch2}_{i}
Parameters
----------
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.bmm(...)`_
"""
return math_funcs.bmm(self, batch2)
def bool(self):
"""Return a bool tensor with the same data.
......@@ -719,50 +836,6 @@ def floor_(self):
return math_funcs.floor(self, self)
def new_full(
self,
size,
fill_value,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : Sequence[int]
The size of output tensor.
fill_value : number
The scalar to fill.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.full(...)`_
"""
return init_funcs.full(
size,
fill_value,
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def ge(self, other):
r"""Compute the element-wise greater-equal comparison.
......@@ -1104,6 +1177,29 @@ def masked_select(self, mask):
return array_funcs.masked_select(self, mask)
def matmul(self, tensor2):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{self} \times \text{tensor2}
Parameters
----------
tensor2 : dragon.vm.torch.Tensor
The tensor to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.matmul(...)`_
"""
return math_funcs.matmul(self, tensor2)
def max(self, dim=None, keepdim=False):
"""Compute the max value of elements along the given dimension.
......@@ -1383,6 +1479,50 @@ def neg_(self):
return math_funcs.neg(self, self)
def new_full(
self,
size,
fill_value,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : Sequence[int]
The size of output tensor.
fill_value : number
The scalar to fill.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.full(...)`_
"""
return init_funcs.full(
size,
fill_value,
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def nonzero(self):
r"""Return the index of non-zero elements.
......@@ -1735,7 +1875,7 @@ def sort(self, dim=-1, descending=False):
def split(self, split_size_or_sections, dim=0):
"""Return the splited chunks along the given dimension.
"""Return the split chunks along the given dimension.
Parameters
----------
......@@ -2132,14 +2272,18 @@ def _process_index(item):
Tensor.abs = abs
Tensor.add = add
Tensor.add_ = add_
Tensor.addmm = addmm
Tensor.argmax = argmax
Tensor.argmin = argmin
Tensor.argsort = argsort
Tensor.backward = backward
Tensor.baddbmm = baddbmm
Tensor.baddbmm_ = baddbmm_
Tensor.bitwise_not = bitwise_not
Tensor.bitwise_not_ = bitwise_not_
Tensor.bitwise_xor = bitwise_xor
Tensor.bitwise_xor_ = bitwise_xor_
Tensor.bmm = bmm
Tensor.bool = bool
Tensor.bool_ = bool_
Tensor.byte = byte
......@@ -2184,6 +2328,7 @@ Tensor.logsumexp = logsumexp
Tensor.lt = lt
Tensor.masked_fill_ = masked_fill_
Tensor.masked_select = masked_select
Tensor.matmul = matmul
Tensor.max = max
Tensor.maximum = maximum
Tensor.mean = mean
......
......@@ -270,6 +270,33 @@ class Tensor(object):
"""
def addmm(self, mat1, mat2, beta=1, alpha=1):
r"""Add the result of matrix-matrix multiplication.
.. math:: \text{out} = \alpha (\text{mat1} \times \text{mat2}) + \beta \text{self}
Parameters
----------
mat1 : dragon.vm.torch.Tensor
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.addmm(...)`_
"""
def argmax(self, dim=None, keepdim=False):
"""Return the index of maximum elements.
......@@ -345,6 +372,64 @@ class Tensor(object):
"""
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{out}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
def baddbmm_(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{self}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
def bitwise_not(self):
r"""Compute the element-wise NOT bitwise operation.
......@@ -419,6 +504,27 @@ class Tensor(object):
"""
def bmm(self, batch2):
r"""Compute the batched matrix multiplication.
.. math:: \text{out}_{i} = \text{self}_{i} \times \text{batch2}_{i}
Parameters
----------
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.bmm(...)`_
"""
def bool(self):
"""Return a bool tensor with the same data.
......@@ -1192,6 +1298,27 @@ class Tensor(object):
"""
def matmul(self, tensor2):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{self} \times \text{tensor2}
Parameters
----------
tensor2 : dragon.vm.torch.Tensor
The tensor to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.matmul(...)`_
"""
def max(self, dim=None, keepdim=False):
"""Compute the max value of elements along the given dimension.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!