Commit 6bfe3e73 by Ting PAN

Reimplement the general matrix multiplication

Summary:
This commit generalizes the fully-connected operation into GEMM,
and enhances the matmul operation via batched Dot, GEMV and GEMM.
New representations and attributes have been consistent with ONNX.
1 parent 73ed1b96
Showing with 3106 additions and 1433 deletions
...@@ -313,8 +313,8 @@ class InnerProduct(Layer): ...@@ -313,8 +313,8 @@ class InnerProduct(Layer):
param = layer_param.inner_product_param param = layer_param.inner_product_param
self.arguments = { self.arguments = {
'axis': param.axis, 'axis': param.axis,
'out_channels': param.num_output, 'n': param.num_output,
'transpose_w': not param.transpose, 'transpose_b': not param.transpose,
} }
self.add_blob(filler=self.get_filler(param, 'weight_filler')) self.add_blob(filler=self.get_filler(param, 'weight_filler'))
if param.bias_term: if param.bias_term:
...@@ -322,7 +322,7 @@ class InnerProduct(Layer): ...@@ -322,7 +322,7 @@ class InnerProduct(Layer):
def __call__(self, bottom): def __call__(self, bottom):
inputs = [bottom] + [blob['data'] for blob in self._blobs] inputs = [bottom] + [blob['data'] for blob in self._blobs]
return math_ops.fully_connected(inputs, **self.arguments) return math_ops.gemm(inputs, **self.arguments)
class Input(Layer): class Input(Layer):
...@@ -409,7 +409,7 @@ class Normalize(Layer): ...@@ -409,7 +409,7 @@ class Normalize(Layer):
def __call__(self, bottom): def __call__(self, bottom):
norm_out = [normalization_ops.lp_normalize(bottom, **self.l2norm_arguments)] norm_out = [normalization_ops.lp_normalize(bottom, **self.l2norm_arguments)]
norm_out += [blob['data'] for blob in self._blobs] norm_out += [blob['data'] for blob in self._blobs]
return math_ops.affine(norm_out, **self.affine_arguments) return array_ops.channel_affine(norm_out, **self.affine_arguments)
class Permute(Layer): class Permute(Layer):
...@@ -583,7 +583,7 @@ class Scale(Layer): ...@@ -583,7 +583,7 @@ class Scale(Layer):
def __call__(self, bottom): def __call__(self, bottom):
inputs = [bottom] + [blob['data'] for blob in self._blobs] inputs = [bottom] + [blob['data'] for blob in self._blobs]
return math_ops.affine(inputs, **self.arguments) return array_ops.channel_affine(inputs, **self.arguments)
class Slice(Layer): class Slice(Layer):
......
...@@ -48,6 +48,9 @@ dragon.math ...@@ -48,6 +48,9 @@ dragon.math
`floor(...) <math/floor.html>`_ `floor(...) <math/floor.html>`_
: Compute the largest integer not greater than input. : Compute the largest integer not greater than input.
`gemm(...) <math/gemm.html>`_
: Compute the general matrix multiplication.
`greater(...) <math/greater.html>`_ `greater(...) <math/greater.html>`_
: Compute the element-wise greater comparison. : Compute the element-wise greater comparison.
...@@ -158,6 +161,7 @@ dragon.math ...@@ -158,6 +161,7 @@ dragon.math
math/equal math/equal
math/exp math/exp
math/floor math/floor
math/gemm
math/greater math/greater
math/greater_equal math/greater_equal
math/is_inf math/is_inf
......
fully_connected gemm
=============== ====
.. autofunction:: dragon.nn.fully_connected .. autofunction:: dragon.math.gemm
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.nn."; content: "dragon.math.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -74,9 +74,6 @@ dragon.nn ...@@ -74,9 +74,6 @@ dragon.nn
: Apply the exponential linear unit. : Apply the exponential linear unit.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_. `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`fully_connected(...) <nn/fully_connected.html>`_
: Compute the dense matrix multiplication along the given axes.
`group_norm(...) <nn/group_norm.html>`_ `group_norm(...) <nn/group_norm.html>`_
: Apply the group normalization. : Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_. `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
...@@ -167,7 +164,6 @@ dragon.nn ...@@ -167,7 +164,6 @@ dragon.nn
nn/drop_block2d nn/drop_block2d
nn/drop_path nn/drop_path
nn/elu nn/elu
nn/fully_connected
nn/group_norm nn/group_norm
nn/hardsigmoid nn/hardsigmoid
nn/hardswish nn/hardswish
......
...@@ -79,7 +79,7 @@ Name Supported Reference ...@@ -79,7 +79,7 @@ Name Supported Reference
`Gather`_ |v| :func:`dragon.index_select` `Gather`_ |v| :func:`dragon.index_select`
`GatherElements`_ `GatherElements`_
`GatherND`_ `GatherND`_
`Gemm`_ |v| :func:`dragon.nn.fully_connected` `Gemm`_ |v| :func:`dragon.math.gemm`
`GlobalAveragePool`_ |v| :func:`dragon.nn.pool2d` `GlobalAveragePool`_ |v| :func:`dragon.nn.pool2d`
`GlobalLpPool`_ `GlobalLpPool`_
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d` `GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d`
......
...@@ -36,6 +36,9 @@ vm.torch ...@@ -36,6 +36,9 @@ vm.torch
`add(...) <torch/add.html>`_ `add(...) <torch/add.html>`_
: Compute the element-wise addition. : Compute the element-wise addition.
`addmm(...) <torch/addmm.html>`_
: Add input to the result of matrix-matrix multiplication.
`arange(...) <torch/arange.html>`_ `arange(...) <torch/arange.html>`_
: Return a tensor of evenly spaced values within a interval. : Return a tensor of evenly spaced values within a interval.
...@@ -51,12 +54,18 @@ vm.torch ...@@ -51,12 +54,18 @@ vm.torch
`axpby(...) <torch/axpby.html>`_ `axpby(...) <torch/axpby.html>`_
: Compute the element-wise addition from input to output. : Compute the element-wise addition from input to output.
`baddbmm(...) <torch/baddbmm.html>`_
: Add input to the result of batched matrix-matrix multiplication.
`bitwise_not(...) <torch/bitwise_not.html>`_ `bitwise_not(...) <torch/bitwise_not.html>`_
: Compute the element-wise NOT bitwise operation. : Compute the element-wise NOT bitwise operation.
`bitwise_xor(...) <torch/bitwise_xor.html>`_ `bitwise_xor(...) <torch/bitwise_xor.html>`_
: Compute the element-wise XOR bitwise operation. : Compute the element-wise XOR bitwise operation.
`bmm(...) <torch/bmm.html>`_
: Compute the batched matrix-matrix multiplication.
`cat(...) <torch/cat.html>`_ `cat(...) <torch/cat.html>`_
: Concatenate the inputs along the given dimension. : Concatenate the inputs along the given dimension.
...@@ -148,6 +157,9 @@ vm.torch ...@@ -148,6 +157,9 @@ vm.torch
`masked_select(...) <torch/logsumexp.html>`_ `masked_select(...) <torch/logsumexp.html>`_
: Select the input elements where mask is 1. : Select the input elements where mask is 1.
`matmul(...) <torch/matmul.html>`_
: Compute the matrix multiplication.
`max(...) <torch/max.html>`_ `max(...) <torch/max.html>`_
: Compute the max value of elements along the given dimension. : Compute the max value of elements along the given dimension.
...@@ -281,13 +293,16 @@ vm.torch ...@@ -281,13 +293,16 @@ vm.torch
torch/Tensor_ torch/Tensor_
torch/abs torch/abs
torch/add torch/add
torch/addmm
torch/arange torch/arange
torch/argmax torch/argmax
torch/argmin torch/argmin
torch/argsort torch/argsort
torch/axpby torch/axpby
torch/baddbmm
torch/bitwise_not torch/bitwise_not
torch/bitwise_xor torch/bitwise_xor
torch/bmm
torch/cat torch/cat
torch/ceil torch/ceil
torch/channel_affine torch/channel_affine
...@@ -321,6 +336,7 @@ vm.torch ...@@ -321,6 +336,7 @@ vm.torch
torch/logsumexp torch/logsumexp
torch/lt torch/lt
torch/masked_select torch/masked_select
torch/matmul
torch/max torch/max
torch/maximum torch/maximum
torch/mean torch/mean
......
...@@ -53,6 +53,10 @@ add\_ ...@@ -53,6 +53,10 @@ add\_
##### #####
.. automethod:: dragon.vm.torch.Tensor.add_ .. automethod:: dragon.vm.torch.Tensor.add_
addmm
#####
.. automethod:: dragon.vm.torch.Tensor.addmm
argmax argmax
###### ######
.. automethod:: dragon.vm.torch.Tensor.argmax .. automethod:: dragon.vm.torch.Tensor.argmax
...@@ -69,6 +73,14 @@ backward ...@@ -69,6 +73,14 @@ backward
######## ########
.. automethod:: dragon.vm.torch.Tensor.backward .. automethod:: dragon.vm.torch.Tensor.backward
baddbmm
#######
.. automethod:: dragon.vm.torch.Tensor.baddbmm
baddbmm\_
#########
.. automethod:: dragon.vm.torch.Tensor.baddbmm_
bitwise_not bitwise_not
########### ###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_not .. automethod:: dragon.vm.torch.Tensor.bitwise_not
...@@ -85,6 +97,10 @@ bitwise_xor\_ ...@@ -85,6 +97,10 @@ bitwise_xor\_
############# #############
.. automethod:: dragon.vm.torch.Tensor.bitwise_xor_ .. automethod:: dragon.vm.torch.Tensor.bitwise_xor_
bmm
###
.. automethod:: dragon.vm.torch.Tensor.bmm
bool bool
#### ####
.. automethod:: dragon.vm.torch.Tensor.bool .. automethod:: dragon.vm.torch.Tensor.bool
...@@ -285,6 +301,14 @@ masked_fill\_ ...@@ -285,6 +301,14 @@ masked_fill\_
############# #############
.. automethod:: dragon.vm.torch.Tensor.masked_fill_ .. automethod:: dragon.vm.torch.Tensor.masked_fill_
masked_select
#############
.. automethod:: dragon.vm.torch.Tensor.masked_select
matmul
######
.. automethod:: dragon.vm.torch.Tensor.matmul
max max
### ###
.. automethod:: dragon.vm.torch.Tensor.max .. automethod:: dragon.vm.torch.Tensor.max
...@@ -293,10 +317,6 @@ maximum ...@@ -293,10 +317,6 @@ maximum
####### #######
.. automethod:: dragon.vm.torch.Tensor.maximum .. automethod:: dragon.vm.torch.Tensor.maximum
masked_select
#############
.. automethod:: dragon.vm.torch.Tensor.masked_select
mean mean
#### ####
.. automethod:: dragon.vm.torch.Tensor.mean .. automethod:: dragon.vm.torch.Tensor.mean
...@@ -535,11 +555,14 @@ zero\_ ...@@ -535,11 +555,14 @@ zero\_
.. _torch.abs(...): abs.html .. _torch.abs(...): abs.html
.. _torch.add(...): add.html .. _torch.add(...): add.html
.. _torch.addmm(...): addmm.html
.. _torch.argmax(...): argmax.html .. _torch.argmax(...): argmax.html
.. _torch.argmin(...): argmin.html .. _torch.argmin(...): argmin.html
.. _torch.argsort(...): argsort.html .. _torch.argsort(...): argsort.html
.. _torch.baddbmm(...): baddbmm.html
.. _torch.bitwise_not(...): bitwise_not.html .. _torch.bitwise_not(...): bitwise_not.html
.. _torch.bitwise_xor(...): bitwise_xor.html .. _torch.bitwise_xor(...): bitwise_xor.html
.. _torch.bmm(...): bmm.html
.. _torch.ceil(...): ceil.html .. _torch.ceil(...): ceil.html
.. _torch.clamp(...): clamp.html .. _torch.clamp(...): clamp.html
.. _torch.cos(...): cos.html .. _torch.cos(...): cos.html
...@@ -557,6 +580,7 @@ zero\_ ...@@ -557,6 +580,7 @@ zero\_
.. _torch.isnan(...): isnan.html .. _torch.isnan(...): isnan.html
.. _torch.le(...): le.html .. _torch.le(...): le.html
.. _torch.lt(...): lt.html .. _torch.lt(...): lt.html
.. _torch.matmul(...): matmul.html
.. _torch.max(...): max.html .. _torch.max(...): max.html
.. _torch.maximum(...): maximum.html .. _torch.maximum(...): maximum.html
.. _torch.min(...): min.html .. _torch.min(...): min.html
......
addmm
=====
.. autofunction:: dragon.vm.torch.addmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
baddbmm
=======
.. autofunction:: dragon.vm.torch.baddbmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
bmm
===
.. autofunction:: dragon.vm.torch.bmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
matmul
======
.. autofunction:: dragon.vm.torch.matmul
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
...@@ -6,6 +6,24 @@ vm.torch.nn ...@@ -6,6 +6,24 @@ vm.torch.nn
Classes Classes
------- -------
`class AdaptiveAvgPool1d <nn/AdaptiveAvgPool1d.html>`_
: Apply the 1d adaptive average pooling.
`class AdaptiveAvgPool2d <nn/AdaptiveAvgPool2d.html>`_
: Apply the 2d adaptive average pooling.
`class AdaptiveAvgPool3d <nn/AdaptiveAvgPool3d.html>`_
: Apply the 3d adaptive average pooling.
`class AdaptiveMaxPool1d <nn/AdaptiveMaxPool1d.html>`_
: Apply the 1d adaptive max pooling.
`class AdaptiveMaxPool2d <nn/AdaptiveMaxPool2d.html>`_
: Apply the 2d adaptive max pooling.
`class AdaptiveMaxPool3d <nn/AdaptiveMaxPool3d.html>`_
: Apply the 3d adaptive max pooling.
`class AffineChannel <nn/AffineChannel.html>`_ `class AffineChannel <nn/AffineChannel.html>`_
: Apply affine transformation along the channels. : Apply affine transformation along the channels.
...@@ -238,6 +256,12 @@ vm.torch.nn ...@@ -238,6 +256,12 @@ vm.torch.nn
.. toctree:: .. toctree::
:hidden: :hidden:
nn/AdaptiveAvgPool1d
nn/AdaptiveAvgPool2d
nn/AdaptiveAvgPool3d
nn/AdaptiveMaxPool1d
nn/AdaptiveMaxPool2d
nn/AdaptiveMaxPool3d
nn/AffineChannel nn/AffineChannel
nn/AvgPool1d nn/AvgPool1d
nn/AvgPool2d nn/AvgPool2d
......
AdaptiveAvgPool1d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool1d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool1d.__init__
.. _torch.nn.functional.adaptive_avg_pool1d(...): functional/adaptive_avg_pool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveAvgPool2d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool2d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool2d.__init__
.. _torch.nn.functional.adaptive_avg_pool2d(...): functional/adaptive_avg_pool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveAvgPool3d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool3d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool3d.__init__
.. _torch.nn.functional.adaptive_avg_pool3d(...): functional/adaptive_avg_pool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool1d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool1d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool1d.__init__
.. _torch.nn.functional.adaptive_max_pool1d(...): functional/adaptive_max_pool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool2d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool2d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool2d.__init__
.. _torch.nn.functional.adaptive_max_pool2d(...): functional/adaptive_max_pool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool3d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool3d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool3d.__init__
.. _torch.nn.functional.adaptive_max_pool3d(...): functional/adaptive_max_pool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
...@@ -6,6 +6,24 @@ vm.torch.nn.functional ...@@ -6,6 +6,24 @@ vm.torch.nn.functional
Functions Functions
--------- ---------
`adaptive_avg_pool1d(...) <functional/adaptive_avg_pool1d.html>`_
: Apply the 1d adaptive average pooling to input.
`adaptive_avg_pool2d(...) <functional/adaptive_avg_pool2d.html>`_
: Apply the 2d adaptive average pooling to input.
`adaptive_avg_pool3d(...) <functional/adaptive_avg_pool3d.html>`_
: Apply the 3d adaptive average pooling to input.
`adaptive_max_pool1d(...) <functional/adaptive_max_pool1d.html>`_
: Apply the 1d adaptive max pooling to input.
`adaptive_max_pool2d(...) <functional/adaptive_max_pool2d.html>`_
: Apply the 2d adaptive max pooling to input.
`adaptive_max_pool3d(...) <functional/adaptive_max_pool3d.html>`_
: Apply the 3d adaptive max pooling to input.
`avg_pool1d(...) <functional/avg_pool1d.html>`_ `avg_pool1d(...) <functional/avg_pool1d.html>`_
: Apply the 1d average pooling to input. : Apply the 1d average pooling to input.
...@@ -167,6 +185,12 @@ vm.torch.nn.functional ...@@ -167,6 +185,12 @@ vm.torch.nn.functional
.. toctree:: .. toctree::
:hidden: :hidden:
functional/adaptive_avg_pool1d
functional/adaptive_avg_pool2d
functional/adaptive_avg_pool3d
functional/adaptive_max_pool1d
functional/adaptive_max_pool2d
functional/adaptive_max_pool3d
functional/avg_pool1d functional/avg_pool1d
functional/avg_pool2d functional/avg_pool2d
functional/avg_pool3d functional/avg_pool3d
......
adaptive_avg_pool1d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool1d
.. _torch.nn.AdaptiveAvgPool1d(...): ../AdaptiveAvgPool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_avg_pool2d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool2d
.. _torch.nn.AdaptiveAvgPool2d(...): ../AdaptiveAvgPool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_avg_pool3d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool3d
.. _torch.nn.AdaptiveAvgPool3d(...): ../AdaptiveAvgPool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool1d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool1d
.. _torch.nn.AdaptiveMaxPool1d(...): ../AdaptiveMaxPool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool2d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool2d
.. _torch.nn.AdaptiveMaxPool2d(...): ../AdaptiveMaxPool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool3d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool3d
.. _torch.nn.AdaptiveMaxPool3d(...): ../AdaptiveMaxPool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
...@@ -185,7 +185,6 @@ const Map<string, Map<string, string>>& ONNXBackend::get_node_renamed_attrs() ...@@ -185,7 +185,6 @@ const Map<string, Map<string, string>>& ONNXBackend::get_node_renamed_attrs()
const { const {
const static Map<string, Map<string, string>> kPerNodeRenamedAttrs = { const static Map<string, Map<string, string>> kPerNodeRenamedAttrs = {
{"DepthToSpace", {{"blocksize", "block_size"}}}, {"DepthToSpace", {{"blocksize", "block_size"}}},
{"Gemm", {{"transB", "transW"}}},
{"RoiAlign", {"RoiAlign",
{ {
{"output_height", "pooled_h"}, {"output_height", "pooled_h"},
......
...@@ -180,19 +180,7 @@ ONNXImporterReturns ONNXBackend::GemmImporter( ...@@ -180,19 +180,7 @@ ONNXImporterReturns ONNXBackend::GemmImporter(
ONNXNode* onnx_node, ONNXNode* onnx_node,
const ConversionContext& ctx) { const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes; auto& attributes = onnx_node->attributes;
auto alpha = attributes.get<float>("alpha", 1.f); attributes.AddRewrittenAttribute("axis")->set_i(-1);
auto beta = attributes.get<float>("beta", 1.f);
auto trans_a = attributes.get<int64_t>("transA", 0L);
// Remove the unsupported attributes
if (alpha != 1.f || beta != 1.f) {
LOG(FATAL) << "alpha/beta can not be set currently.";
}
if (trans_a) {
LOG(FATAL) << "Tranposed A is not supported currently.";
}
attributes.remove("alpha");
attributes.remove("beta");
attributes.remove("transA");
return GenericImporter(onnx_node, ctx); return GenericImporter(onnx_node, ctx);
} }
......
...@@ -98,7 +98,6 @@ DECLARE_ELEMENTWISE_OP(SignGradient); ...@@ -98,7 +98,6 @@ DECLARE_ELEMENTWISE_OP(SignGradient);
DECLARE_ELEMENTWISE_OP(SinGradient); DECLARE_ELEMENTWISE_OP(SinGradient);
DECLARE_ELEMENTWISE_OP(SqrtGradient); DECLARE_ELEMENTWISE_OP(SqrtGradient);
DECLARE_ELEMENTWISE_OP(SquareGradient); DECLARE_ELEMENTWISE_OP(SquareGradient);
// Binary ElementwiseOp // Binary ElementwiseOp
DECLARE_ELEMENTWISE_OP(Add); DECLARE_ELEMENTWISE_OP(Add);
DECLARE_ELEMENTWISE_OP(Sub); DECLARE_ELEMENTWISE_OP(Sub);
...@@ -122,7 +121,6 @@ DECLARE_ELEMENTWISE_OP(PowGradient); ...@@ -122,7 +121,6 @@ DECLARE_ELEMENTWISE_OP(PowGradient);
DECLARE_ELEMENTWISE_OP(DotGradient); DECLARE_ELEMENTWISE_OP(DotGradient);
DECLARE_ELEMENTWISE_OP(MinimumGradient); DECLARE_ELEMENTWISE_OP(MinimumGradient);
DECLARE_ELEMENTWISE_OP(MaximumGradient); DECLARE_ELEMENTWISE_OP(MaximumGradient);
#undef DECLARE_ELEMENTWISE_OP #undef DECLARE_ELEMENTWISE_OP
} // namespace dragon } // namespace dragon
......
#include "dragon/operators/math/fully_connected_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/filler.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void FullyConnectedOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X);
// Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N;
if (out_channels_ <= 0) {
// Infer the "N" from the weights shape
N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString();
} else {
// Use a fixed "N" from the argument
N = out_channels_;
}
vec64_t Y_dims(axis + 1);
for (int i = 0; i < axis + 1; i++) {
Y_dims[i] = i < axis ? X.dim(i) : N;
}
if (transW_ > 0) {
TENSOR_FILL(W, vec64_t({N, K}));
CHECK(W.ndim() == 2 && W.dim(1) == K)
<< "\nWeights dimensions should be [N, K].\n"
<< "Got X as (" << M << ", " << K << "), "
<< "and W as " << W.DimString();
} else {
TENSOR_FILL(W, vec64_t({K, N}));
CHECK(W.ndim() == 2 && W.dim(0) == K)
<< "\nWeights dimensions should be [K, N].\n"
<< "Got X as (" << M << ", " << K << "), "
<< "and W as " << W.DimString();
}
math::Gemm(
CblasNoTrans,
(CBLAS_TRANSPOSE)transW_,
M,
N,
K,
1.f,
X.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
if (InputSize() > 2) {
TENSOR_FILL(Input(2), vec64_t({N}));
kernel::BiasAdd(
M,
1,
N,
Y->template data<T, Context>(),
Input(2).template data<T, Context>(),
Y->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
void FullyConnectedOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void FullyConnectedGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
CANONICALIZE_AXIS_WITH_TENSOR(X);
// Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N;
if (out_channels_ <= 0) {
// Infer the "N" from the weights shape
N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString();
} else {
// Use a fixed "N" from the argument
N = out_channels_;
}
if (dX->has_name()) {
if (transW_) {
math::Gemm(
CblasNoTrans,
CblasNoTrans,
M,
K,
N,
1.f,
dY.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasNoTrans,
CblasTrans,
M,
K,
N,
1.f,
dY.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
}
if (dW->has_name()) {
if (transW_) {
math::Gemm(
CblasTrans,
CblasNoTrans,
N,
K,
M,
1.f,
dY.template data<T, Context>(),
X.template data<T, Context>(),
0.f,
dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasTrans,
CblasNoTrans,
K,
N,
M,
1.f,
X.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx());
}
}
if (dB->has_name()) {
vec32_t dims = {(int)M, (int)N}, axes = {0};
math::ReduceSum(
2,
dims.data(),
1,
axes.data(),
1.f,
dY.template data<T, Context>(),
dB->Reshape({N})->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
void FullyConnectedGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(FullyConnected);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(FullyConnected);
#endif
DEPLOY_CPU_OPERATOR(FullyConnectedGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(FullyConnectedGradient);
#endif
OPERATOR_SCHEMA(FullyConnected)
/* X, W, B */
.NumInputs(2, 3)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(FullyConnectedGradient)
/* X, W, dY */
.NumInputs(3)
/* dX, dW, dB */
.NumOutputs(3);
namespace {
class GradientMaker : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
"",
vector<string>({I(0), I(1), GO(0)}),
vector<string>({GI(0), GI(1), GI(2)}));
}
};
} // namespace
REGISTER_GRADIENT(FullyConnected, GradientMaker);
} // namespace dragon
#include "dragon/operators/math/gemm_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/filler.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
template <class Context>
template <typename T>
void GemmOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(A);
// Check matrix A
auto M = transA_ ? A.count(axis) : A.count(0, axis);
auto K = transA_ ? A.count(0, axis) : A.count(axis);
// Check matrix B
auto N = n_; // Init "N" from the argument
if (N <= 0) {
// Infer "N" from the B shape
N = B.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer 'N' from "
<< "the B shape: " << B.DimString();
}
if (transB_ > 0) {
TENSOR_FILL(B, vec64_t({N, K}));
CHECK(B.ndim() == 2 && B.dim(1) == K)
<< "\nMatrixB's dimensions should be [N, K].\n"
<< "Got A as (" << M << ", " << K << "), "
<< "and B as " << B.DimString();
} else {
TENSOR_FILL(B, vec64_t({K, N}));
CHECK(B.ndim() == 2 && B.dim(0) == K)
<< "\nMatrixB's dimensions should be [K, N].\n"
<< "Got A as (" << M << ", " << K << "), "
<< "and B as " << B.DimString();
}
// Copy matrix C to Y if provided
vec64_t Y_dims(A.dims().begin(), A.dims().begin() + axis);
Y_dims.insert(transA_ ? Y_dims.begin() : Y_dims.end(), N);
if (InputSize() > 2) {
auto& C = Input(2);
if (C.ndim() == 0) {
TENSOR_FILL(C, vec64_t({N}));
}
if (math::utils::IsBinaryBroadcast(Y_dims, C.dims(), Y_dims)) {
math::Set(
C.ndim(),
C.dims().data(),
Y_dims.size(),
Y_dims.data(),
C.template data<T, Context>(),
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else {
LOG(FATAL) << "Could not broadcast together with shapes: "
<< Tensor::DimString(Y_dims) << " " << C.DimString();
}
}
math::Gemm(
(CBLAS_TRANSPOSE)transA_,
(CBLAS_TRANSPOSE)transB_,
M,
N,
K,
alpha_,
A.template data<T, Context>(),
B.template data<T, Context>(),
InputSize() > 2 ? beta_ : 0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void GemmOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void GemmGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(3);
auto *dA = Output(0), *dB = Output(1), *dC = Output(2);
CANONICALIZE_AXIS_WITH_TENSOR(A);
// Check matrix A
auto M = transA_ ? A.count(axis) : A.count(0, axis);
auto K = transA_ ? A.count(0, axis) : A.count(axis);
// Check matrix B
auto N = n_; // Init "N" from the argument
if (N <= 0) {
// Infer "N" from the B shape
N = B.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer 'N' from "
<< "the B shape: " << B.DimString();
}
if (dA->has_name()) {
if (transA_ > 0) {
math::Gemm(
transB_ ? CblasTrans : CblasNoTrans,
CblasTrans,
K,
M,
N,
alpha_,
B.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasNoTrans,
transB_ ? CblasNoTrans : CblasTrans,
M,
K,
N,
alpha_,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
}
if (dB->has_name()) {
if (transB_) {
math::Gemm(
CblasTrans,
transA_ ? CblasTrans : CblasNoTrans,
N,
K,
M,
alpha_,
dY.template data<T, Context>(),
A.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
transA_ ? CblasNoTrans : CblasTrans,
CblasNoTrans,
K,
N,
M,
alpha_,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
}
if (dC->has_name()) {
auto& C = Input(2);
if (C.count() == dY.count()) {
math::Scale(
dY.count(),
beta_,
dY.template data<T, Context>(),
dC->ReshapeLike(C)->template mutable_data<T, Context>(),
ctx());
} else {
vec32_t Y_axes, C_axes;
math::utils::ComputeBinaryBroadcastAxes(
dY.dims(), C.dims(), dY.dims(), Y_axes, C_axes);
math::ReduceSum(
dY.ndim(),
vec32_t{dY.dims().begin(), dY.dims().end()}.data(),
C_axes.size(),
C_axes.data(),
beta_,
dY.template data<T, Context>(),
dC->ReshapeLike(C)->template mutable_data<T, Context>(),
ctx());
}
}
}
template <class Context>
void GemmGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Gemm);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Gemm);
#endif
DEPLOY_CPU_OPERATOR(GemmGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(GemmGradient);
#endif
OPERATOR_SCHEMA(Gemm)
/* A, B, C */
.NumInputs(2, 3)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(GemmGradient)
/* A, B, C, dY */
.NumInputs(4)
/* dA, dB, dC */
.NumOutputs(3);
namespace {
class GradientMaker : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
"",
vector<string>({I(0), I(1), I(2), GO(0)}),
vector<string>({GI(0), GI(1), GI(2)}));
}
};
} // namespace
REGISTER_GRADIENT(Gemm, GradientMaker);
} // namespace dragon
...@@ -10,20 +10,23 @@ ...@@ -10,20 +10,23 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_ #ifndef DRAGON_OPERATORS_MATH_GEMM_OP_H_
#define DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_ #define DRAGON_OPERATORS_MATH_GEMM_OP_H_
#include "dragon/core/operator.h" #include "dragon/core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class FullyConnectedOp final : public Operator<Context> { class GemmOp final : public Operator<Context> {
public: public:
FullyConnectedOp(const OperatorDef& def, Workspace* ws) GemmOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
out_channels_(OP_SINGLE_ARG(int64_t, "out_channels", 0)), n_(OP_SINGLE_ARG(int64_t, "n", 0)),
transW_(OP_SINGLE_ARG(int64_t, "transW", 1)) {} alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)),
beta_(OP_SINGLE_ARG(float, "beta", 1.f)),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -32,16 +35,20 @@ class FullyConnectedOp final : public Operator<Context> { ...@@ -32,16 +35,20 @@ class FullyConnectedOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
int64_t out_channels_, transW_; float alpha_, beta_;
int64_t n_, transA_, transB_;
}; };
template <class Context> template <class Context>
class FullyConnectedGradientOp final : public Operator<Context> { class GemmGradientOp final : public Operator<Context> {
public: public:
FullyConnectedGradientOp(const OperatorDef& def, Workspace* ws) GemmGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
out_channels_(OP_SINGLE_ARG(int64_t, "out_channels", 0)), n_(OP_SINGLE_ARG(int64_t, "n", 0)),
transW_(OP_SINGLE_ARG(int64_t, "transW", 1)) {} alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)),
beta_(OP_SINGLE_ARG(float, "beta", 1.f)),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -50,9 +57,10 @@ class FullyConnectedGradientOp final : public Operator<Context> { ...@@ -50,9 +57,10 @@ class FullyConnectedGradientOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
int64_t out_channels_, transW_; float alpha_, beta_;
int64_t n_, transA_, transB_;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_ #endif // DRAGON_OPERATORS_MATH_GEMM_OP_H_
#include "dragon/operators/math/matmul_op.h" #include "dragon/operators/math/matmul_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/math_functions.h" #include "dragon/utils/math_functions.h"
namespace dragon { namespace dragon {
...@@ -7,45 +8,205 @@ template <class Context> ...@@ -7,45 +8,205 @@ template <class Context>
template <typename T> template <typename T>
void MatMulOp<Context>::DoRunWithType() { void MatMulOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), *Y = Output(0); auto &A = Input(0), &B = Input(1), *Y = Output(0);
auto A_ndim = A.ndim(), B_ndim = B.ndim();
CHECK_GE(A.ndim(), 2) << "\nTensor(" << A.name() + ") must be a matrix" if (A_ndim == 1 && B_ndim == 1) {
<< "(or rank > 2, representing batches of matrices)."; // Vector x Vector
CHECK_GE(B.ndim(), 2) << "\nTensor(" << B.name() + ") must be a matrix" CHECK_EQ(A.count(), B.count()) << "\nExcept equal length of two vectors.";
<< "(or rank > 2, representing batches of matrices)."; math::Dot(
A.count(),
auto M1 = A.dim(-2), N1 = A.dim(-1); A.template data<T, Context>(),
auto M2 = B.dim(-2), N2 = B.dim(-1); B.template data<T, Context>(),
auto M = transA_ ? N1 : M1, N = transB_ ? M2 : N2; Y->Reshape({})->template mutable_data<T, Context>(),
auto K1 = transA_ ? M1 : N1, K2 = transB_ ? N2 : M2; ctx());
auto A_stride = M1 * N1, B_stride = M2 * N2, Y_stride = M * N; return;
auto batch_size = A.count() / A_stride; }
CHECK((K1 == K2) && (batch_size == (B.count() / B_stride))) if (A_ndim == 1) {
<< "\nTensor(" << A.name() << "): " << A.DimString() const auto N = A.count();
<< " can not mul with Tensor" CHECK_EQ(B.dim(B_ndim - 2), N) << "\nExcept the second last dim of B is "
<< "(" << B.name() << "): " << B.DimString(); << N << ", got " << B.dim(B_ndim - 2);
const auto M = B.dim(B_ndim - 1);
vec64_t Y_dims(A.dims()); const auto batch_size = B.count() / (M * N);
Y_dims[Y_dims.size() - 2] = M; vec64_t Y_dims(B.dims().begin(), B.dims().end() - 1);
Y_dims[Y_dims.size() - 1] = N; Y_dims.back() = B.dims().back();
Y->Reshape(Y_dims); if (batch_size == 1) {
// Vector x Matrix
auto* a = A.template data<T, Context>(); math::Gemv(
auto* b = B.template data<T, Context>(); CblasTrans,
auto* y = Y->template mutable_data<T, Context>(); N,
M,
for (int i = 0; i < batch_size; ++i) { 1.f,
B.template data<T, Context>(),
A.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else {
// Broadcasted Vector x Batched Matrix
math::GemmStridedBatched(
CblasTrans,
CblasNoTrans,
batch_size,
M,
1,
N,
M * N,
0,
M,
1.f,
B.template data<T, Context>(),
A.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
}
return;
}
if (B_ndim == 1) {
// Matrix x Vector
const auto N = B.count();
CHECK_EQ(A.dim(A_ndim - 1), N) << "\nExcept the last dim of A is " << N
<< ", got " << A.dim(A_ndim - 1);
const auto M = A.count() / N;
vec64_t Y_dims(A.dims());
Y_dims.erase(Y_dims.end() - 1);
math::Gemv(
CblasNoTrans,
M,
N,
1.f,
A.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
return;
}
// Check matrix A
const auto M = A.dim(A_ndim - 2);
const auto K = A.dim(A_ndim - 1);
// Check matrix B
CHECK_EQ(B.dim(B_ndim - 2), K) << "\nExcept the second last dim of B is " << K
<< ", got " << B.dim(B_ndim - 2);
const auto N = B.dim(B_ndim - 1);
// Check batching && broadcasting
vec64_t A_dims(A.dims().begin(), A.dims().end() - 2);
vec64_t B_dims(B.dims().begin(), B.dims().end() - 2);
vec64_t A_batch_dims, B_batch_dims, Y_dims;
if (math::utils::IsBinaryBroadcast(A_dims, B_dims, Y_dims)) {
math::utils::ComputeBinaryBroadcastDims(
A_dims, B_dims, A_batch_dims, B_batch_dims);
} else {
LOG(FATAL) << "Could not broadcast together with shapes " << A.DimString()
<< " " << B.DimString();
}
const int64_t batch_ndim = A_batch_dims.size();
const bool broadcasting = A_batch_dims != B_batch_dims;
Y_dims.push_back(M);
Y_dims.push_back(N);
const auto A_batch_size = std::accumulate(
A_batch_dims.begin(),
A_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
const auto B_batch_size = std::accumulate(
B_batch_dims.begin(),
B_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
const auto Y_batch_size = std::accumulate(
Y_dims.begin(),
Y_dims.begin() + batch_ndim,
1LL,
std::multiplies<int64_t>());
if (B_batch_size == 1) {
// Batched Matrix x Broadcasted Matrix
math::Gemm( math::Gemm(
transA_ > 0 ? CblasTrans : CblasNoTrans, CblasNoTrans,
transB_ > 0 ? CblasTrans : CblasNoTrans, CblasNoTrans,
A_batch_size * M,
N,
K,
1.f,
A.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else if (A_batch_size == 1) {
// Broadcasted Matrix x Batched Matrix
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
Y_batch_size,
M,
N,
K,
0,
K * N,
M * N,
1.f,
A.template data<T, Context>(),
B.template data<T, Context>(),
0.0f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else if (!broadcasting) {
// Batched Matrix x Batched Matrix
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
Y_batch_size,
M,
N,
K,
M * K,
K * N,
M * N,
1.f,
A.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else {
// Broadcasted Matrix x Broadcasted Matrix
vector<const T*> A_arr(Y_batch_size);
vector<const T*> B_arr(Y_batch_size);
vector<T*> Y_arr(Y_batch_size);
vec64_t index(batch_ndim, 0);
auto* A_data = A.template data<T, Context>();
auto* B_data = B.template data<T, Context>();
auto* Y_data = Y->Reshape(Y_dims)->template mutable_data<T, Context>();
for (int64_t Y_i = 0; Y_i < Y_batch_size; ++Y_i) {
const auto A_i = math::utils::GetIndexFromDims(
batch_ndim, A_batch_dims.data(), index.data());
const auto B_i = math::utils::GetIndexFromDims(
batch_ndim, B_batch_dims.data(), index.data());
A_arr[Y_i] = A_data + A_i * M * K;
B_arr[Y_i] = B_data + B_i * K * N;
Y_arr[Y_i] = Y_data + Y_i * M * N;
math::utils::IncreaseIndexInDims(batch_ndim, Y_dims.data(), index.data());
}
math::GemmBatched(
CblasNoTrans,
CblasNoTrans,
Y_batch_size,
M, M,
N, N,
K1, K,
1.f, 1.f,
a + i * A_stride, A_arr.data(),
b + i * B_stride, B_arr.data(),
0.f, 0.f,
y + i * Y_stride, Y_arr.data(),
ctx()); ctx());
} }
} }
...@@ -60,95 +221,397 @@ template <typename T> ...@@ -60,95 +221,397 @@ template <typename T>
void MatMulGradientOp<Context>::DoRunWithType() { void MatMulGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(2); auto &A = Input(0), &B = Input(1), &dY = Input(2);
auto *dA = Output(0), *dB = Output(1); auto *dA = Output(0), *dB = Output(1);
auto A_ndim = A.ndim(), B_ndim = B.ndim();
CHECK_GE(A.ndim(), 2) << "\nTensor(" << A.name() + ") must be a matrix" if (A_ndim == 1 && B_ndim == 1) {
<< "(or rank > 2, representing batches of matrices)."; // Vector x Vector
CHECK_GE(B.ndim(), 2) << "\nTensor(" << B.name() + ") must be a matrix" if (dA->has_name()) {
<< "(or rank > 2, representing batches of matrices)."; math::Mul(
dY.ndim(),
auto M1 = A.dim(-2), N1 = A.dim(-1); dY.dims().data(),
auto M2 = B.dim(-2), N2 = B.dim(-1); B.ndim(),
auto M = transA_ ? N1 : M1, N = transB_ ? M2 : N2; B.dims().data(),
auto K1 = transA_ ? M1 : N1, K2 = transB_ ? N2 : M2; dY.template data<T, Context>(),
auto A_stride = M1 * N1, B_stride = M2 * N2, Y_stride = M * N; B.template data<T, Context>(),
auto batch_size = A.count() / A_stride; dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
CHECK((K1 == K2) && (batch_size == (B.count() / B_stride))) }
<< "\nTensor(" << A.name() << "): " << A.DimString() if (dB->has_name()) {
<< " can not mul with Tensor" math::Mul(
<< "(" << B.name() << "): " << B.DimString(); dY.ndim(),
dY.dims().data(),
if (dA->has_name()) { A.ndim(),
auto* b = B.template data<T, Context>(); A.dims().data(),
auto* dy = dY.template data<T, Context>(); dY.template data<T, Context>(),
auto* da = dA->ReshapeLike(A)->template mutable_data<T, Context>(); A.template data<T, Context>(),
if (transA_ > 0) { dB->ReshapeLike(B)->template mutable_data<T, Context>(),
for (int i = 0; i < batch_size; ++i) { ctx());
math::Gemm( }
transB_ ? CblasTrans : CblasNoTrans, return;
CblasTrans, }
K1,
if (A_ndim == 1) {
const auto N = A.count();
if (dA->has_name()) {
const auto M = B.dim(B_ndim - 1);
const auto batch_size = B.count() / (M * N);
if (batch_size == 1) {
// Vector x Matrix
math::Gemv(
CblasNoTrans,
N,
M,
1.f,
B.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
} else {
// Broadcasted Vector x Batched Matrix
auto* scratch =
ctx()->workspace()->template data<T, Context>({batch_size * N})[0];
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
batch_size,
N,
1,
M,
M * N,
M, M,
N, N,
1.f, 1.f,
b + i * B_stride, B.template data<T, Context>(),
dy + i * Y_stride, dY.template data<T, Context>(),
0.f, 0.f,
da + i * A_stride, scratch,
ctx());
math::ReduceSum(
2,
vec32_t{int(batch_size), int(N)}.data(),
1,
vec32_t{0}.data(),
1.f,
scratch,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
} else { }
for (int i = 0; i < batch_size; ++i) { if (dB->has_name()) {
const auto M = B.dim(B_ndim - 1);
const auto batch_size = B.count() / (M * N);
if (batch_size == 1) {
// Vector x Matrix
math::Gemm( math::Gemm(
CblasNoTrans, CblasNoTrans,
transB_ ? CblasNoTrans : CblasTrans, CblasNoTrans,
N,
M, M,
K1, 1,
1.f,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
} else {
// Broadcasted Vector x Batched Matrix
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
batch_size,
N, N,
M,
1,
0,
M,
M * N,
1.f, 1.f,
dy + i * Y_stride, A.template data<T, Context>(),
b + i * B_stride, dY.template data<T, Context>(),
0.f, 0.f,
da + i * A_stride, dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx()); ctx());
} }
} }
return;
}
if (B_ndim == 1) {
const auto N = B.count();
const auto M = A.count() / N;
// Matrix x Vector
if (dA->has_name()) {
math::Gemm(
CblasNoTrans,
CblasNoTrans,
M,
N,
1,
1.f,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) { if (dB->has_name()) {
auto* a = A.template data<T, Context>(); math::Gemv(
auto* dy = dY.template data<T, Context>(); CblasTrans,
auto* db = dB->ReshapeLike(B)->template mutable_data<T, Context>(); M,
if (transB_) { N,
for (int i = 0; i < batch_size; ++i) { 1.f,
math::Gemm( A.template data<T, Context>(),
CblasTrans, dY.template data<T, Context>(),
transA_ ? CblasTrans : CblasNoTrans, 0.f,
N, dB->ReshapeLike(B)->template mutable_data<T, Context>(),
K1, ctx());
M, }
1.f, return;
dy + i * Y_stride, }
a + i * A_stride,
0.f, // Check matrix A && B
db + i * B_stride, const auto M = A.dim(A_ndim - 2);
ctx()); const auto K = A.dim(A_ndim - 1);
} const auto N = B.dim(B_ndim - 1);
} else {
for (int i = 0; i < batch_size; ++i) { // Check batching && broadcasting
math::Gemm( vec64_t A_dims(A.dims().begin(), A.dims().end() - 2);
transA_ ? CblasNoTrans : CblasTrans, vec64_t B_dims(B.dims().begin(), B.dims().end() - 2);
CblasNoTrans, vec64_t A_batch_dims, B_batch_dims, Y_batch_dims;
K1, vec32_t A_batch_axes, B_batch_axes;
N, if (math::utils::IsBinaryBroadcast(A_dims, B_dims, Y_batch_dims)) {
M, math::utils::ComputeBinaryBroadcastDims(
1.f, A_dims, B_dims, A_batch_dims, B_batch_dims);
a + i * A_stride, math::utils::ComputeBinaryBroadcastAxes(
dy + i * Y_stride, A_batch_dims, B_batch_dims, Y_batch_dims, A_batch_axes, B_batch_axes);
0.f, } else {
db + i * B_stride, LOG(FATAL) << "Could not broadcast together with shapes " << A.DimString()
ctx()); << " " << B.DimString();
} }
const int64_t batch_ndim = A_batch_dims.size();
const bool broadcasting = A_batch_dims != B_batch_dims;
const auto A_batch_size = std::accumulate(
A_batch_dims.begin(),
A_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
const auto B_batch_size = std::accumulate(
B_batch_dims.begin(),
B_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
const auto Y_batch_size = std::accumulate(
Y_batch_dims.begin(),
Y_batch_dims.end(),
1LL,
std::multiplies<int64_t>());
if (B_batch_size == 1) {
// Batched Matrix x Broadcasted Matrix
if (dA->has_name()) {
math::Gemm(
CblasNoTrans,
CblasTrans,
A_batch_size * M,
K,
N,
1.f,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
math::Gemm(
CblasTrans,
CblasNoTrans,
K,
N,
A_batch_size * M,
1.f,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
} else if (A_batch_size == 1) {
// Broadcasted Matrix x Batched Matrix
if (dA->has_name()) {
auto* scratch = ctx()->workspace()->template data<T, Context>(
{Y_batch_size * M * K})[0];
math::GemmStridedBatched(
CblasNoTrans,
CblasTrans,
Y_batch_size,
M,
K,
N,
M * N,
K * N,
M * K,
1.f,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.0f,
scratch,
ctx());
math::ReduceSum(
2,
vec32_t{int(Y_batch_size), int(M * K)}.data(),
1,
vec32_t{0}.data(),
1.f,
scratch,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
math::GemmStridedBatched(
CblasTrans,
CblasNoTrans,
Y_batch_size,
K,
N,
M,
0,
M * N,
K * N,
1.f,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
} else if (!broadcasting) {
// Batched Matrix x Batched Matrix
if (dA->has_name()) {
math::GemmStridedBatched(
CblasNoTrans,
CblasTrans,
Y_batch_size,
M,
K,
N,
M * N,
K * N,
M * K,
1.f,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
math::GemmStridedBatched(
CblasTrans,
CblasNoTrans,
Y_batch_size,
K,
N,
M,
M * K,
M * N,
K * N,
1.f,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
} else {
// Broadcasted Matrix x Broadcasted Matrix
vector<const T*> A_arr(Y_batch_size);
vector<const T*> B_arr(Y_batch_size);
vector<const T*> dY_arr(Y_batch_size);
vector<T*> dA_arr(Y_batch_size);
vector<T*> dB_arr(Y_batch_size);
if (dA->has_name()) {
vec64_t index(batch_ndim, 0);
vec32_t scratch_dims({Y_batch_dims.begin(), Y_batch_dims.end()});
scratch_dims.push_back(int(M * K));
auto* dY_data = dY.template data<T, Context>();
auto* B_data = B.template data<T, Context>();
auto* scratch = ctx()->workspace()->template data<T, Context>(
{Y_batch_size * std::max(M * K, K * N)})[0];
for (int64_t Y_i = 0; Y_i < Y_batch_size; ++Y_i) {
const auto B_i = math::utils::GetIndexFromDims(
batch_ndim, B_batch_dims.data(), index.data());
dY_arr[Y_i] = dY_data + Y_i * M * N;
B_arr[Y_i] = B_data + B_i * K * N;
dA_arr[Y_i] = scratch + Y_i * M * K;
math::utils::IncreaseIndexInDims(
batch_ndim, Y_batch_dims.data(), index.data());
}
math::GemmBatched(
CblasNoTrans,
CblasTrans,
Y_batch_size,
M,
K,
N,
1.f,
dY_arr.data(),
B_arr.data(),
0.f,
dA_arr.data(),
ctx());
math::ReduceSum(
scratch_dims.size(),
scratch_dims.data(),
A_batch_axes.size(),
A_batch_axes.data(),
1.f,
scratch,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
if (dB->has_name()) {
vec64_t index(batch_ndim, 0);
vec32_t scratch_dims({Y_batch_dims.begin(), Y_batch_dims.end()});
scratch_dims.push_back(int(K * N));
auto* dY_data = dY.template data<T, Context>();
auto* A_data = A.template data<T, Context>();
auto* scratch = ctx()->workspace()->template data<T, Context>(
{Y_batch_size * std::max(M * K, K * N)})[0];
for (int64_t Y_i = 0; Y_i < Y_batch_size; ++Y_i) {
const auto A_i = math::utils::GetIndexFromDims(
batch_ndim, A_batch_dims.data(), index.data());
dY_arr[Y_i] = dY_data + Y_i * M * N;
A_arr[Y_i] = A_data + A_i * M * K;
dB_arr[Y_i] = scratch + Y_i * K * N;
math::utils::IncreaseIndexInDims(
batch_ndim, Y_batch_dims.data(), index.data());
} }
math::GemmBatched(
CblasTrans,
CblasNoTrans,
Y_batch_size,
K,
N,
M,
1.f,
A_arr.data(),
dY_arr.data(),
0.f,
dB_arr.data(),
ctx());
math::ReduceSum(
scratch_dims.size(),
scratch_dims.data(),
B_batch_axes.size(),
B_batch_axes.data(),
1.f,
scratch,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
} }
} }
} }
......
...@@ -20,37 +20,25 @@ namespace dragon { ...@@ -20,37 +20,25 @@ namespace dragon {
template <class Context> template <class Context>
class MatMulOp final : public Operator<Context> { class MatMulOp final : public Operator<Context> {
public: public:
MatMulOp(const OperatorDef& def, Workspace* ws) SIMPLE_CTOR_DTOR(MatMulOp);
: Operator<Context>(def, ws),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
int64_t transA_, transB_;
}; };
template <class Context> template <class Context>
class MatMulGradientOp final : public Operator<Context> { class MatMulGradientOp final : public Operator<Context> {
public: public:
MatMulGradientOp(const OperatorDef& def, Workspace* ws) SIMPLE_CTOR_DTOR(MatMulGradientOp);
: Operator<Context>(def, ws),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
int64_t transA_, transB_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -35,6 +35,7 @@ from dragon.core.ops.math_ops import dot ...@@ -35,6 +35,7 @@ from dragon.core.ops.math_ops import dot
from dragon.core.ops.math_ops import equal from dragon.core.ops.math_ops import equal
from dragon.core.ops.math_ops import exp from dragon.core.ops.math_ops import exp
from dragon.core.ops.math_ops import floor from dragon.core.ops.math_ops import floor
from dragon.core.ops.math_ops import gemm
from dragon.core.ops.math_ops import greater from dragon.core.ops.math_ops import greater
from dragon.core.ops.math_ops import greater_equal from dragon.core.ops.math_ops import greater_equal
from dragon.core.ops.math_ops import is_inf from dragon.core.ops.math_ops import is_inf
......
...@@ -33,7 +33,6 @@ from dragon.core.ops.activation_ops import relu6 ...@@ -33,7 +33,6 @@ from dragon.core.ops.activation_ops import relu6
from dragon.core.ops.activation_ops import selu from dragon.core.ops.activation_ops import selu
from dragon.core.ops.activation_ops import softmax from dragon.core.ops.activation_ops import softmax
from dragon.core.ops.activation_ops import swish from dragon.core.ops.activation_ops import swish
from dragon.core.ops.math_ops import fully_connected
from dragon.core.ops.normalization_ops import batch_norm from dragon.core.ops.normalization_ops import batch_norm
from dragon.core.ops.normalization_ops import group_norm from dragon.core.ops.normalization_ops import group_norm
from dragon.core.ops.normalization_ops import instance_norm from dragon.core.ops.normalization_ops import instance_norm
......
...@@ -414,30 +414,34 @@ def flatten_spec(args, inputs, outputs): ...@@ -414,30 +414,34 @@ def flatten_spec(args, inputs, outputs):
return outputs return outputs
@register('FullyConnected') @register('Gemm')
def fully_connected_spec(args, inputs, outputs): def gemm_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
axis, out_channels = args['axis'], args.get('out_channels', None) axis, n = args['axis'], args.get('n', None)
while axis < 0: while axis < 0:
try: try:
axis += len(inputs[0].shape) axis += len(inputs[0].shape)
except TypeError: except TypeError:
return outputs break
out_shape = [None] * (axis + 1) out_shape = [None] * axis if axis >= 0 else None
if out_channels is None: if n is None:
try: try:
if args['transW']: if args['transB']:
out_channels = inputs[1].shape[0] n = inputs[1].shape[0]
else: else:
out_channels = inputs[1].shape[1] n = inputs[1].shape[1]
except (TypeError, IndexError): except (TypeError, IndexError):
out_channels = None n = None
try: try:
out_shape[axis] = out_channels if out_shape is None or inputs[0].shape is not None:
out_shape[:axis] = inputs[0].shape[:axis] out_shape = list(inputs[0].shape[:axis])
if args['transA']:
out_shape.insert(0, n)
else:
out_shape.append(n)
outputs[0].shape = out_shape
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
outputs[0].shape = out_shape
return outputs return outputs
...@@ -510,12 +514,25 @@ def masked_select_spec(args, inputs, outputs): ...@@ -510,12 +514,25 @@ def masked_select_spec(args, inputs, outputs):
@register('MatMul') @register('MatMul')
def matmul_spec(args, inputs, outputs): def matmul_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
ta, tb = args['transA'], args['transB']
try: try:
a_shape = list(inputs[0].shape[:])
b_shape = list(inputs[1].shape[:]) b_shape = list(inputs[1].shape[:])
a_shape = out_shape = list(inputs[0].shape[:]) if len(a_shape) >= 2 and len(b_shape) >= 2:
out_shape[-2] = a_shape[-1] if ta else a_shape[-2] out_shape = [1] * max(len(a_shape), len(b_shape))
out_shape[-1] = b_shape[-2] if tb else b_shape[-1] a_shape = [1] * (len(out_shape) - len(a_shape)) + a_shape
b_shape = [1] * (len(out_shape) - len(b_shape)) + b_shape
for i in range(len(out_shape)):
try:
out_shape[i] = max(a_shape[i], b_shape[i])
except TypeError:
out_shape[i] = None
out_shape[-2] = a_shape[-2]
out_shape[-1] = b_shape[-1]
elif len(a_shape) == 1 and len(b_shape) == 1:
out_shape = []
else:
out_shape = a_shape if len(b_shape) == 1 else b_shape
out_shape.pop(-1 if len(b_shape) == 1 else -2)
except (TypeError, IndexError): except (TypeError, IndexError):
out_shape = None out_shape = None
outputs[0].shape = out_shape outputs[0].shape = out_shape
......
...@@ -498,23 +498,30 @@ def floor(inputs, **kwargs): ...@@ -498,23 +498,30 @@ def floor(inputs, **kwargs):
@OpSchema.num_inputs(2, 3) @OpSchema.num_inputs(2, 3)
def fully_connected(inputs, axis=1, transpose_w=True, **kwargs): def gemm(
r"""Compute the dense matrix multiplication along the given axes. inputs,
alpha=1.0,
.. math:: y = Wx + b beta=1.0,
transpose_a=False,
The column of input matrix is determined by: transpose_b=False,
**kwargs
.. math:: \text{Col} = \text{DimSince}(\text{Input}, \text{Axis}) ):
r"""Compute the general matrix multiplication.
.. math:: \text{out} = \alpha AB + \beta C
Parameters Parameters
---------- ----------
inputs : Sequence[dragon.Tensor] inputs : Sequence[dragon.Tensor]
The tensor :math:`x`, :math:`W` and :math:`b`. The matrix :math:`A`, :math:`B` and optional :math:`C`.
axis : int, optional, default=1 alpha : float, optional, default=1.0
The start axis to compute, can be negative. The value to :math:`\alpha`.
transpose_w : bool, optional, default=True beta : float, optional, default=1.0
**True** to transpose :math:`W` before computation. The value to :math:`\beta`.
transpose_a : bool, optional, default=False
**True** to transpose :math:`A` before computation.
transpose_b : bool, optional, default=False
**True** to transpose :math:`B` before computation.
Returns Returns
------- -------
...@@ -523,15 +530,22 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs): ...@@ -523,15 +530,22 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs):
""" """
args = ArgHelper.parse(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.FullyConnected args['axis'] = kwargs.get('axis', -1)
args['alpha'], args['beta'] = float(alpha), float(beta)
op_lib = math_ops_lib.Gemm
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate(axis=axis, transpose_w=transpose_w) \ .instantiate(
.apply(inputs) axis=args['axis'],
alpha=args['alpha'],
beta=args['beta'],
transpose_a=transpose_a,
transpose_b=transpose_b,
).apply(inputs)
else: else:
args.pop('transpose_w') args['transA'] = args.pop('transpose_a')
args['transW'] = transpose_w args['transB'] = args.pop('transpose_b')
return op_lib.blend('FullyConnected', **args) return op_lib.blend(**args)
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
...@@ -812,42 +826,44 @@ def less_equal(inputs, **kwargs): ...@@ -812,42 +826,44 @@ def less_equal(inputs, **kwargs):
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs): def matmul(inputs, **kwargs):
r"""Compute the matrix multiplication. r"""Compute the matrix multiplication.
.. math:: y = a \times b .. math:: \text{out} = \text{input1} \times \text{input2}
The rank of ``a`` and ``b`` should be equal and >= 2: The behavior depends on the shape of input tensors:
```python * If both tensors are 1d, computes the vector product.
# Ok, a typical matrix multiplication * If tensors are 1d and >=2d, computes the vector-matrix multiplication.
a = dragon.ones((2, 3), 'float32') * If tensors are >=2d and 1d, computes the matrix-vector multiplication.
b = dragon.ones((3, 3), 'float32') * If both tensors are >= 2d, computes the matrix-matrix multiplication.
print(dragon.math.matmul([a, b])) * If one tensor is >= 3d, applies batching and broadcasting to the computation.
# Compute a batch matrix multiplication if rank > 2 Examples:
aa = dragon.ones((4, 2, 3), 'float32')
bb = dragon.ones((4, 3, 3), 'float32')
print(dragon.math.matmul([aa, bb]))
```
If inputs are transposed, remember to transpose them back:
```python ```python
# Vector x Vector
a = dragon.ones((2,), 'float32')
b = dragon.ones((2,), 'float32')
print(dragon.math.matmul([a, b]))
# Vector x Matrix
a = dragon.ones((2,), 'float32')
b = dragon.ones((2, 3), 'float32')
print(dragon.math.matmul([a, b]))
# Matrix x Vector
a = dragon.ones((3, 2), 'float32') a = dragon.ones((3, 2), 'float32')
b = dragon.ones((3, 3), 'float32') b = dragon.ones((2,), 'float32')
print(dragon.math.matmul([a, b])) # ``a`` takes the wrong dimensions print(dragon.math.matmul([a, b]))
print(dragon.math.matmul([a, b], transpose_a=True)) # Ok # Matrix x Matrix
a = dragon.ones((2, 3), 'float32')
b = dragon.ones((3, 2), 'float32')
print(dragon.math.matmul([a, b]))
``` ```
Parameters Parameters
---------- ----------
inputs : Sequence[dragon.Tensor] inputs : Sequence[dragon.Tensor]
The matrix :math:`a` and :math:`b`. The input tensors.
transpose_a : bool, optional, default=False
**True** to transpose :math:`a` before computation.
transpose_b : bool, optional, default=False
**True** to transpose :math:`b` before computation.
Returns Returns
------- -------
...@@ -858,15 +874,9 @@ def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs): ...@@ -858,15 +874,9 @@ def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs):
args = ArgHelper.parse(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.MatMul op_lib = math_ops_lib.MatMul
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib.instantiate().apply(inputs)
.instantiate(
transpose_a=transpose_a,
transpose_b=transpose_b,
).apply(inputs)
else: else:
args.pop('transpose_a') return op_lib.blend(**args)
args.pop('transpose_b')
return op_lib.blend(transA=transpose_a, transB=transpose_b, **args)
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
......
...@@ -80,20 +80,26 @@ class Clip(Operator): ...@@ -80,20 +80,26 @@ class Clip(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class FullyConnected(Operator): class Gemm(Operator):
"""FullyConnected operator.""" """FullyConnected operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(FullyConnected, self).__init__(key, dev, **kwargs) super(Gemm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1) self.axis = kwargs.get('axis', -1)
self.transpose_w = kwargs.get('transpose_w', True) self.alpha = kwargs.get('alpha', 1.0)
self.beta = kwargs.get('beta', 1.0)
self.transpose_a = kwargs.get('transpose_a', False)
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self): def attributes(self):
return { return {
'op_type': 'FullyConnected', 'op_type': 'Gemm',
'arguments': { 'arguments': {
'axis': self.axis, 'axis': self.axis,
'transW': self.transpose_w, 'alpha': self.alpha,
'beta': self.beta,
'transA': self.transpose_a,
'transB': self.transpose_b,
} }
} }
...@@ -104,18 +110,10 @@ class FullyConnected(Operator): ...@@ -104,18 +110,10 @@ class FullyConnected(Operator):
class MatMul(Operator): class MatMul(Operator):
"""MatMul operator.""" """MatMul operator."""
def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False)
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self): def attributes(self):
return { return {
'op_type': 'MatMul', 'op_type': 'MatMul',
'arguments': { 'arguments': {},
'transA': self.transpose_a,
'transB': self.transpose_b,
}
} }
def forward(self, inputs): def forward(self, inputs):
......
...@@ -136,6 +136,7 @@ def conv( ...@@ -136,6 +136,7 @@ def conv(
data_format=data_format, data_format=data_format,
bias=len(inputs) > 2, bias=len(inputs) > 2,
dtype=inputs[1].dtype, dtype=inputs[1].dtype,
input_shape=inputs[0].shape,
).apply(inputs) ).apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -465,6 +466,7 @@ def conv_transpose( ...@@ -465,6 +466,7 @@ def conv_transpose(
data_format=data_format, data_format=data_format,
bias=len(inputs) > 2, bias=len(inputs) > 2,
dtype=inputs[1].dtype, dtype=inputs[1].dtype,
input_shape=inputs[0].shape,
).apply(inputs) ).apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -44,13 +44,6 @@ def hardsigmoid_exporter(op_def, context): ...@@ -44,13 +44,6 @@ def hardsigmoid_exporter(op_def, context):
return node, const_tensors return node, const_tensors
@export_util.register('PRelu')
def prelu_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
const_tensors = [helper.from_tensor(op_def.input[1], context.ws)]
return node, const_tensors
@export_util.register('Relu') @export_util.register('Relu')
def relu_exporter(op_def, context): def relu_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
......
...@@ -81,24 +81,14 @@ def clip_exporter_v11(op_def, context): ...@@ -81,24 +81,14 @@ def clip_exporter_v11(op_def, context):
return node, const_tensors return node, const_tensors
@export_util.register('FullyConnected-7') @export_util.register('Gemm-7')
def fully_connected_exporter_v7(op_def, context): def gemm_exporter_v7(op_def, context):
node, const_tensors = export_util.translate(**locals()) return export_util.translate(**locals())
node.op_type = 'Gemm'
helper.add_attribute(node, 'alpha', 1.)
helper.add_attribute(node, 'beta', 1.)
for arg in op_def.arg:
if arg.name == 'transW':
helper.add_attribute(node, 'transB', arg.i)
# Weights and biases
const_tensors = [helper.from_tensor(name, context.ws)
for name in op_def.input[1:]]
return node, const_tensors
@export_util.register('FullyConnected') @export_util.register('Gemm')
def fully_connected_exporter(op_def, context): def gemm_exporter(op_def, context):
node, const_tensors = fully_connected_exporter_v7(op_def, context) node, const_tensors = gemm_exporter_v7(op_def, context)
helper.add_attribute(node, 'broadcast', 1) # Removed since opset 7 helper.add_attribute(node, 'broadcast', 1) # Removed since opset 7
return node, const_tensors return node, const_tensors
......
...@@ -29,8 +29,6 @@ def batch_norm_exporter(op_def, context): ...@@ -29,8 +29,6 @@ def batch_norm_exporter(op_def, context):
elif arg.name == 'momentum_desc': elif arg.name == 'momentum_desc':
momentum = helper.fetch_argument(op_def, arg, context.ws) momentum = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'momentum', float(momentum)) helper.add_attribute(node, 'momentum', float(momentum))
# Weight, bias, running mean and running variance
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
...@@ -48,8 +46,6 @@ def group_norm_exporter(op_def, context): ...@@ -48,8 +46,6 @@ def group_norm_exporter(op_def, context):
else: else:
helper.add_attribute(node, 'op_type', 'GroupNorm') helper.add_attribute(node, 'op_type', 'GroupNorm')
helper.add_attribute(node, 'group', arg.i) helper.add_attribute(node, 'group', arg.i)
# Weight and bias
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
......
...@@ -25,7 +25,7 @@ from dragon.vm.onnx.core.exporters import utils as export_util ...@@ -25,7 +25,7 @@ from dragon.vm.onnx.core.exporters import utils as export_util
'ConvTranspose', 'ConvTranspose',
'DepthwiseConv', 'DepthwiseConv',
]) ])
def convolution(op_def, context): def conv_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'ConvTranspose' if 'Transpose' in op_def.type else 'Conv' node.op_type = 'ConvTranspose' if 'Transpose' in op_def.type else 'Conv'
if 'Depthwise' in op_def.type: if 'Depthwise' in op_def.type:
...@@ -58,8 +58,6 @@ def convolution(op_def, context): ...@@ -58,8 +58,6 @@ def convolution(op_def, context):
helper.add_attribute(node, 'output_shape', arg.ints) helper.add_attribute(node, 'output_shape', arg.ints)
elif arg.name == 'output_padding': elif arg.name == 'output_padding':
helper.add_attribute(node, 'output_padding', arg.ints) helper.add_attribute(node, 'output_padding', arg.ints)
# Weights and biases
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
......
...@@ -203,8 +203,7 @@ DRAGON_API void Gemv<float16, CPUContext>( ...@@ -203,8 +203,7 @@ DRAGON_API void Gemv<float16, CPUContext>(
const float16* x, const float16* x,
const float beta, const float beta,
float16* y, float16* y,
CPUContext* ctx, CPUContext* ctx) {
const std::string math_type) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -219,8 +218,7 @@ DRAGON_API void Gemv<float16, CPUContext>( ...@@ -219,8 +218,7 @@ DRAGON_API void Gemv<float16, CPUContext>(
const T* x, \ const T* x, \
const float beta, \ const float beta, \
T* y, \ T* y, \
CPUContext* ctx, \ CPUContext* ctx) { \
const string math_type) { \
T _alpha_ = alpha, _beta_ = beta; \ T _alpha_ = alpha, _beta_ = beta; \
EigenVectorMap<T> y_vec(y, TransA == CblasNoTrans ? M : N); \ EigenVectorMap<T> y_vec(y, TransA == CblasNoTrans ? M : N); \
if (beta == 0.f) \ if (beta == 0.f) \
...@@ -260,8 +258,7 @@ DRAGON_API void Gemm<float16, CPUContext>( ...@@ -260,8 +258,7 @@ DRAGON_API void Gemm<float16, CPUContext>(
const float16* B, const float16* B,
const float beta, const float beta,
float16* C, float16* C,
CPUContext* ctx, CPUContext* ctx) {
const string math_type) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -278,8 +275,7 @@ DRAGON_API void Gemm<float16, CPUContext>( ...@@ -278,8 +275,7 @@ DRAGON_API void Gemm<float16, CPUContext>(
const T* B, \ const T* B, \
const float beta, \ const float beta, \
T* C, \ T* C, \
CPUContext* ctx, \ CPUContext* ctx) { \
const string math_type) { \
T _alpha_ = alpha, _beta_ = beta; \ T _alpha_ = alpha, _beta_ = beta; \
auto C_mat = EigenMatrixMap<T>(C, N, M); \ auto C_mat = EigenMatrixMap<T>(C, N, M); \
if (beta == 0.f) \ if (beta == 0.f) \
...@@ -328,6 +324,105 @@ DEFINE_GEMM_FUNC(float); ...@@ -328,6 +324,105 @@ DEFINE_GEMM_FUNC(float);
DEFINE_GEMM_FUNC(double); DEFINE_GEMM_FUNC(double);
#undef DEFINE_GEMM_FUNC #undef DEFINE_GEMM_FUNC
template <>
DRAGON_API void GemmBatched<float16, CPUContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
const float16** A,
const float16** B,
const float beta,
float16** C,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_BATCHED_GEMM_FUNC(T) \
template <> \
DRAGON_API void GemmBatched<T, CPUContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const float alpha, \
const T** A, \
const T** B, \
const float beta, \
T** C, \
CPUContext* ctx) { \
for (int i = 0; i < batch_size; ++i) { \
Gemm(TransA, TransB, M, N, K, alpha, A[i], B[i], beta, C[i], ctx); \
} \
}
DEFINE_BATCHED_GEMM_FUNC(float);
DEFINE_BATCHED_GEMM_FUNC(double);
#undef DEFINE_BATCHED_GEMM_FUNC
template <>
DRAGON_API void GemmStridedBatched<float16, CPUContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const int A_stride,
const int B_stride,
const int C_stride,
const float alpha,
const float16* A,
const float16* B,
const float beta,
float16* C,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_STRIDED_BATCHED_GEMM_FUNC(T) \
template <> \
DRAGON_API void GemmStridedBatched<T, CPUContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const int A_stride, \
const int B_stride, \
const int C_stride, \
const float alpha, \
const T* A, \
const T* B, \
const float beta, \
T* C, \
CPUContext* ctx) { \
for (int i = 0; i < batch_size; ++i) { \
Gemm( \
TransA, \
TransB, \
M, \
N, \
K, \
alpha, \
A + i * A_stride, \
B + i * B_stride, \
beta, \
C + i * C_stride, \
ctx); \
} \
}
DEFINE_STRIDED_BATCHED_GEMM_FUNC(float);
DEFINE_STRIDED_BATCHED_GEMM_FUNC(double);
#undef DEFINE_STRIDED_BATCHED_GEMM_FUNC
} // namespace math } // namespace math
} // namespace dragon } // namespace dragon
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/conversions.h" #include "dragon/utils/conversions.h"
#include "dragon/utils/device/common_thrust.h"
#include "dragon/utils/math/blas.h" #include "dragon/utils/math/blas.h"
namespace dragon { namespace dragon {
...@@ -456,8 +457,7 @@ DRAGON_API void Gemv<float16, CUDAContext>( ...@@ -456,8 +457,7 @@ DRAGON_API void Gemv<float16, CUDAContext>(
const float16* x, const float16* x,
const float beta, const float beta,
float16* y, float16* y,
CUDAContext* ctx, CUDAContext* ctx) {
const string math_type) {
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N; auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
int m = cuTransA == CUBLAS_OP_N ? N : M; int m = cuTransA == CUBLAS_OP_N ? N : M;
int k = cuTransA == CUBLAS_OP_N ? M : N; int k = cuTransA == CUBLAS_OP_N ? M : N;
...@@ -465,53 +465,8 @@ DRAGON_API void Gemv<float16, CUDAContext>( ...@@ -465,53 +465,8 @@ DRAGON_API void Gemv<float16, CUDAContext>(
int LDC = m; int LDC = m;
CUBLAS_CHECK( CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
if (math_type == "float32") { if (TENSOR_CORE_AVAILABLE()) {
#if CUDA_VERSION >= 9000 CUBLAS_CHECK(cublasGemmEx(
if (TENSOR_CORE_AVAILABLE()) {
// GEMV + MATH32 + TENSOR-CORE
CUBLAS_CHECK(cublasGemmEx(
ctx->cublas_handle(),
cuTransA,
CUBLAS_OP_N,
m,
1,
k,
&alpha,
A,
CUDA_R_16F,
LDA,
x,
CUDA_R_16F,
k,
&beta,
y,
CUDA_R_16F,
LDC,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// GEMV + MATH32 + DEFAULT
CUBLAS_CHECK(cublasSgemmEx(
ctx->cublas_handle(),
cuTransA,
CUBLAS_OP_N,
m,
1,
k,
&alpha,
A,
CUDA_R_16F,
LDA,
x,
CUDA_R_16F,
k,
&beta,
y,
CUDA_R_16F,
LDC));
}
#else
CUBLAS_CHECK(cublasSgemmEx(
ctx->cublas_handle(), ctx->cublas_handle(),
cuTransA, cuTransA,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -528,124 +483,66 @@ DRAGON_API void Gemv<float16, CUDAContext>( ...@@ -528,124 +483,66 @@ DRAGON_API void Gemv<float16, CUDAContext>(
&beta, &beta,
y, y,
CUDA_R_16F, CUDA_R_16F,
LDC)); LDC,
#endif CUDA_R_32F,
} else if (math_type == "float16") { CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const half alpha_val = convert::To<half>(alpha); } else {
const half beta_val = convert::To<half>(beta); CUBLAS_CHECK(cublasSgemmEx(
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
// GEMV + MATH16 + TENSOR-CORE
CUBLAS_CHECK(cublasGemmEx(
ctx->cublas_handle(),
cuTransA,
CUBLAS_OP_N,
m,
1,
k,
&alpha_val,
A,
CUDA_R_16F,
LDA,
x,
CUDA_R_16F,
k,
&beta_val,
y,
CUDA_R_16F,
LDC,
CUDA_R_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// GEMV + MATH16 + DEFAULT
CUBLAS_CHECK(cublasHgemm(
ctx->cublas_handle(),
cuTransA,
CUBLAS_OP_N,
m,
1,
k,
&alpha_val,
reinterpret_cast<const half*>(A),
LDA,
reinterpret_cast<const half*>(x),
k,
&beta_val,
reinterpret_cast<half*>(y),
LDC));
}
#else
CUBLAS_CHECK(cublasHgemm(
ctx->cublas_handle(), ctx->cublas_handle(),
cuTransA, cuTransA,
CUBLAS_OP_N, CUBLAS_OP_N,
m, m,
1, 1,
k, k,
&alpha_val, &alpha,
reinterpret_cast<const half*>(A), A,
CUDA_R_16F,
LDA, LDA,
reinterpret_cast<const half*>(x), x,
CUDA_R_16F,
k, k,
&beta_val, &beta,
reinterpret_cast<half*>(y), y,
CUDA_R_16F,
LDC)); LDC));
#endif
} else {
LOG(FATAL) << "Unknown math type: " << math_type;
} }
} }
template <> #define DEFINE_GEMV_FUNC(T, cublas_func) \
DRAGON_API void Gemv<float, CUDAContext>( template <> \
const CBLAS_TRANSPOSE TransA, DRAGON_API void Gemv<T, CUDAContext>( \
const int M, const CBLAS_TRANSPOSE TransA, \
const int N, const int M, \
const float alpha, const int N, \
const float* A, const float alpha, \
const float* x, const T* A, \
const float beta, const T* x, \
float* y, const float beta, \
CUDAContext* ctx, T* y, \
const string math_type) { CUDAContext* ctx) { \
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N; auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N; \
CUBLAS_CHECK( const auto alpha_val = static_cast<T>(alpha); \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); const auto beta_val = static_cast<T>(beta); \
CUBLAS_CHECK(cublasSgemv( CUBLAS_CHECK( \
ctx->cublas_handle(), cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1)); cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
} CUBLAS_CHECK(cublas_func( \
ctx->cublas_handle(), \
cuTransA, \
N, \
M, \
&alpha_val, \
A, \
N, \
x, \
1, \
&beta_val, \
y, \
1)); \
}
template <> DEFINE_GEMV_FUNC(float, cublasSgemv);
DRAGON_API void Gemv<double, CUDAContext>( DEFINE_GEMV_FUNC(double, cublasDgemv);
const CBLAS_TRANSPOSE TransA, #undef DEFINE_GEMV_FUNC
const int M,
const int N,
const float alpha,
const double* A,
const double* x,
const float beta,
double* y,
CUDAContext* ctx,
const string math_type) {
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto alpha_val = static_cast<double>(alpha);
const auto beta_val = static_cast<double>(beta);
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasDgemv(
ctx->cublas_handle(),
cuTransA,
N,
M,
&alpha_val,
A,
N,
x,
1,
&beta_val,
y,
1));
}
template <> template <>
DRAGON_API void Gemm<float16, CUDAContext>( DRAGON_API void Gemm<float16, CUDAContext>(
...@@ -659,61 +556,15 @@ DRAGON_API void Gemm<float16, CUDAContext>( ...@@ -659,61 +556,15 @@ DRAGON_API void Gemm<float16, CUDAContext>(
const float16* B, const float16* B,
const float beta, const float beta,
float16* C, float16* C,
CUDAContext* ctx, CUDAContext* ctx) {
const std::string math_type) {
int lda = (TransA == CblasNoTrans) ? K : M; int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K; int ldb = (TransB == CblasNoTrans) ? N : K;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBLAS_CHECK( CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
if (math_type == "float32") { if (TENSOR_CORE_AVAILABLE()) {
#if CUDA_VERSION >= 9000 CUBLAS_CHECK(cublasGemmEx(
if (TENSOR_CORE_AVAILABLE()) {
// GEMM + MATH32 + TENSOR-CORE
CUBLAS_CHECK(cublasGemmEx(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&beta,
C,
CUDA_R_16F,
N,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// GEMM + MATH32 + DEFAULT
CUBLAS_CHECK(cublasSgemmEx(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&beta,
C,
CUDA_R_16F,
N));
}
#else
CUBLAS_CHECK(cublasSgemmEx(
ctx->cublas_handle(), ctx->cublas_handle(),
cuTransB, cuTransB,
cuTransA, cuTransA,
...@@ -730,95 +581,99 @@ DRAGON_API void Gemm<float16, CUDAContext>( ...@@ -730,95 +581,99 @@ DRAGON_API void Gemm<float16, CUDAContext>(
&beta, &beta,
C, C,
CUDA_R_16F, CUDA_R_16F,
N)); N,
#endif CUDA_R_32F,
} else if (math_type == "float16") { CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const half alpha_val = convert::To<half>(alpha); } else {
const half beta_val = convert::To<half>(beta); CUBLAS_CHECK(cublasSgemmEx(
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) {
// GEMM + MATH16 + TENSOR-CORE
CUBLAS_CHECK(cublasGemmEx(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha_val,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&beta_val,
C,
CUDA_R_16F,
N,
CUDA_R_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// GEMM + MATH16 + DEFAULT
CUBLAS_CHECK(cublasHgemm(
ctx->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha_val,
reinterpret_cast<const half*>(B),
ldb,
reinterpret_cast<const half*>(A),
lda,
&beta_val,
reinterpret_cast<half*>(C),
N));
}
#else
CUBLAS_CHECK(cublasHgemm(
ctx->cublas_handle(), ctx->cublas_handle(),
cuTransB, cuTransB,
cuTransA, cuTransA,
N, N,
M, M,
K, K,
&alpha_val, &alpha,
reinterpret_cast<const half*>(B), B,
CUDA_R_16F,
ldb, ldb,
reinterpret_cast<const half*>(A), A,
CUDA_R_16F,
lda, lda,
&beta_val, &beta,
reinterpret_cast<half*>(C), C,
CUDA_R_16F,
N)); N));
#endif
} else {
LOG(FATAL) << "Unknown math type: " << math_type;
} }
} }
#define DEFINE_GEMM_FUNC(T, cublas_func) \
template <> \
DRAGON_API void Gemm<T, CUDAContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int M, \
const int N, \
const int K, \
const float alpha, \
const T* A, \
const T* B, \
const float beta, \
T* C, \
CUDAContext* ctx) { \
int lda = TransA == CblasNoTrans ? K : M; \
int ldb = TransB == CblasNoTrans ? N : K; \
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
const auto alpha_val = static_cast<T>(alpha); \
const auto beta_val = static_cast<T>(beta); \
CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func( \
ctx->cublas_handle(), \
cuTransB, \
cuTransA, \
N, \
M, \
K, \
&alpha_val, \
B, \
ldb, \
A, \
lda, \
&beta_val, \
C, \
N)); \
}
DEFINE_GEMM_FUNC(float, cublasSgemm);
DEFINE_GEMM_FUNC(double, cublasDgemm);
#undef DEFINE_GEMM_FUNC
template <> template <>
DRAGON_API void Gemm<float, CUDAContext>( DRAGON_API void GemmBatched<float16, CUDAContext>(
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M, const int M,
const int N, const int N,
const int K, const int K,
const float alpha, const float alpha,
const float* A, const float16** A,
const float* B, const float16** B,
const float beta, const float beta,
float* C, float16** C,
CUDAContext* ctx, CUDAContext* ctx) {
const string math_type) {
int lda = TransA == CblasNoTrans ? K : M; int lda = TransA == CblasNoTrans ? K : M;
int ldb = TransB == CblasNoTrans ? N : K; int ldb = TransB == CblasNoTrans ? N : K;
int ldc = N;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
thrust::device_vector<const void*> A_arr(A, A + batch_size);
thrust::device_vector<const void*> B_arr(B, B + batch_size);
thrust::device_vector<void*> C_arr(C, C + batch_size);
CUBLAS_CHECK( CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSgemm( CUBLAS_CHECK(cublasGemmBatchedEx(
ctx->cublas_handle(), ctx->cublas_handle(),
cuTransB, cuTransB,
cuTransA, cuTransA,
...@@ -826,54 +681,172 @@ DRAGON_API void Gemm<float, CUDAContext>( ...@@ -826,54 +681,172 @@ DRAGON_API void Gemm<float, CUDAContext>(
M, M,
K, K,
&alpha, &alpha,
B, B_arr.data().get(),
CUDA_R_16F,
ldb, ldb,
A, A_arr.data().get(),
CUDA_R_16F,
lda, lda,
&beta, &beta,
C, C_arr.data().get(),
N)); CUDA_R_16F,
ldc,
batch_size,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} }
#define DEFINE_BATCHED_GEMM_FUNC(T, cublas_func) \
template <> \
DRAGON_API void GemmBatched<T, CUDAContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const float alpha, \
const T** A, \
const T** B, \
const float beta, \
T** C, \
CUDAContext* ctx) { \
int lda = TransA == CblasNoTrans ? K : M; \
int ldb = TransB == CblasNoTrans ? N : K; \
int ldc = N; \
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
const auto alpha_val = static_cast<T>(alpha); \
const auto beta_val = static_cast<T>(beta); \
thrust::device_vector<const T*> A_arr(A, A + batch_size); \
thrust::device_vector<const T*> B_arr(B, B + batch_size); \
thrust::device_vector<T*> C_arr(C, C + batch_size); \
CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func( \
ctx->cublas_handle(), \
cuTransB, \
cuTransA, \
N, \
M, \
K, \
&alpha_val, \
B_arr.data().get(), \
ldb, \
A_arr.data().get(), \
lda, \
&beta_val, \
C_arr.data().get(), \
ldc, \
batch_size)); \
}
DEFINE_BATCHED_GEMM_FUNC(float, cublasSgemmBatched);
DEFINE_BATCHED_GEMM_FUNC(double, cublasDgemmBatched);
#undef DEFINE_BATCHED_GEMM_FUNC
template <> template <>
DRAGON_API void Gemm<double, CUDAContext>( DRAGON_API void GemmStridedBatched<float16, CUDAContext>(
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M, const int M,
const int N, const int N,
const int K, const int K,
const int A_stride,
const int B_stride,
const int C_stride,
const float alpha, const float alpha,
const double* A, const float16* A,
const double* B, const float16* B,
const float beta, const float beta,
double* C, float16* C,
CUDAContext* ctx, CUDAContext* ctx) {
const string math_type) { int lda = TransA == CblasNoTrans ? K : M;
int lda = (TransA == CblasNoTrans) ? K : M; int ldb = TransB == CblasNoTrans ? N : K;
int ldb = (TransB == CblasNoTrans) ? N : K; int ldc = N;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
const auto alpha_val = static_cast<double>(alpha);
const auto beta_val = static_cast<double>(beta);
CUBLAS_CHECK( CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasDgemm( CUBLAS_CHECK(cublasGemmStridedBatchedEx(
ctx->cublas_handle(), ctx->cublas_handle(),
cuTransB, cuTransB,
cuTransA, cuTransA,
N, N,
M, M,
K, K,
&alpha_val, &alpha,
B, B,
CUDA_R_16F,
ldb, ldb,
B_stride,
A, A,
CUDA_R_16F,
lda, lda,
&beta_val, A_stride,
&beta,
C, C,
N)); CUDA_R_16F,
ldc,
C_stride,
batch_size,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} }
#define DEFINE_STRIDED_BATCHED_GEMM_FUNC(T, cublas_func) \
template <> \
DRAGON_API void GemmStridedBatched<T, CUDAContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const int A_stride, \
const int B_stride, \
const int C_stride, \
const float alpha, \
const T* A, \
const T* B, \
const float beta, \
T* C, \
CUDAContext* ctx) { \
int lda = TransA == CblasNoTrans ? K : M; \
int ldb = TransB == CblasNoTrans ? N : K; \
int ldc = N; \
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T; \
const auto alpha_val = static_cast<T>(alpha); \
const auto beta_val = static_cast<T>(beta); \
CUBLAS_CHECK( \
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \
CUBLAS_CHECK(cublas_func( \
ctx->cublas_handle(), \
cuTransB, \
cuTransA, \
N, \
M, \
K, \
&alpha_val, \
B, \
ldb, \
B_stride, \
A, \
lda, \
A_stride, \
&beta_val, \
C, \
ldc, \
C_stride, \
batch_size)); \
}
DEFINE_STRIDED_BATCHED_GEMM_FUNC(float, cublasSgemmStridedBatched);
DEFINE_STRIDED_BATCHED_GEMM_FUNC(double, cublasDgemmStridedBatched);
#undef DEFINE_STRIDED_BATCHED_GEMM_FUNC
} // namespace math } // namespace math
} // namespace dragon } // namespace dragon
......
...@@ -85,8 +85,7 @@ DRAGON_API void Gemv( ...@@ -85,8 +85,7 @@ DRAGON_API void Gemv(
const T* x, const T* x,
const float beta, const float beta,
T* y, T* y,
Context* ctx, Context* ctx);
const string math_type = "float32");
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void Gemm( DRAGON_API void Gemm(
...@@ -100,8 +99,40 @@ DRAGON_API void Gemm( ...@@ -100,8 +99,40 @@ DRAGON_API void Gemm(
const T* B, const T* B,
const float beta, const float beta,
T* C, T* C,
Context* ctx, Context* ctx);
const string math_type = "float32");
template <typename T, class Context>
DRAGON_API void GemmBatched(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
const T** A,
const T** B,
const float beta,
T** C,
Context* ctx);
template <typename T, class Context>
DRAGON_API void GemmStridedBatched(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const int A_stride,
const int B_stride,
const int C_stride,
const float alpha,
const T* A,
const T* B,
const float beta,
T* C,
Context* ctx);
} // namespace math } // namespace math
......
...@@ -158,15 +158,15 @@ __global__ void _BroadcastWhere( ...@@ -158,15 +158,15 @@ __global__ void _BroadcastWhere(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_SET_FUNC(T1, T2) \ #define DEFINE_SET_FUNC(T, ScalarT) \
template <> \ template <> \
DRAGON_API void Set<T1, CUDAContext>( \ DRAGON_API void Set<T, CUDAContext>( \
const int x_ndim, \ const int x_ndim, \
const int64_t* x_dims, \ const int64_t* x_dims, \
const int y_ndim, \ const int y_ndim, \
const int64_t* y_dims, \ const int64_t* y_dims, \
const T1* x, \ const T* x, \
T1* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
int rows, cols; \ int rows, cols; \
vec64_t X_dims(x_dims, x_dims + x_ndim); \ vec64_t X_dims(x_dims, x_dims + x_ndim); \
...@@ -189,8 +189,8 @@ __global__ void _BroadcastWhere( ...@@ -189,8 +189,8 @@ __global__ void _BroadcastWhere(
ctx->cuda_stream()>>>( \ ctx->cuda_stream()>>>( \
nthreads, \ nthreads, \
cols, \ cols, \
reinterpret_cast<const T2*>(x), \ reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<T2*>(y)); \ reinterpret_cast<ScalarT*>(y)); \
return; \ return; \
} \ } \
if (math::utils::IsColwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \ if (math::utils::IsColwiseBroadcast(X_dims, Y_dims, &rows, &cols)) { \
...@@ -202,8 +202,8 @@ __global__ void _BroadcastWhere( ...@@ -202,8 +202,8 @@ __global__ void _BroadcastWhere(
ctx->cuda_stream()>>>( \ ctx->cuda_stream()>>>( \
nthreads, \ nthreads, \
cols, \ cols, \
reinterpret_cast<const T2*>(x), \ reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<T2*>(y)); \ reinterpret_cast<ScalarT*>(y)); \
return; \ return; \
} \ } \
vec64_t X_broadcast_strides, _; \ vec64_t X_broadcast_strides, _; \
...@@ -226,8 +226,8 @@ __global__ void _BroadcastWhere( ...@@ -226,8 +226,8 @@ __global__ void _BroadcastWhere(
Y_dims.size(), \ Y_dims.size(), \
strides, \ strides, \
dims, \ dims, \
reinterpret_cast<const T2*>(x), \ reinterpret_cast<const ScalarT*>(x), \
reinterpret_cast<T2*>(y)); \ reinterpret_cast<ScalarT*>(y)); \
} }
DEFINE_SET_FUNC(bool, uint8_t); DEFINE_SET_FUNC(bool, uint8_t);
...@@ -235,8 +235,8 @@ DEFINE_SET_FUNC(int8_t, int8_t); ...@@ -235,8 +235,8 @@ DEFINE_SET_FUNC(int8_t, int8_t);
DEFINE_SET_FUNC(uint8_t, uint8_t); DEFINE_SET_FUNC(uint8_t, uint8_t);
DEFINE_SET_FUNC(int, int); DEFINE_SET_FUNC(int, int);
DEFINE_SET_FUNC(int64_t, int64_t); DEFINE_SET_FUNC(int64_t, int64_t);
DEFINE_SET_FUNC(float, float);
DEFINE_SET_FUNC(float16, half); DEFINE_SET_FUNC(float16, half);
DEFINE_SET_FUNC(float, float);
DEFINE_SET_FUNC(double, double); DEFINE_SET_FUNC(double, double);
#undef DEFINE_SET_FUNC #undef DEFINE_SET_FUNC
...@@ -267,13 +267,31 @@ DEFINE_SET_FUNC(double, double); ...@@ -267,13 +267,31 @@ DEFINE_SET_FUNC(double, double);
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \ A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
const auto nthreads = rows * cols; \ const auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \ if (broadcast_1st > 0) { \
_RowwiseBinaryFunc<InputT, OutputT, Functor<InputT>, true> \ _RowwiseBinaryFunc< \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ <<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Functor<InputT>(), a, b, y); \ nthreads, \
cols, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} else { \ } else { \
_RowwiseBinaryFunc<InputT, OutputT, Functor<InputT>, false> \ _RowwiseBinaryFunc< \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ <<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Functor<InputT>(), a, b, y); \ nthreads, \
cols, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} \ } \
return; \ return; \
} \ } \
...@@ -281,13 +299,31 @@ DEFINE_SET_FUNC(double, double); ...@@ -281,13 +299,31 @@ DEFINE_SET_FUNC(double, double);
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \ A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
const auto nthreads = rows * cols; \ const auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \ if (broadcast_1st > 0) { \
_ColwiseBinaryFunc<InputT, OutputT, Functor<InputT>, true> \ _ColwiseBinaryFunc< \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ <<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Functor<InputT>(), a, b, y); \ nthreads, \
cols, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} else { \ } else { \
_ColwiseBinaryFunc<InputT, OutputT, Functor<InputT>, false> \ _ColwiseBinaryFunc< \
math::ScalarType<InputT>::type, \
math::ScalarType<OutputT>::type, \
Functor<math::ScalarType<InputT>::type>, \
false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ <<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Functor<InputT>(), a, b, y); \ nthreads, \
cols, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} \ } \
return; \ return; \
} \ } \
...@@ -304,9 +340,9 @@ DEFINE_SET_FUNC(double, double); ...@@ -304,9 +340,9 @@ DEFINE_SET_FUNC(double, double);
y_dims.data[i] = Y_dims[i]; \ y_dims.data[i] = Y_dims[i]; \
} \ } \
_BroadcastBinaryFunc< \ _BroadcastBinaryFunc< \
InputT, \ math::ScalarType<InputT>::type, \
OutputT, \ math::ScalarType<OutputT>::type, \
Functor<InputT>, \ Functor<math::ScalarType<InputT>::type>, \
CUDA_TENSOR_MAX_DIMS> \ CUDA_TENSOR_MAX_DIMS> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ <<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \ nthreads, \
...@@ -314,108 +350,121 @@ DEFINE_SET_FUNC(double, double); ...@@ -314,108 +350,121 @@ DEFINE_SET_FUNC(double, double);
a_strides, \ a_strides, \
b_strides, \ b_strides, \
y_dims, \ y_dims, \
Functor<InputT>(), \ Functor<math::ScalarType<InputT>::type>(), \
a, \ reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
b, \ reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
y); \ reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} }
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float16, float16, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, float16, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float16, float16, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float16, float16, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor); DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor); DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, float16, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, float16, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, int8_t, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, int8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, uint8_t, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, uint8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, int, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int64_t, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, int64_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float16, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, float, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, double, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, double, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, int, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float16, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, float, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, double, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, double, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, int8_t, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, int8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, uint8_t, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, uint8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, int, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int64_t, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, int64_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float16, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, float, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, double, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, double, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, int, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float16, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, float, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, double, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, double, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, int8_t, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, int8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, uint8_t, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, uint8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, int, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int64_t, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, int64_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float16, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, float, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, double, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, double, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, int, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float16, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, float, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC #undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, T, dtype) \ #define DEFINE_BINARY_FUNC(name, T, ScalarT) \
template <> \ template <> \
DRAGON_API void name<T, CUDAContext>( \ DRAGON_API void name<T, CUDAContext>( \
const int a_ndim, \ const int a_ndim, \
const int64_t* a_dims, \ const int64_t* a_dims, \
const int b_ndim, \ const int b_ndim, \
const int64_t* b_dims, \ const int64_t* b_dims, \
const T* a, \ const T* a, \
const T* b, \ const T* b, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
name( \ name( \
a_ndim, \ a_ndim, \
a_dims, \ a_dims, \
b_ndim, \ b_ndim, \
b_dims, \ b_dims, \
reinterpret_cast<const dtype*>(a), \ reinterpret_cast<const ScalarT*>(a), \
reinterpret_cast<const dtype*>(b), \ reinterpret_cast<const ScalarT*>(b), \
reinterpret_cast<dtype*>(y), \ reinterpret_cast<ScalarT*>(y), \
ctx); \ ctx); \
} }
DEFINE_BINARY_FUNC(Add, bool, uint8_t); // Or DEFINE_BINARY_FUNC(Add, bool, uint8_t); // Or
...@@ -423,130 +472,19 @@ DEFINE_BINARY_FUNC(Sub, bool, uint8_t); // Xor ...@@ -423,130 +472,19 @@ DEFINE_BINARY_FUNC(Sub, bool, uint8_t); // Xor
DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
#undef DEFINE_BINARY_FUNC #undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, OutputT1, OutputT2, Functor) \ #define DEFINE_WHERE_FUNC(T, ScalarT) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int a_ndim, \
const int64_t* a_dims, \
const int b_ndim, \
const int64_t* b_dims, \
const float16* a, \
const float16* b, \
OutputT1* y, \
CUDAContext* ctx) { \
int rows, cols, broadcast_1st; \
vec64_t A_dims(a_dims, a_dims + a_ndim); \
vec64_t B_dims(b_dims, b_dims + b_ndim); \
vec64_t A_broadcast_dims, B_broadcast_dims; \
math::utils::ComputeBinaryBroadcastDims( \
A_dims, B_dims, A_broadcast_dims, B_broadcast_dims); \
if (A_broadcast_dims == B_broadcast_dims) { \
auto count = std::accumulate( \
a_dims, a_dims + a_ndim, 1, std::multiplies<int64_t>()); \
name(count, a, b, y, ctx); \
return; \
} \
if (math::utils::IsRowwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_RowwiseBinaryFunc<half, OutputT2, Functor<half>, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
} else { \
_RowwiseBinaryFunc<half, OutputT2, Functor<half>, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
} \
return; \
} \
if (math::utils::IsColwiseBroadcast( \
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_ColwiseBinaryFunc<half, OutputT2, Functor<half>, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
} else { \
_ColwiseBinaryFunc<half, OutputT2, Functor<half>, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
} \
return; \
} \
vec64_t A_broadcast_strides, B_broadcast_strides, Y_dims; \
math::utils::ComputeBinaryBroadcastStrides( \
A_dims, B_dims, A_broadcast_strides, B_broadcast_strides, Y_dims); \
CUDA_TENSOR_DIMS_CHECK((int)Y_dims.size()); \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> a_strides, b_strides, y_dims; \
const auto nthreads = std::accumulate( \
Y_dims.begin(), Y_dims.end(), 1, std::multiplies<int64_t>()); \
for (int i = 0; i < Y_dims.size(); ++i) { \
a_strides.data[i] = A_broadcast_strides[i]; \
b_strides.data[i] = B_broadcast_strides[i]; \
y_dims.data[i] = Y_dims[i]; \
} \
_BroadcastBinaryFunc<half, OutputT2, Functor<half>, CUDA_TENSOR_MAX_DIMS> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
Y_dims.size(), \
a_strides, \
b_strides, \
y_dims, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<OutputT2*>(y)); \
}
DEFINE_BINARY_FUNC(Add, float16, half, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, half, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, float16, half, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, float16, half, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, half, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, half, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, half, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, bool, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, bool, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, bool, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, bool, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, bool, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_WHERE_FUNC(T1, T2) \
template <> \ template <> \
DRAGON_API void Where<T1, CUDAContext>( \ DRAGON_API void Where<T, CUDAContext>( \
const int a_ndim, \ const int a_ndim, \
const int64_t* a_dims, \ const int64_t* a_dims, \
const int b_ndim, \ const int b_ndim, \
const int64_t* b_dims, \ const int64_t* b_dims, \
const int c_ndim, \ const int c_ndim, \
const int64_t* c_dims, \ const int64_t* c_dims, \
const T1* a, \ const T* a, \
const T1* b, \ const T* b, \
const bool* c, \ const bool* c, \
T1* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
vec64_t A_dims(a_dims, a_dims + a_ndim); \ vec64_t A_dims(a_dims, a_dims + a_ndim); \
vec64_t B_dims(b_dims, b_dims + b_ndim); \ vec64_t B_dims(b_dims, b_dims + b_ndim); \
...@@ -597,10 +535,10 @@ DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, math::GreaterEqualFunctor); ...@@ -597,10 +535,10 @@ DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, math::GreaterEqualFunctor);
b_strides, \ b_strides, \
c_strides, \ c_strides, \
y_dims, \ y_dims, \
reinterpret_cast<const T2*>(a), \ reinterpret_cast<const ScalarT*>(a), \
reinterpret_cast<const T2*>(b), \ reinterpret_cast<const ScalarT*>(b), \
reinterpret_cast<const uint8_t*>(c), \ reinterpret_cast<const uint8_t*>(c), \
reinterpret_cast<T2*>(y)); \ reinterpret_cast<ScalarT*>(y)); \
} }
DEFINE_WHERE_FUNC(bool, uint8_t); DEFINE_WHERE_FUNC(bool, uint8_t);
......
...@@ -24,21 +24,21 @@ void _Cast(const int n, const InputT* x, OutputT* y) { ...@@ -24,21 +24,21 @@ void _Cast(const int n, const InputT* x, OutputT* y) {
#define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void Cast<InputT, OutputT, CPUContext>( \ DRAGON_API void Cast<InputT, OutputT, CPUContext>( \
const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \ const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \
_Cast(n, x, y); \ _Cast(n, x, y); \
} }
#define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void Cast<InputT, OutputT, CPUContext>( \ DRAGON_API void Cast<InputT, OutputT, CPUContext>( \
const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \ const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \ LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<InputT>()) << " -> " \ << types::to_string(TypeMeta::Make<InputT>()) << " -> " \
<< types::to_string(TypeMeta::Make<OutputT>()); \ << types::to_string(TypeMeta::Make<OutputT>()); \
} \ } \
template <> \ template <> \
void Cast<OutputT, InputT, CPUContext>( \ DRAGON_API void Cast<OutputT, InputT, CPUContext>( \
const int n, const OutputT* x, InputT* y, CPUContext* ctx) { \ const int n, const OutputT* x, InputT* y, CPUContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \ LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<OutputT>()) << " -> " \ << types::to_string(TypeMeta::Make<OutputT>()) << " -> " \
......
...@@ -23,7 +23,7 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) { ...@@ -23,7 +23,7 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) {
#define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void Cast<InputT, OutputT, CUDAContext>( \ DRAGON_API void Cast<InputT, OutputT, CUDAContext>( \
const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \ const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \
_Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, \ n, \
...@@ -33,14 +33,14 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) { ...@@ -33,14 +33,14 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) {
#define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void Cast<InputT, OutputT, CUDAContext>( \ DRAGON_API void Cast<InputT, OutputT, CUDAContext>( \
const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \ const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \ LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<InputT>()) << " -> " \ << types::to_string(TypeMeta::Make<InputT>()) << " -> " \
<< types::to_string(TypeMeta::Make<OutputT>()); \ << types::to_string(TypeMeta::Make<OutputT>()); \
} \ } \
template <> \ template <> \
void Cast<OutputT, InputT, CUDAContext>( \ DRAGON_API void Cast<OutputT, InputT, CUDAContext>( \
const int n, const OutputT* x, InputT* y, CUDAContext* ctx) { \ const int n, const OutputT* x, InputT* y, CUDAContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \ LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<OutputT>()) << " -> " \ << types::to_string(TypeMeta::Make<OutputT>()) << " -> " \
......
...@@ -599,14 +599,14 @@ DEFINE_POWX_FUNC(double); ...@@ -599,14 +599,14 @@ DEFINE_POWX_FUNC(double);
#define DEFINE_NOT_ZERO_FUNC(T) \ #define DEFINE_NOT_ZERO_FUNC(T) \
template <> \ template <> \
void NotZero<T, CUDAContext>( \ DRAGON_API void NotZero<T, CUDAContext>( \
const int count, const T* x, bool* y, CUDAContext* ctx) { \ const int count, const T* x, bool* y, CUDAContext* ctx) { \
_NotZero<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _NotZero<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, x, y); \ count, x, y); \
} }
template <> template <>
void NotZero<float16, CUDAContext>( DRAGON_API void NotZero<float16, CUDAContext>(
const int count, const int count,
const float16* x, const float16* x,
bool* y, bool* y,
...@@ -742,106 +742,124 @@ DEFINE_BIAS_FUNC(float); ...@@ -742,106 +742,124 @@ DEFINE_BIAS_FUNC(float);
DEFINE_BIAS_FUNC(double); DEFINE_BIAS_FUNC(double);
#undef DEFINE_BIAS_FUNC #undef DEFINE_BIAS_FUNC
#define DEFINE_BINARY_FUNC(name, InputT, OutputT, Op) \ #define DEFINE_BINARY_FUNC(name, InputT, OutputT, Functor) \
template <> \ template <> \
DRAGON_API void name<InputT, CUDAContext>( \ DRAGON_API void name<InputT, CUDAContext>( \
const int n, \ const int n, \
const InputT* a, \ const InputT* a, \
const InputT* b, \ const InputT* b, \
OutputT* y, \ OutputT* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_SimpleBinaryFunc<<< \ _SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \ CUDA_BLOCKS(n), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>(n, Op<InputT>(), a, b, y); \ ctx->cuda_stream()>>>( \
n, \
Functor<math::ScalarType<InputT>::type>(), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(a), \
reinterpret_cast<const math::ScalarType<InputT>::type*>(b), \
reinterpret_cast<math::ScalarType<OutputT>::type*>(y)); \
} }
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float16, float16, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor); DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, float16, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor); DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float16, float16, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor); DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float16, float16, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor); DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, float16, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor); DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor); DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, float16, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor); DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, float16, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor); DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, int8_t, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, int8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, uint8_t, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, uint8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, int, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int64_t, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, int64_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float16, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, float, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, double, bool, math::EqualFunctor); DEFINE_BINARY_FUNC(Equal, double, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, int, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float16, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, float, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, double, bool, math::NotEqualFunctor); DEFINE_BINARY_FUNC(NotEqual, double, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, int8_t, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, int8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, uint8_t, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, uint8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, int, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int64_t, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, int64_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float16, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, float, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, double, bool, math::LessFunctor); DEFINE_BINARY_FUNC(Less, double, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, int, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float16, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, float, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, double, bool, math::LessEqualFunctor); DEFINE_BINARY_FUNC(LessEqual, double, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, int8_t, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, int8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, uint8_t, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, uint8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, int, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int64_t, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, int64_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float16, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, float, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, double, bool, math::GreaterFunctor); DEFINE_BINARY_FUNC(Greater, double, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, int, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float16, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, float, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor); DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC #undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, T, dtype) \ #define DEFINE_BINARY_FUNC(name, T, ScalarT) \
template <> \ template <> \
DRAGON_API void name<T, CUDAContext>( \ DRAGON_API void name<T, CUDAContext>( \
const int n, const T* a, const T* b, T* y, CUDAContext* ctx) { \ const int n, const T* a, const T* b, T* y, CUDAContext* ctx) { \
name( \ name( \
n, \ n, \
reinterpret_cast<const dtype*>(a), \ reinterpret_cast<const ScalarT*>(a), \
reinterpret_cast<const dtype*>(b), \ reinterpret_cast<const ScalarT*>(b), \
reinterpret_cast<dtype*>(y), \ reinterpret_cast<ScalarT*>(y), \
ctx); \ ctx); \
} }
...@@ -850,76 +868,6 @@ DEFINE_BINARY_FUNC(Sub, bool, uint8_t); // Xor ...@@ -850,76 +868,6 @@ DEFINE_BINARY_FUNC(Sub, bool, uint8_t); // Xor
DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
#undef DEFINE_BINARY_FUNC #undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, \
const float16* a, \
const float16* b, \
float16* y, \
CUDAContext* ctx) { \
if ((n & 1) == 0) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n >> 1), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n >> 1, \
Functor<half2>(), \
reinterpret_cast<const half2*>(a), \
reinterpret_cast<const half2*>(b), \
reinterpret_cast<half2*>(y)); \
} else { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<half*>(y)); \
} \
}
DEFINE_BINARY_FUNC(Add, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, math::MaxFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, \
const float16* a, \
const float16* b, \
bool* y, \
CUDAContext* ctx) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
y); \
}
DEFINE_BINARY_FUNC(Equal, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_WHERE_FUNC(T) \ #define DEFINE_WHERE_FUNC(T) \
template <> \ template <> \
DRAGON_API void Where<T, CUDAContext>( \ DRAGON_API void Where<T, CUDAContext>( \
......
...@@ -217,18 +217,18 @@ DEFINE_REDUCE_FUNC(Sum); ...@@ -217,18 +217,18 @@ DEFINE_REDUCE_FUNC(Sum);
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name) \ #define DEFINE_KERNEL_LAUNCHER(name) \
template <> \ template <> \
void Reduce##name<float16, CPUContext>( \ DRAGON_API void Reduce##name<float16, CPUContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
const int* axes, \ const int* axes, \
const float scale, \ const float scale, \
const float16* x, \ const float16* x, \
float16* y, \ float16* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \ CPU_FP16_NOT_SUPPORTED; \
} }
DEFINE_KERNEL_LAUNCHER(Max); DEFINE_KERNEL_LAUNCHER(Max);
...@@ -258,7 +258,7 @@ DRAGON_API float16 Sum<float16, CPUContext>( ...@@ -258,7 +258,7 @@ DRAGON_API float16 Sum<float16, CPUContext>(
#define DEFINE_KERNEL_LAUNCHER(name, T) \ #define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void Reduce##name<T, CPUContext>( \ DRAGON_API void Reduce##name<T, CPUContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
...@@ -298,7 +298,7 @@ DEFINE_KERNEL_LAUNCHER(Sum, double); ...@@ -298,7 +298,7 @@ DEFINE_KERNEL_LAUNCHER(Sum, double);
*y = val * T(scale); \ *y = val * T(scale); \
} \ } \
template <> \ template <> \
T Sum<T, CPUContext>( \ DRAGON_API T Sum<T, CPUContext>( \
const int n, const float scale, const T* x, CPUContext* ctx) { \ const int n, const float scale, const T* x, CPUContext* ctx) { \
T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \ T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \
return val * T(scale); \ return val * T(scale); \
......
...@@ -174,7 +174,7 @@ DEFINE_REDUCE_DISPATCHER(Sum); ...@@ -174,7 +174,7 @@ DEFINE_REDUCE_DISPATCHER(Sum);
// We found that FP16 accumulator drops too many small values in // We found that FP16 accumulator drops too many small values in
// empirical experiments. // empirical experiments.
template <> template <>
void ReduceSum<float16, CUDAContext>( DRAGON_API void ReduceSum<float16, CUDAContext>(
const int num_dims, const int num_dims,
const int* dims, const int* dims,
const int num_axes, const int num_axes,
...@@ -199,7 +199,7 @@ void ReduceSum<float16, CUDAContext>( ...@@ -199,7 +199,7 @@ void ReduceSum<float16, CUDAContext>(
#define DEFINE_KERNEL_LAUNCHER(name, T, AccT, Reducer, kInit) \ #define DEFINE_KERNEL_LAUNCHER(name, T, AccT, Reducer, kInit) \
template <> \ template <> \
void Reduce##name<T, CUDAContext>( \ DRAGON_API void Reduce##name<T, CUDAContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
......
...@@ -174,7 +174,8 @@ void ComputeBinaryBroadcastDims( ...@@ -174,7 +174,8 @@ void ComputeBinaryBroadcastDims(
const vec64_t& A_dims, const vec64_t& A_dims,
const vec64_t& B_dims, const vec64_t& B_dims,
vec64_t& A_broadcast_dims, vec64_t& A_broadcast_dims,
vec64_t& B_broadcast_dims) { vec64_t& B_broadcast_dims,
int64_t* C_broadcast_dims) {
auto num_dims = std::max(A_dims.size(), B_dims.size()); auto num_dims = std::max(A_dims.size(), B_dims.size());
A_broadcast_dims.resize(num_dims); A_broadcast_dims.resize(num_dims);
B_broadcast_dims.resize(num_dims); B_broadcast_dims.resize(num_dims);
...@@ -194,6 +195,16 @@ void ComputeBinaryBroadcastDims( ...@@ -194,6 +195,16 @@ void ComputeBinaryBroadcastDims(
B_dims.begin(), B_dims.begin(),
B_dims.end(), B_dims.end(),
B_broadcast_dims.begin() + num_dims - B_dims.size()); B_broadcast_dims.begin() + num_dims - B_dims.size());
if (C_broadcast_dims != nullptr) {
for (int i = 0; i < num_dims; ++i) {
if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) {
C_broadcast_dims[i] = 0;
} else {
C_broadcast_dims[i] =
std::max(A_broadcast_dims[i], B_broadcast_dims[i]);
}
}
}
} }
void ComputeBinaryBroadcastStrides( void ComputeBinaryBroadcastStrides(
......
...@@ -304,7 +304,8 @@ DRAGON_API void ComputeBinaryBroadcastDims( ...@@ -304,7 +304,8 @@ DRAGON_API void ComputeBinaryBroadcastDims(
const vec64_t& A_dims, const vec64_t& A_dims,
const vec64_t& B_dims, const vec64_t& B_dims,
vec64_t& A_broadcast_dims, vec64_t& A_broadcast_dims,
vec64_t& B_broadcast_dims); vec64_t& B_broadcast_dims,
int64_t* C_broadcast_dims = nullptr);
DRAGON_API void ComputeBinaryBroadcastStrides( DRAGON_API void ComputeBinaryBroadcastStrides(
const vec64_t& A_dims, const vec64_t& A_dims,
...@@ -326,22 +327,22 @@ DRAGON_API void TransposeAxesForReduce( ...@@ -326,22 +327,22 @@ DRAGON_API void TransposeAxesForReduce(
const int* reduce_axes, const int* reduce_axes,
int* transpose_axes); int* transpose_axes);
template <typename dim_t, typename stride_t> template <typename DimT, typename StrideT>
inline void inline void
ComputeStrides(const int num_dims, const dim_t* dims, stride_t* strides) { ComputeStrides(const int num_dims, const DimT* dims, StrideT* strides) {
int64_t cur_stride = 1; int64_t cur_stride = 1;
for (int i = num_dims - 1; i >= 0; --i) { for (int i = num_dims - 1; i >= 0; --i) {
strides[i] = stride_t(cur_stride); strides[i] = StrideT(cur_stride);
cur_stride *= int64_t(dims[i]); cur_stride *= int64_t(dims[i]);
} }
} }
template <typename dim_t, typename axis_t, typename stride_t> template <typename DimT, typename AxisT, typename StrideT>
inline void ComputeTransposeStrides( inline void ComputeTransposeStrides(
const int num_dims, const int num_dims,
const dim_t* dims, const DimT* dims,
const axis_t* axes, const AxisT* axes,
stride_t* strides) { StrideT* strides) {
vec64_t buf(num_dims); vec64_t buf(num_dims);
int64_t cur_stride = 1; int64_t cur_stride = 1;
for (int i = num_dims - 1; i >= 0; --i) { for (int i = num_dims - 1; i >= 0; --i) {
...@@ -349,13 +350,25 @@ inline void ComputeTransposeStrides( ...@@ -349,13 +350,25 @@ inline void ComputeTransposeStrides(
cur_stride *= int64_t(dims[i]); cur_stride *= int64_t(dims[i]);
} }
for (int i = 0; i < num_dims; ++i) { for (int i = 0; i < num_dims; ++i) {
strides[i] = stride_t(buf[axes[i]]); strides[i] = StrideT(buf[axes[i]]);
} }
} }
template <typename dim_t, typename index_t> template <typename DimT, typename IndexT>
inline IndexT
GetIndexFromDims(const int num_dims, const DimT* dims, IndexT* index) {
IndexT ret = 0;
for (int i = 0; i < num_dims; ++i) {
if (dims[i] > 1) {
ret = ret * dims[i] + index[i];
}
}
return ret;
}
template <typename DimT, typename IndexT>
inline void inline void
IncreaseIndexInDims(const int num_dims, const dim_t* dims, index_t* index) { IncreaseIndexInDims(const int num_dims, const DimT* dims, IndexT* index) {
for (int i = num_dims - 1; i >= 0; --i) { for (int i = num_dims - 1; i >= 0; --i) {
++index[i]; ++index[i];
if (index[i] >= dims[i]) { if (index[i] >= dims[i]) {
......
...@@ -116,11 +116,9 @@ class Dense(Layer): ...@@ -116,11 +116,9 @@ class Dense(Layer):
self.built = True self.built = True
def call(self, inputs): def call(self, inputs):
outputs = math_ops.fully_connected( outputs = math_ops.gemm(
[inputs, self.kernel] + [self.bias] [inputs, self.kernel] +
if self.use_bias else [], ([self.bias] if self.use_bias else []),
axis=-1,
transW=False,
) )
if self.activation is not None: if self.activation is not None:
return self.activation(outputs) return self.activation(outputs)
......
...@@ -703,38 +703,38 @@ def log(x, name=None): ...@@ -703,38 +703,38 @@ def log(x, name=None):
return math_ops.log(x, name=name) return math_ops.log(x, name=name)
def matmul( def matmul(a, b, name=None):
a,
b,
transpose_a=False,
transpose_b=False,
name=None,
):
r"""Compute the matrix multiplication. r"""Compute the matrix multiplication.
.. math:: y = a \times b .. math:: \text{out} = a \times b
The rank of ``a`` and ``b`` should be equal and >= 2: The behavior depends on the shape of input tensors:
```python * If both tensors are 1d, computes the vector product.
# Ok, a typical matrix multiplication * If tensors are 1d and >=2d, computes the vector-matrix multiplication.
a = tf.ones((2, 3), 'float32') * If tensors are >=2d and 1d, computes the matrix-vector multiplication.
b = tf.ones((3, 3), 'float32') * If both tensors are >= 2d, computes the matrix-matrix multiplication.
print(tf.linalg.matmul(a, b)) * If one tensor is >= 3d, applies batching and broadcasting to the computation.
# Compute a batch matrix multiplication if rank > 2 Examples:
aa = tf.ones((4, 2, 3), 'float32')
bb = tf.ones((4, 3, 3), 'float32')
print(tf.linalg.matmul(aa, bb))
```
If inputs are transposed, remember to transpose them back:
```python ```python
# Vector x Vector
a = tf.ones((2,), 'float32')
b = tf.ones((2,), 'float32')
print(tf.linalg.matmul(a, b))
# Vector x Matrix
a = tf.ones((2,), 'float32')
b = tf.ones((2, 3), 'float32')
print(tf.linalg.matmul(a, b))
# Matrix x Vector
a = tf.ones((3, 2), 'float32') a = tf.ones((3, 2), 'float32')
b = tf.ones((3, 3), 'float32') b = tf.ones((2,), 'float32')
print(tf.linalg.matmul(a, b)) # ``a`` takes the wrong dimensions print(tf.linalg.matmul(a, b))
print(tf.linalg.matmul(a, b, transpose_a=True)) # Ok # Matrix x Matrix
a = tf.ones((2, 3), 'float32')
b = tf.ones((3, 2), 'float32')
print(tf.linalg.matmul(a, b))
``` ```
Parameters Parameters
...@@ -743,10 +743,6 @@ def matmul( ...@@ -743,10 +743,6 @@ def matmul(
The matrix :math:`a`. The matrix :math:`a`.
b : dragon.Tensor b : dragon.Tensor
The matrix :math:`b`. The matrix :math:`b`.
transpose_a : bool, optional, default=False
**True** to transpose :math:`a` before computing.
transpose_b : bool, optional, default=False
**True** to transpose :math:`b` before computing.
name : str, optional name : str, optional
The operation name. The operation name.
...@@ -756,12 +752,7 @@ def matmul( ...@@ -756,12 +752,7 @@ def matmul(
The output tensor. The output tensor.
""" """
return math_ops.matmul( return math_ops.matmul([a, b], name=name)
[a, b],
transpose_a=transpose_a,
transpose_b=transpose_b,
name=name,
)
def multiply(x, y, name=None): def multiply(x, y, name=None):
......
...@@ -85,25 +85,25 @@ class Dense(layer.Layer): ...@@ -85,25 +85,25 @@ class Dense(layer.Layer):
raise AssertionError('The input dimension must be rank 2.' raise AssertionError('The input dimension must be rank 2.'
'Please reshape or flatten it.') 'Please reshape or flatten it.')
if self.in_channels: if self.in_channels:
shape = [self.n_units, self.in_channels] shape = [self.in_channels, self.n_units]
else: else:
self.in_channels = inputs_shape[1] self.in_channels = inputs_shape[1]
shape = [self.n_units, inputs_shape[1]] shape = [inputs_shape[1], self.n_units]
self.W = self.add_weight( self.W = self.add_weight(
name="weights", name='weights',
shape=shape, shape=shape,
init=self.W_init, init=self.W_init,
) )
if self.b_init: if self.b_init:
self.b = self.add_weight( self.b = self.add_weight(
name="biases", name='biases',
shape=[self.n_units], shape=[self.n_units],
init=self.b_init, init=self.b_init,
) )
def forward(self, inputs): def forward(self, inputs):
outputs = math_ops.fully_connected( outputs = math_ops.gemm(
[inputs, self.W] + ([self.b] if self.b_init else []), axis=1) [inputs, self.W] + ([self.b] if self.b_init else []))
if self.act: if self.act:
outputs = self.act(outputs) outputs = self.act(outputs)
return outputs return outputs
...@@ -281,17 +281,15 @@ class TestOpSpec(unittest.TestCase): ...@@ -281,17 +281,15 @@ class TestOpSpec(unittest.TestCase):
self.assertEqual(dragon.flatten( self.assertEqual(dragon.flatten(
self.sym4, axis=1, num_axes=-1).shape, (1, None)) self.sym4, axis=1, num_axes=-1).shape, (1, None))
def test_fully_connected(self): def test_gemm(self):
w = dragon.Tensor((3, 2)) w = dragon.Tensor((3, 2))
with dragon.graph_mode(): with dragon.graph_mode():
self.assertEqual(dragon.nn.fully_connected( self.assertEqual(dragon.math.gemm(
[self.sym1, w]).shape, (None, 3)) [self.sym1, w]).shape, None)
self.assertEqual(dragon.nn.fully_connected( self.assertEqual(dragon.math.gemm(
[self.sym1, w], transpose_w=False).shape, (None, 2)) [self.sym1, w], axis=1).shape, (None, 2))
self.assertEqual(dragon.nn.fully_connected( self.assertEqual(dragon.math.gemm(
[self.sym1, w], axis=-1).shape, None) [self.sym1, self.sym1]).shape, None)
self.assertEqual(dragon.nn.fully_connected(
[self.sym1, self.sym1]).shape, (None, None))
def test_index_select(self): def test_index_select(self):
with dragon.graph_mode(): with dragon.graph_mode():
...@@ -325,7 +323,9 @@ class TestOpSpec(unittest.TestCase): ...@@ -325,7 +323,9 @@ class TestOpSpec(unittest.TestCase):
self.assertEqual(dragon.math.matmul( self.assertEqual(dragon.math.matmul(
[self.sym1, self.sym3]).shape, None) [self.sym1, self.sym3]).shape, None)
self.assertEqual(dragon.math.matmul( self.assertEqual(dragon.math.matmul(
[self.sym2, self.sym3]).shape, None) [self.sym2, self.sym3]).shape, (None,))
self.assertEqual(dragon.math.matmul(
[self.sym3, self.sym2]).shape, (1,))
self.assertEqual(dragon.math.matmul( self.assertEqual(dragon.math.matmul(
[self.sym3, self.sym3]).shape, (1, None)) [self.sym3, self.sym3]).shape, (1, None))
self.assertEqual(dragon.math.matmul( self.assertEqual(dragon.math.matmul(
......
...@@ -1868,22 +1868,22 @@ class TestMathOps(OpTestCase): ...@@ -1868,22 +1868,22 @@ class TestMathOps(OpTestCase):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_floor() self.test_floor()
def test_fully_connected(self): def test_gemm(self):
entries = [((2, 3), (3, 4), (4,), False), entries = [((2, 3), (3, 4), (4,), False),
((2, 3), (4, 3), (4,), True)] ((2, 3), (4, 3), (4,), True)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
for x_shape, w_shape, b_shape, trans_w in entries: for x_shape, w_shape, b_shape, trans_b in entries:
data1, data2, data3 = arange(x_shape), arange(w_shape), arange(b_shape) data1, data2, data3 = arange(x_shape), arange(w_shape), arange(b_shape)
x, w, b = new_tensor(data1), new_tensor(data2), new_tensor(data3) x, w, b = new_tensor(data1), new_tensor(data2), new_tensor(data3)
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([x, w, b]) tape.watch([x, w, b])
y = dragon.nn.fully_connected([x, w, b], transpose_w=trans_w) y = dragon.math.gemm([x, w, b], transpose_b=trans_b)
data4 = arange(y.shape) data4 = arange(y.shape)
dy = new_tensor(data4) dy = new_tensor(data4)
dx, dw, db = tape.gradient(y, [x, w, b], output_gradients=[dy]) dx, dw, db = tape.gradient(y, [x, w, b], output_gradients=[dy])
result = np.matmul(data1, data2.T if trans_w else data2) + data3 result = np.matmul(data1, data2.T if trans_b else data2) + data3
if trans_w: if trans_b:
grad1 = np.matmul(data4, data2) grad1 = np.matmul(data4, data2)
grad2 = np.matmul(data4.T, data1) grad2 = np.matmul(data4.T, data1)
else: else:
...@@ -1894,9 +1894,9 @@ class TestMathOps(OpTestCase): ...@@ -1894,9 +1894,9 @@ class TestMathOps(OpTestCase):
[result, grad1, grad2, reduce_like(data4, data3)]) [result, grad1, grad2, reduce_like(data4, data3)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_fully_connected_cuda(self): def test_gemm_cuda(self):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_fully_connected() self.test_gemm()
def test_greater(self): def test_greater(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
...@@ -1997,40 +1997,62 @@ class TestMathOps(OpTestCase): ...@@ -1997,40 +1997,62 @@ class TestMathOps(OpTestCase):
self.test_log() self.test_log()
def test_matmul(self): def test_matmul(self):
entries = [ entries = [((2, 3), (3, 4)),
((2, 3), (3, 4), False, False), ((1, 2, 3), (2, 3, 4)),
((2, 3), (4, 3), False, True), ((2, 2, 3), (1, 3, 4)),
((3, 2), (3, 4), True, False), ((2, 2, 3), (2, 3, 4)),
((3, 2), (4, 3), True, True)] ((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for execution in ('EAGER_MODE', 'GRAPH_MODE',):
with execution_context().mode(execution):
for a_shape, b_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
a, b = new_tensor(data1), new_tensor(data2)
with dragon.GradientTape() as tape:
tape.watch([a, b])
y = dragon.math.matmul([a, b])
data3 = arange(y.shape)
dy = new_tensor(data3)
da, db = tape.gradient(y, [a, b], output_gradients=[dy])
grad1 = np.matmul(data3, transpose_last(data2, 2))
grad2 = np.matmul(transpose_last(data1, 2), data3)
self.assertEqual(
[y, da, db],
[np.matmul(data1, data2),
reduce_like(grad1, data1),
reduce_like(grad2, data2)])
entries = [((2,), (2,), (2, 1), (2, 1), (1, 1)),
((2,), (2, 3), (2, 1), (2, 3), (1, 3)),
((2, 3), (3,), (2, 3), (1, 3), (2, 1)),
((2,), (4, 2, 3), (1, 2, 1), (4, 2, 3), (4, 1, 3)),
((4, 2, 3), (3,), (4, 2, 3), (1, 1, 3), (4, 2, 1))]
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
for a_shape, b_shape, trans_a, trans_b in entries: for a_shape, b_shape, da_shape, db_shape, dy_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape) data1, data2 = arange(a_shape), arange(b_shape)
data4 = data1 if len(a_shape) > len(b_shape) else data2
a, b = new_tensor(data1), new_tensor(data2) a, b = new_tensor(data1), new_tensor(data2)
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([a, b]) tape.watch([a, b])
y = dragon.math.matmul([a, b], trans_a, trans_b) y = dragon.math.matmul([a, b])
data3 = arange(y.shape) data3 = arange(y.shape)
dy = new_tensor(data3) dy = new_tensor(data3)
da, db = tape.gradient(y, [a, b], output_gradients=[dy]) da, db = tape.gradient(y, [a, b], output_gradients=[dy])
if trans_a: grad1 = data3.reshape(dy_shape) * data2.reshape(db_shape)
if trans_b: grad2 = data1.reshape(da_shape) * data3.reshape(dy_shape)
grad1 = np.matmul(data2.T, data3.T) grad1_axes, grad2_axes = [], []
grad2 = np.matmul(data3.T, data1.T) for i in range(len(dy_shape)):
else: if da_shape[i] != db_shape[i]:
grad1 = np.matmul(data2, data3.T) if da_shape[i] == 1:
grad2 = np.matmul(data1, data3) grad1_axes.append(i)
else: if db_shape[i] == 1:
if trans_b: grad2_axes.append(i)
grad1 = np.matmul(data3, data2)
grad2 = np.matmul(data3.T, data1)
else:
grad1 = np.matmul(data3, data2.T)
grad2 = np.matmul(data1.T, data3)
self.assertEqual( self.assertEqual(
[y, da, db], [y, da, db],
[np.matmul(data1.T if trans_a else data1, [np.matmul(data1, data2),
data2.T if trans_b else data2), grad1, grad2]) reduce(grad1, tuple(grad1_axes)).reshape(data1.shape),
reduce(grad2, tuple(grad2_axes)).reshape(data2.shape)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_matmul_cuda(self): def test_matmul_cuda(self):
...@@ -4145,6 +4167,16 @@ def reduce_like(data, other, reduction='sum'): ...@@ -4145,6 +4167,16 @@ def reduce_like(data, other, reduction='sum'):
return data return data
def transpose_last(data, num_axes=None, axes=None):
"""Transpose the last axes of data."""
if axes is None and num_axes is not None:
axes = list(range(num_axes))[::-1]
perm = list(range(len(data.shape)))
start_axis = len(perm) - len(axes)
perm[start_axis:] = [v + start_axis for v in axes]
return np.transpose(data, perm)
def uniform(shape, dtype='float32'): def uniform(shape, dtype='float32'):
"""Return the uniform data with given shape.""" """Return the uniform data with given shape."""
return np.random.uniform(-1., 1., size=shape).astype(dtype) return np.random.uniform(-1., 1., size=shape).astype(dtype)
......
...@@ -619,39 +619,54 @@ class TestModules(OpTestCase): ...@@ -619,39 +619,54 @@ class TestModules(OpTestCase):
self.assertEqual(m4(x), np.pad(data, pads, 'constant')) self.assertEqual(m4(x), np.pad(data, pads, 'constant'))
def test_pool1d(self): def test_pool1d(self):
entries = [((2, 2, 2,), (2,), 2, 1, 'MAX'), entries = [((2, 2, 2,), (2,), 2, 1, 'MaxPool1d'),
((2, 2, 2,), (2,), 2, 1, 'AVG')] ((2, 2, 2,), (2,), 2, 1, 'AvgPool1d'),
((2, 2, 2,), (1,), 1, 0, 'AdaptiveMaxPool1d'),
((2, 2, 2,), (1,), 1, 0, 'AdaptiveAvgPool1d')]
for x_shape, kernel_shape, strides, pads, mode in entries: for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1 data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool1d if mode == 'AVG' else torch.nn.MaxPool1d module_cls = getattr(torch.nn, mode)
x = new_tensor(data) x = new_tensor(data)
m = module_cls(kernel_shape, strides, pads) if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m) y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.) result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result) self.assertEqual(y, result)
def test_pool2d(self): def test_pool2d(self):
entries = [((2, 2, 2, 2), (2, 2), 2, 1, 'MAX'), entries = [((2, 2, 2, 2), (2, 2), 2, 1, 'MaxPool2d'),
((2, 2, 2, 2), (2, 2), 2, 1, 'AVG')] ((2, 2, 2, 2), (2, 2), 2, 1, 'AvgPool2d'),
((2, 2, 2, 2), (1, 1), 1, 0, 'AdaptiveMaxPool2d'),
((2, 2, 2, 2), (1, 1), 1, 0, 'AdaptiveAvgPool2d')]
for x_shape, kernel_shape, strides, pads, mode in entries: for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1 data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool2d if mode == 'AVG' else torch.nn.MaxPool2d module_cls = getattr(torch.nn, mode)
x = new_tensor(data) x = new_tensor(data)
m = module_cls(kernel_shape, strides, pads) if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m) y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.) result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result) self.assertEqual(y, result)
def test_pool3d(self): def test_pool3d(self):
entries = [((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'MAX'), entries = [((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'MaxPool3d'),
((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'AVG')] ((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'AvgPool3d'),
((2, 2, 2, 2, 2), (1, 1, 1), 1, 0, 'AdaptiveMaxPool3d'),
((2, 2, 2, 2, 2), (1, 1, 1), 1, 0, 'AdaptiveAvgPool3d')]
for x_shape, kernel_shape, strides, pads, mode in entries: for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1 data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool3d if mode == 'AVG' else torch.nn.MaxPool3d module_cls = getattr(torch.nn, mode)
x = new_tensor(data) x = new_tensor(data)
m = module_cls(kernel_shape, strides, pads) if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m) y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.) result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result) self.assertEqual(y, result)
def test_prelu(self): def test_prelu(self):
......
...@@ -95,6 +95,16 @@ class TestTensorOps(OpTestCase): ...@@ -95,6 +95,16 @@ class TestTensorOps(OpTestCase):
a += b a += b
self.assertEqual(a, data1 + data2) self.assertEqual(a, data1 + data2)
def test_addmm(self):
entries = [((2, 3), (3, 4), (2, 4))]
for a_shape, b_shape, c_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
data3 = arange(c_shape)
a, b = new_tensor(data1), new_tensor(data2)
c = new_tensor(data3)
y = c.addmm(a, b)
self.assertEqual(y, np.matmul(data1, data2) + data3)
def test_argmax(self): def test_argmax(self):
entries = [(0, True), (0, False), (1, True), (1, False), (None, False)] entries = [(0, True), (0, False), (1, True), (1, False), (None, False)]
for axis, keepdims in entries: for axis, keepdims in entries:
...@@ -115,6 +125,18 @@ class TestTensorOps(OpTestCase): ...@@ -115,6 +125,18 @@ class TestTensorOps(OpTestCase):
result = np.expand_dims(result, axis) result = np.expand_dims(result, axis)
self.assertEqual(x.argmin(axis, keepdims), result) self.assertEqual(x.argmin(axis, keepdims), result)
def test_baddbmm(self):
entries = [((2, 2, 3), (2, 3, 4), (2, 2, 4))]
for a_shape, b_shape, c_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
data3 = arange(c_shape)
a, b = new_tensor(data1), new_tensor(data2)
c = new_tensor(data3)
y = c.baddbmm(a, b)
self.assertEqual(y, np.matmul(data1, data2) + data3)
c.baddbmm_(a, b)
self.assertEqual(c, np.matmul(data1, data2) + data3)
def test_bitwise_not(self): def test_bitwise_not(self):
for shape in self.unary_test_shapes: for shape in self.unary_test_shapes:
data = np.random.binomial(1, 0.5, shape).astype('bool') data = np.random.binomial(1, 0.5, shape).astype('bool')
...@@ -132,6 +154,18 @@ class TestTensorOps(OpTestCase): ...@@ -132,6 +154,18 @@ class TestTensorOps(OpTestCase):
a.bitwise_xor_(b) a.bitwise_xor_(b)
self.assertEqual(a, np.bitwise_xor(data1, data2)) self.assertEqual(a, np.bitwise_xor(data1, data2))
def test_bmm(self):
test_shapes = [((1, 2, 3), (2, 3, 4)),
((2, 2, 3), (1, 3, 4)),
((2, 2, 3), (2, 3, 4)),
((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for a_shape, b_shape in test_shapes:
data1, data2 = arange(a_shape), arange(b_shape, 1)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.bmm(b), np.matmul(data1, data2))
def test_ceil(self): def test_ceil(self):
data = np.array([1.4, 1.7, 2.0]) data = np.array([1.4, 1.7, 2.0])
x = new_tensor(data) x = new_tensor(data)
...@@ -334,6 +368,24 @@ class TestTensorOps(OpTestCase): ...@@ -334,6 +368,24 @@ class TestTensorOps(OpTestCase):
data[data > 2] = 0 data[data > 2] = 0
self.assertEqual(x, data) self.assertEqual(x, data)
def test_matmul(self):
test_shapes = [((2,), (2,)),
((2,), (2, 3)),
((2, 3), (3,)),
((2, 3), (3, 4)),
((2,), (4, 2, 3)),
((4, 2, 3), (3,)),
((1, 2, 3), (2, 3, 4)),
((2, 2, 3), (1, 3, 4)),
((2, 2, 3), (2, 3, 4)),
((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for a_shape, b_shape in test_shapes:
data1, data2 = arange(a_shape), arange(b_shape, 1)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.matmul(b), np.matmul(data1, data2))
def test_max(self): def test_max(self):
entries = [(0, True), (0, False), entries = [(0, True), (0, False),
(1, True), (1, False), (1, True), (1, False),
...@@ -382,20 +434,12 @@ class TestTensorOps(OpTestCase): ...@@ -382,20 +434,12 @@ class TestTensorOps(OpTestCase):
self.assertEqual(y, np.minimum(data1, data2)) self.assertEqual(y, np.minimum(data1, data2))
def test_mm(self): def test_mm(self):
entries = [ entries = [((2, 3), (3, 4))]
((2, 3), (3, 4), False, False), for a_shape, b_shape in entries:
((2, 3), (4, 3), False, True),
((3, 2), (3, 4), True, False),
((3, 2), (4, 3), True, True)]
for a_shape, b_shape, trans_a, trans_b in entries:
data1, data2 = arange(a_shape), arange(b_shape) data1, data2 = arange(a_shape), arange(b_shape)
a, b = new_tensor(data1), new_tensor(data2) a, b = new_tensor(data1), new_tensor(data2)
if trans_a or trans_b: y = a.mm(b)
y = torch.mm(a, b, trans_a, trans_b) self.assertEqual(y, np.matmul(data1, data2))
else:
y = a.mm(b)
self.assertEqual(y, np.matmul(data1.T if trans_a else data1,
data2.T if trans_b else data2))
def test_mul(self): def test_mul(self):
for a_shape, b_shape in self.binary_test_shapes: for a_shape, b_shape in self.binary_test_shapes:
......
...@@ -94,9 +94,12 @@ from dragon.vm.torch.core.ops.init.functional import zeros ...@@ -94,9 +94,12 @@ from dragon.vm.torch.core.ops.init.functional import zeros
from dragon.vm.torch.core.ops.init.functional import zeros_like from dragon.vm.torch.core.ops.init.functional import zeros_like
from dragon.vm.torch.core.ops.math.functional import abs from dragon.vm.torch.core.ops.math.functional import abs
from dragon.vm.torch.core.ops.math.functional import add from dragon.vm.torch.core.ops.math.functional import add
from dragon.vm.torch.core.ops.math.functional import addmm
from dragon.vm.torch.core.ops.math.functional import axpby from dragon.vm.torch.core.ops.math.functional import axpby
from dragon.vm.torch.core.ops.math.functional import baddbmm
from dragon.vm.torch.core.ops.math.functional import bitwise_not from dragon.vm.torch.core.ops.math.functional import bitwise_not
from dragon.vm.torch.core.ops.math.functional import bitwise_xor from dragon.vm.torch.core.ops.math.functional import bitwise_xor
from dragon.vm.torch.core.ops.math.functional import bmm
from dragon.vm.torch.core.ops.math.functional import ceil from dragon.vm.torch.core.ops.math.functional import ceil
from dragon.vm.torch.core.ops.math.functional import clamp from dragon.vm.torch.core.ops.math.functional import clamp
from dragon.vm.torch.core.ops.math.functional import cos from dragon.vm.torch.core.ops.math.functional import cos
...@@ -112,6 +115,7 @@ from dragon.vm.torch.core.ops.math.functional import le ...@@ -112,6 +115,7 @@ from dragon.vm.torch.core.ops.math.functional import le
from dragon.vm.torch.core.ops.math.functional import log from dragon.vm.torch.core.ops.math.functional import log
from dragon.vm.torch.core.ops.math.functional import logsumexp from dragon.vm.torch.core.ops.math.functional import logsumexp
from dragon.vm.torch.core.ops.math.functional import lt from dragon.vm.torch.core.ops.math.functional import lt
from dragon.vm.torch.core.ops.math.functional import matmul
from dragon.vm.torch.core.ops.math.functional import maximum from dragon.vm.torch.core.ops.math.functional import maximum
from dragon.vm.torch.core.ops.math.functional import minimum from dragon.vm.torch.core.ops.math.functional import minimum
from dragon.vm.torch.core.ops.math.functional import mm from dragon.vm.torch.core.ops.math.functional import mm
......
...@@ -76,6 +76,12 @@ from dragon.vm.torch.core.nn.modules.padding import ReplicationPad1d ...@@ -76,6 +76,12 @@ from dragon.vm.torch.core.nn.modules.padding import ReplicationPad1d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad2d from dragon.vm.torch.core.nn.modules.padding import ReplicationPad2d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad3d from dragon.vm.torch.core.nn.modules.padding import ReplicationPad3d
from dragon.vm.torch.core.nn.modules.padding import ZeroPad2d from dragon.vm.torch.core.nn.modules.padding import ZeroPad2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool1d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool3d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool1d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool3d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool1d from dragon.vm.torch.core.nn.modules.pooling import AvgPool1d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool2d from dragon.vm.torch.core.nn.modules.pooling import AvgPool2d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool3d from dragon.vm.torch.core.nn.modules.pooling import AvgPool3d
......
...@@ -14,6 +14,12 @@ from __future__ import absolute_import as _absolute_import ...@@ -14,6 +14,12 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool1d
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool2d
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool3d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool1d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool2d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool3d
from dragon.vm.torch.core.nn.functional import avg_pool1d from dragon.vm.torch.core.nn.functional import avg_pool1d
from dragon.vm.torch.core.nn.functional import avg_pool2d from dragon.vm.torch.core.nn.functional import avg_pool2d
from dragon.vm.torch.core.nn.functional import avg_pool3d from dragon.vm.torch.core.nn.functional import avg_pool3d
......
...@@ -76,7 +76,7 @@ class Function(object): ...@@ -76,7 +76,7 @@ class Function(object):
Parameters Parameters
---------- ----------
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
......
...@@ -18,16 +18,166 @@ from dragon.core.util import nest ...@@ -18,16 +18,166 @@ from dragon.core.util import nest
from dragon.vm.torch.core.nn.modules import _functions from dragon.vm.torch.core.nn.modules import _functions
from dragon.vm.torch.core.nn import _reduction from dragon.vm.torch.core.nn import _reduction
from dragon.vm.torch.core.nn.modules import utils from dragon.vm.torch.core.nn.modules import utils
from dragon.vm.torch.core.ops.math import _functions as _math_functions
from dragon.vm.torch.core.ops.math import functional as math_funcs from dragon.vm.torch.core.ops.math import functional as math_funcs
def adaptive_avg_pool1d(input, output_size):
"""Apply the 1d adaptive average pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveAvgPool1d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-1:], utils._single(output_size))
return _pool(input, _pool_mode='AVG', _nd_util=utils._single, **kwargs)
def adaptive_avg_pool2d(input, output_size):
"""Apply the 2d adaptive average pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveAvgPool2d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-2:], utils._pair(output_size))
return _pool(input, _pool_mode='AVG', _nd_util=utils._pair, **kwargs)
def adaptive_avg_pool3d(input, output_size):
"""Apply the 3d adaptive average pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveAvgPool3d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-3:], utils._triple(output_size))
return _pool(input, _pool_mode='AVG', _nd_util=utils._triple, **kwargs)
def adaptive_max_pool1d(input, output_size):
"""Apply the 1d adaptive max pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveMaxPool1d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-1:], utils._single(output_size))
return _pool(input, _pool_mode='MAX', _nd_util=utils._single, **kwargs)
def adaptive_max_pool2d(input, output_size):
"""Apply the 2d adaptive max pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveMaxPool2d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-2:], utils._pair(output_size))
return _pool(input, _pool_mode='MAX', _nd_util=utils._pair, **kwargs)
def adaptive_max_pool3d(input, output_size):
"""Apply the 3d adaptive max pooling to input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
output_size : Union[int, Sequence[int]]
The target output size.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.AdaptiveMaxPool3d(...)`_
"""
kwargs = utils._get_adaptive_pool_kwargs(
input.size()[-3:], utils._triple(output_size))
return _pool(input, _pool_mode='MAX', _nd_util=utils._triple, **kwargs)
def avg_pool1d( def avg_pool1d(
input, input,
kernel_size, kernel_size,
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
r"""Apply the 1d average pooling to input. r"""Apply the 1d average pooling to input.
...@@ -36,15 +186,13 @@ def avg_pool1d( ...@@ -36,15 +186,13 @@ def avg_pool1d(
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
kernel_size : Union[int, Sequence[int]] kernel_size : Union[int, Sequence[int]]
The size of sliding window. The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns Returns
------- -------
...@@ -65,7 +213,6 @@ def avg_pool2d( ...@@ -65,7 +213,6 @@ def avg_pool2d(
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
r"""Apply the 2d average pooling to input. r"""Apply the 2d average pooling to input.
...@@ -74,15 +221,13 @@ def avg_pool2d( ...@@ -74,15 +221,13 @@ def avg_pool2d(
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
kernel_size : Union[int, Sequence[int]] kernel_size : Union[int, Sequence[int]]
The size of sliding window. The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns Returns
------- -------
...@@ -103,7 +248,6 @@ def avg_pool3d( ...@@ -103,7 +248,6 @@ def avg_pool3d(
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
r"""Apply the 3d average pooling to input. r"""Apply the 3d average pooling to input.
...@@ -112,15 +256,13 @@ def avg_pool3d( ...@@ -112,15 +256,13 @@ def avg_pool3d(
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
kernel_size : Union[int, Sequence[int]] kernel_size : Union[int, Sequence[int]]
The size of sliding window. The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns Returns
------- -------
...@@ -262,9 +404,9 @@ def conv1d( ...@@ -262,9 +404,9 @@ def conv1d(
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias tensor. The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
dilation : Union[int, Sequence[int]], optional, default=1 dilation : Union[int, Sequence[int]], optional, default=1
...@@ -303,9 +445,9 @@ def conv2d( ...@@ -303,9 +445,9 @@ def conv2d(
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias tensor. The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
dilation : Union[int, Sequence[int]], optional, default=1 dilation : Union[int, Sequence[int]], optional, default=1
...@@ -344,9 +486,9 @@ def conv3d( ...@@ -344,9 +486,9 @@ def conv3d(
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias tensor. The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
dilation : Union[int, Sequence[int]], optional, default=1 dilation : Union[int, Sequence[int]], optional, default=1
...@@ -386,9 +528,9 @@ def conv_transpose1d( ...@@ -386,9 +528,9 @@ def conv_transpose1d(
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias. The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of slidaing window. The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
output_padding : int, optional, default=1 output_padding : int, optional, default=1
...@@ -430,9 +572,9 @@ def conv_transpose2d( ...@@ -430,9 +572,9 @@ def conv_transpose2d(
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias. The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
output_padding : int, optional, default=1 output_padding : int, optional, default=1
...@@ -474,9 +616,9 @@ def conv_transpose3d( ...@@ -474,9 +616,9 @@ def conv_transpose3d(
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias. The bias tensor.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of convolution window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
output_padding : int, optional, default=1 output_padding : int, optional, default=1
...@@ -604,9 +746,9 @@ def depthwise_conv2d( ...@@ -604,9 +746,9 @@ def depthwise_conv2d(
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias. The bias tensor.
stride : Sequence[int], default=1 stride : Sequence[int], default=1
The stride of sliding window. The stride of convolution window.
padding : Sequence[int], default=0 padding : Sequence[int], default=0
The zero padding size. The zero padding size.
dilation : Sequence[int], default=1 dilation : Sequence[int], default=1
...@@ -1093,7 +1235,7 @@ def leaky_relu(input, negative_slope=0.01, inplace=False): ...@@ -1093,7 +1235,7 @@ def leaky_relu(input, negative_slope=0.01, inplace=False):
def linear(input, weight, bias=None): def linear(input, weight, bias=None):
r"""Apply the linear transformation to input. r"""Apply the linear transformation to input.
.. math:: y = Wx + b .. math:: \text{out} = \text{input} \times \text{weight}^{T} + \text{bias}
Parameters Parameters
---------- ----------
...@@ -1102,7 +1244,7 @@ def linear(input, weight, bias=None): ...@@ -1102,7 +1244,7 @@ def linear(input, weight, bias=None):
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias. The bias tensor.
Returns Returns
------- -------
...@@ -1114,7 +1256,9 @@ def linear(input, weight, bias=None): ...@@ -1114,7 +1256,9 @@ def linear(input, weight, bias=None):
`torch.nn.Linear(...)`_ `torch.nn.Linear(...)`_
""" """
return _functions.Linear.instantiate(input.device).apply(input, weight, bias) return _math_functions.Gemm \
.instantiate(input.device, transB=True) \
.apply(input, weight, bias)
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
...@@ -1217,7 +1361,6 @@ def max_pool1d( ...@@ -1217,7 +1361,6 @@ def max_pool1d(
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
r"""Apply the 1d max pooling to input. r"""Apply the 1d max pooling to input.
...@@ -1226,15 +1369,13 @@ def max_pool1d( ...@@ -1226,15 +1369,13 @@ def max_pool1d(
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
kernel_size : Union[int, Sequence[int]] kernel_size : Union[int, Sequence[int]]
The size of sliding window. The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns Returns
------- -------
...@@ -1255,7 +1396,6 @@ def max_pool2d( ...@@ -1255,7 +1396,6 @@ def max_pool2d(
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
r"""Apply the 2d max pooling to input. r"""Apply the 2d max pooling to input.
...@@ -1264,15 +1404,13 @@ def max_pool2d( ...@@ -1264,15 +1404,13 @@ def max_pool2d(
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
kernel_size : Union[int, Sequence[int]] kernel_size : Union[int, Sequence[int]]
The size of sliding window. The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns Returns
------- -------
...@@ -1293,7 +1431,6 @@ def max_pool3d( ...@@ -1293,7 +1431,6 @@ def max_pool3d(
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
r"""Apply the 3d max pooling to input. r"""Apply the 3d max pooling to input.
...@@ -1302,15 +1439,13 @@ def max_pool3d( ...@@ -1302,15 +1439,13 @@ def max_pool3d(
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
kernel_size : Union[int, Sequence[int]] kernel_size : Union[int, Sequence[int]]
The size of sliding window. The size of pooling window.
stride : Union[int, Sequence[int]], optional, default=1 stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window. The stride of pooling window.
padding : Union[int, Sequence[int]], optional, default=0 padding : Union[int, Sequence[int]], optional, default=0
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
Returns Returns
------- -------
...@@ -1442,7 +1577,7 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None): ...@@ -1442,7 +1577,7 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None):
eps : float, optional, default=1e-12 eps : float, optional, default=1e-12
The value to :math:`\epsilon`. The value to :math:`\epsilon`.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -2095,6 +2230,7 @@ def _conv( ...@@ -2095,6 +2230,7 @@ def _conv(
group=groups, group=groups,
bias=bias is not None, bias=bias is not None,
dtype=weight.dtype, dtype=weight.dtype,
input_shape=input.shape,
).apply(input, weight, bias) ).apply(input, weight, bias)
...@@ -2124,6 +2260,7 @@ def _conv_transpose( ...@@ -2124,6 +2260,7 @@ def _conv_transpose(
output_padding=_nd_util(output_padding), output_padding=_nd_util(output_padding),
bias=bias is not None, bias=bias is not None,
dtype=weight.dtype, dtype=weight.dtype,
input_shape=input.shape,
).apply(input, weight, bias) ).apply(input, weight, bias)
...@@ -2133,7 +2270,6 @@ def _pool( ...@@ -2133,7 +2270,6 @@ def _pool(
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
_pool_mode='MAX', _pool_mode='MAX',
_nd_util=utils._pair, _nd_util=utils._pair,
_pool_fn=_functions.Pool, _pool_fn=_functions.Pool,
...@@ -2145,5 +2281,4 @@ def _pool( ...@@ -2145,5 +2281,4 @@ def _pool(
pads=_nd_util(padding), pads=_nd_util(padding),
mode=_pool_mode, mode=_pool_mode,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
global_pool=global_pool,
).apply(input) ).apply(input)
...@@ -86,7 +86,6 @@ class Pool(function.Function): ...@@ -86,7 +86,6 @@ class Pool(function.Function):
self.pads = kwargs.get('pads', 0) self.pads = kwargs.get('pads', 0)
self.ceil_mode = kwargs.get('ceil_mode', False) self.ceil_mode = kwargs.get('ceil_mode', False)
self.mode = kwargs.get('mode', 'MAX') self.mode = kwargs.get('mode', 'MAX')
self.global_pool = kwargs.get('global_pool', False)
def attributes(self): def attributes(self):
return { return {
...@@ -98,7 +97,6 @@ class Pool(function.Function): ...@@ -98,7 +97,6 @@ class Pool(function.Function):
'ceil_mode': self.ceil_mode, 'ceil_mode': self.ceil_mode,
'mode': self.mode, 'mode': self.mode,
'data_format': 'NCHW', 'data_format': 'NCHW',
'global_pool': self.global_pool,
} }
} }
...@@ -316,24 +314,6 @@ class L2Loss(Loss): ...@@ -316,24 +314,6 @@ class L2Loss(Loss):
} }
class Linear(function.Function):
"""Linear function."""
def attributes(self):
return {
'op_type': 'FullyConnected',
'arguments': {
'axis': -1,
'transW': True,
},
}
def forward(self, input, weight, bias=None, out=None):
inputs = [input, weight] + ([bias] if bias else [])
outputs = [out] if out else [self.alloc()]
return self.dispatch(inputs, outputs)
class LocalResponseNorm(function.Function): class LocalResponseNorm(function.Function):
"""LocalResponseNorm function.""" """LocalResponseNorm function."""
......
...@@ -48,7 +48,7 @@ class Identity(Module): ...@@ -48,7 +48,7 @@ class Identity(Module):
class Linear(Module): class Linear(Module):
r"""Apply the linear transformation. r"""Apply the linear transformation.
.. math:: y = Wx + b .. math:: \text{out} = \text{input} \times \text{weight}^{T} + \text{bias}
Examples: Examples:
......
...@@ -18,6 +18,17 @@ from dragon.vm.torch.core.nn import functional as F ...@@ -18,6 +18,17 @@ from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules.module import Module from dragon.vm.torch.core.nn.modules.module import Module
class _AdaptivePoolNd(Module):
"""Apply the n-dimension adaptive pooling."""
def __init__(self, output_size):
super(_AdaptivePoolNd, self).__init__()
self.output_size = output_size
def extra_repr(self):
return 'output_size={}'.format(self.output_size)
class _PoolNd(Module): class _PoolNd(Module):
"""Apply the n-dimension pooling.""" """Apply the n-dimension pooling."""
...@@ -27,24 +38,243 @@ class _PoolNd(Module): ...@@ -27,24 +38,243 @@ class _PoolNd(Module):
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
super(_PoolNd, self).__init__() super(_PoolNd, self).__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
self.ceil_mode = ceil_mode self.ceil_mode = ceil_mode
self.global_pool = global_pool
def extra_repr(self): def extra_repr(self):
return 'kernel_size={kernel_size}, ' \ return 'kernel_size={kernel_size}, ' \
'stride={stride}, ' \ 'stride={stride}, ' \
'padding={padding}, ' \ 'padding={padding}, ' \
'ceil_mode={ceil_mode}, ' \ 'ceil_mode={ceil_mode}' \
'global_pool={global_pool}' \
.format(**self.__dict__) .format(**self.__dict__)
class AdaptiveAvgPool1d(_AdaptivePoolNd):
r"""Apply the 1d adaptive average pooling.
This module excepts the input size :math:`(N, C, H)`,
and output size is :math:`(N, C, H_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of data.
Examples:
```python
m = torch.nn.AdaptiveAvgPool1d(1)
x = torch.ones(2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_avg_pool1d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveAvgPool1d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveAvgPool1d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_avg_pool1d(input, self.output_size)
class AdaptiveAvgPool2d(_AdaptivePoolNd):
r"""Apply the 2d adaptive average pooling.
This module excepts the input size :math:`(N, C, H, W)`,
and output size is :math:`(N, C, H_{\text{out}}, W_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`H` and :math:`W` are the height and width of data.
Examples:
```python
m = torch.nn.AdaptiveAvgPool2d(1)
x = torch.ones(2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_avg_pool2d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveAvgPool2d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveAvgPool2d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_avg_pool2d(input, self.output_size)
class AdaptiveAvgPool3d(_AdaptivePoolNd):
r"""Apply the 3d adaptive average pooling.
This module excepts the input size :math:`(N, C, D, H, W)`,
and output size is :math:`(N, C, D_{\text{out}}, H_{\text{out}}, W_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`D`, :math:`H` and :math:`W` are the depth, height and width of data.
Examples:
```python
m = torch.nn.AdaptiveAvgPool3d(1)
x = torch.ones(2, 2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_avg_pool3d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveAvgPool3d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveAvgPool3d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_avg_pool3d(input, self.output_size)
class AdaptiveMaxPool1d(_AdaptivePoolNd):
r"""Apply the 1d adaptive max pooling.
This module excepts the input size :math:`(N, C, H)`,
and output size is :math:`(N, C, H_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of data.
Examples:
```python
m = torch.nn.AdaptiveMaxPool1d(1)
x = torch.ones(2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_max_pool1d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveMaxPool1d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveMaxPool1d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_max_pool1d(input, self.output_size)
class AdaptiveMaxPool2d(_AdaptivePoolNd):
r"""Apply the 2d adaptive max pooling.
This module excepts the input size :math:`(N, C, H, W)`,
and output size is :math:`(N, C, H_{\text{out}}, W_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`H` and :math:`W` are the height and width of data.
Examples:
```python
m = torch.nn.AdaptiveMaxPool2d(1)
x = torch.ones(2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_max_pool2d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveMaxPool2d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveMaxPool2d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_max_pool2d(input, self.output_size)
class AdaptiveMaxPool3d(_AdaptivePoolNd):
r"""Apply the 3d adaptive max pooling.
This module excepts the input size :math:`(N, C, D, H, W)`,
and output size is :math:`(N, C, D_{\text{out}}, H_{\text{out}}, W_{\text{out}})`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`D`, :math:`H` and :math:`W` are the depth, height and width of data.
Examples:
```python
m = torch.nn.AdaptiveMaxPool3d(1)
x = torch.ones(2, 2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.adaptive_max_pool3d(...)`_
"""
def __init__(self, output_size):
"""Create a ``AdaptiveMaxPool3d`` module.
Parameters
----------
output_size : Union[int, Sequence[int]]
The target output size.
"""
super(AdaptiveMaxPool3d, self).__init__(output_size=output_size)
def forward(self, input):
return F.adaptive_max_pool3d(input, self.output_size)
class AvgPool1d(_PoolNd): class AvgPool1d(_PoolNd):
r"""Apply the 1d average pooling. r"""Apply the 1d average pooling.
...@@ -73,7 +303,6 @@ class AvgPool1d(_PoolNd): ...@@ -73,7 +303,6 @@ class AvgPool1d(_PoolNd):
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
"""Create a ``AvgPool1d`` module. """Create a ``AvgPool1d`` module.
...@@ -87,8 +316,6 @@ class AvgPool1d(_PoolNd): ...@@ -87,8 +316,6 @@ class AvgPool1d(_PoolNd):
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
""" """
super(AvgPool1d, self).__init__( super(AvgPool1d, self).__init__(
...@@ -96,7 +323,6 @@ class AvgPool1d(_PoolNd): ...@@ -96,7 +323,6 @@ class AvgPool1d(_PoolNd):
stride=stride, stride=stride,
padding=padding, padding=padding,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
global_pool=global_pool,
) )
def forward(self, input): def forward(self, input):
...@@ -106,7 +332,6 @@ class AvgPool1d(_PoolNd): ...@@ -106,7 +332,6 @@ class AvgPool1d(_PoolNd):
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
ceil_mode=self.ceil_mode, ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
) )
...@@ -138,7 +363,6 @@ class AvgPool2d(_PoolNd): ...@@ -138,7 +363,6 @@ class AvgPool2d(_PoolNd):
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
"""Create a ``AvgPool2d`` module. """Create a ``AvgPool2d`` module.
...@@ -152,8 +376,6 @@ class AvgPool2d(_PoolNd): ...@@ -152,8 +376,6 @@ class AvgPool2d(_PoolNd):
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
""" """
super(AvgPool2d, self).__init__( super(AvgPool2d, self).__init__(
...@@ -161,7 +383,6 @@ class AvgPool2d(_PoolNd): ...@@ -161,7 +383,6 @@ class AvgPool2d(_PoolNd):
stride=stride, stride=stride,
padding=padding, padding=padding,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
global_pool=global_pool,
) )
def forward(self, input): def forward(self, input):
...@@ -171,7 +392,6 @@ class AvgPool2d(_PoolNd): ...@@ -171,7 +392,6 @@ class AvgPool2d(_PoolNd):
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
ceil_mode=self.ceil_mode, ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
) )
...@@ -203,7 +423,6 @@ class AvgPool3d(_PoolNd): ...@@ -203,7 +423,6 @@ class AvgPool3d(_PoolNd):
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
"""Create a ``AvgPool3d`` module. """Create a ``AvgPool3d`` module.
...@@ -217,8 +436,6 @@ class AvgPool3d(_PoolNd): ...@@ -217,8 +436,6 @@ class AvgPool3d(_PoolNd):
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
""" """
super(AvgPool3d, self).__init__( super(AvgPool3d, self).__init__(
...@@ -226,7 +443,6 @@ class AvgPool3d(_PoolNd): ...@@ -226,7 +443,6 @@ class AvgPool3d(_PoolNd):
stride=stride, stride=stride,
padding=padding, padding=padding,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
global_pool=global_pool,
) )
def forward(self, input): def forward(self, input):
...@@ -236,7 +452,6 @@ class AvgPool3d(_PoolNd): ...@@ -236,7 +452,6 @@ class AvgPool3d(_PoolNd):
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
ceil_mode=self.ceil_mode, ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
) )
...@@ -268,7 +483,6 @@ class MaxPool1d(_PoolNd): ...@@ -268,7 +483,6 @@ class MaxPool1d(_PoolNd):
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
"""Create a ``MaxPool1d`` module. """Create a ``MaxPool1d`` module.
...@@ -282,8 +496,6 @@ class MaxPool1d(_PoolNd): ...@@ -282,8 +496,6 @@ class MaxPool1d(_PoolNd):
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
""" """
super(MaxPool1d, self).__init__( super(MaxPool1d, self).__init__(
...@@ -291,7 +503,6 @@ class MaxPool1d(_PoolNd): ...@@ -291,7 +503,6 @@ class MaxPool1d(_PoolNd):
stride=stride, stride=stride,
padding=padding, padding=padding,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
global_pool=global_pool,
) )
def forward(self, input): def forward(self, input):
...@@ -301,7 +512,6 @@ class MaxPool1d(_PoolNd): ...@@ -301,7 +512,6 @@ class MaxPool1d(_PoolNd):
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
ceil_mode=self.ceil_mode, ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
) )
...@@ -333,7 +543,6 @@ class MaxPool2d(_PoolNd): ...@@ -333,7 +543,6 @@ class MaxPool2d(_PoolNd):
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
"""Create a ``MaxPool2d`` module. """Create a ``MaxPool2d`` module.
...@@ -347,8 +556,6 @@ class MaxPool2d(_PoolNd): ...@@ -347,8 +556,6 @@ class MaxPool2d(_PoolNd):
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional
Apply the global pooling or not.
""" """
super(MaxPool2d, self).__init__( super(MaxPool2d, self).__init__(
...@@ -356,7 +563,6 @@ class MaxPool2d(_PoolNd): ...@@ -356,7 +563,6 @@ class MaxPool2d(_PoolNd):
stride=stride, stride=stride,
padding=padding, padding=padding,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
global_pool=global_pool,
) )
def forward(self, input): def forward(self, input):
...@@ -366,7 +572,6 @@ class MaxPool2d(_PoolNd): ...@@ -366,7 +572,6 @@ class MaxPool2d(_PoolNd):
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
ceil_mode=self.ceil_mode, ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
) )
...@@ -398,7 +603,6 @@ class MaxPool3d(_PoolNd): ...@@ -398,7 +603,6 @@ class MaxPool3d(_PoolNd):
stride=1, stride=1,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
global_pool=False,
): ):
"""Create a ``MaxPool3d`` module. """Create a ``MaxPool3d`` module.
...@@ -412,8 +616,6 @@ class MaxPool3d(_PoolNd): ...@@ -412,8 +616,6 @@ class MaxPool3d(_PoolNd):
The zero padding size. The zero padding size.
ceil_mode : bool, optional, default=False ceil_mode : bool, optional, default=False
Ceil or floor the boundary. Ceil or floor the boundary.
global_pool : bool, optional, default=False
Apply the global pooling or not.
""" """
super(MaxPool3d, self).__init__( super(MaxPool3d, self).__init__(
...@@ -421,7 +623,6 @@ class MaxPool3d(_PoolNd): ...@@ -421,7 +623,6 @@ class MaxPool3d(_PoolNd):
stride=stride, stride=stride,
padding=padding, padding=padding,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
global_pool=global_pool,
) )
def forward(self, input): def forward(self, input):
...@@ -431,5 +632,4 @@ class MaxPool3d(_PoolNd): ...@@ -431,5 +632,4 @@ class MaxPool3d(_PoolNd):
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
ceil_mode=self.ceil_mode, ceil_mode=self.ceil_mode,
global_pool=self.global_pool,
) )
...@@ -22,6 +22,18 @@ import itertools ...@@ -22,6 +22,18 @@ import itertools
from dragon.core.util import six from dragon.core.util import six
def _get_adaptive_pool_kwargs(input_sizes, output_sizes):
stride, kernel_size = [], []
for input_size, output_size in zip(input_sizes, output_sizes):
if output_size == 1:
stride.append(1)
kernel_size.append(input_size)
else:
stride.append(input_size // output_size)
kernel_size.append(input_size - (output_size - 1) * stride[-1])
return {'kernel_size': kernel_size, 'stride': stride}
def _ntuple(n): def _ntuple(n):
def parse(x): def parse(x):
if isinstance(x, six.collections_abc.Sequence): if isinstance(x, six.collections_abc.Sequence):
......
...@@ -315,12 +315,16 @@ class OneHot(function.Function): ...@@ -315,12 +315,16 @@ class OneHot(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(OneHot, self).__init__(key, dev, **kwargs) super(OneHot, self).__init__(key, dev, **kwargs)
self.depth = kwargs.get('depth', 1) self.depth = kwargs.get('depth', 1)
self.on_value = kwargs.get('on_value', 1)
self.off_value = kwargs.get('off_value', 0)
def attributes(self): def attributes(self):
return { return {
'op_type': 'OneHot', 'op_type': 'OneHot',
'arguments': { 'arguments': {
'depth': self.depth, 'depth': self.depth,
'on_value': self.on_value,
'off_value': self.off_value,
}, },
} }
......
...@@ -46,7 +46,7 @@ def argmax(input, dim=None, keepdim=False, out=None): ...@@ -46,7 +46,7 @@ def argmax(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimension or not. Keep the reduced dimension or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -81,7 +81,7 @@ def argmin(input, dim=None, keepdim=False, out=None): ...@@ -81,7 +81,7 @@ def argmin(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimension or not. Keep the reduced dimension or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -174,7 +174,7 @@ def cat(seq, dim=0, out=None): ...@@ -174,7 +174,7 @@ def cat(seq, dim=0, out=None):
dim : int, optional dim : int, optional
The dim to concatenate. The dim to concatenate.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -197,11 +197,11 @@ def channel_affine(input, weight, bias=None, dim=0, out=None): ...@@ -197,11 +197,11 @@ def channel_affine(input, weight, bias=None, dim=0, out=None):
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias. The bias tensor.
dim : int, optional, default=0 dim : int, optional, default=0
The start dimension to transform. The start dimension to transform.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -369,7 +369,7 @@ def cumsum(input, dim, out=None): ...@@ -369,7 +369,7 @@ def cumsum(input, dim, out=None):
dim : int dim : int
The cumulative dimension. The cumulative dimension.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -429,7 +429,7 @@ def flatten(input, start_dim=0, end_dim=-1, out=None): ...@@ -429,7 +429,7 @@ def flatten(input, start_dim=0, end_dim=-1, out=None):
end_dim : int, optional, default=-1 end_dim : int, optional, default=-1
The end dimension to flatten. The end dimension to flatten.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -465,7 +465,7 @@ def index_select(input, dim, index, out=None): ...@@ -465,7 +465,7 @@ def index_select(input, dim, index, out=None):
index : dragon.vm.torch.Tensor index : dragon.vm.torch.Tensor
The index tensor. The index tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -523,7 +523,7 @@ def masked_select(input, mask, out=None): ...@@ -523,7 +523,7 @@ def masked_select(input, mask, out=None):
mask : dragon.vm.torch.Tensor mask : dragon.vm.torch.Tensor
The mask for selecting. The mask for selecting.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -566,7 +566,7 @@ def max(input, dim=None, keepdim=False, out=None): ...@@ -566,7 +566,7 @@ def max(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -606,7 +606,7 @@ def mean(input, dim=None, keepdim=False, out=None): ...@@ -606,7 +606,7 @@ def mean(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -646,7 +646,7 @@ def min(input, dim=None, keepdim=False, out=None): ...@@ -646,7 +646,7 @@ def min(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -721,7 +721,7 @@ def nonzero(input, out=None): ...@@ -721,7 +721,7 @@ def nonzero(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -732,7 +732,7 @@ def nonzero(input, out=None): ...@@ -732,7 +732,7 @@ def nonzero(input, out=None):
return _functions.NonZero.instantiate(input.device).apply(input, out) return _functions.NonZero.instantiate(input.device).apply(input, out)
def one_hot(input, depth): def one_hot(input, depth, on_value=1, off_value=0):
r"""Return the one-hot representation for input. r"""Return the one-hot representation for input.
.. math:: .. math::
...@@ -748,6 +748,10 @@ def one_hot(input, depth): ...@@ -748,6 +748,10 @@ def one_hot(input, depth):
The input tensor. The input tensor.
depth : int depth : int
The depth of channels. The depth of channels.
on_value : int, optional, default=1
The value for equal branch.
off_value : int, optional, default=0
The value for not-equal branch.
Returns Returns
------- -------
...@@ -755,7 +759,12 @@ def one_hot(input, depth): ...@@ -755,7 +759,12 @@ def one_hot(input, depth):
The output tensor. The output tensor.
""" """
return _functions.OneHot.instantiate(input.device, depth=depth).apply(input) return _functions.OneHot.instantiate(
input.device,
depth=depth,
on_value=on_value,
off_value=off_value,
).apply(input)
def permute(input, dims): def permute(input, dims):
...@@ -812,7 +821,7 @@ def reshape(input, shape, out=None): ...@@ -812,7 +821,7 @@ def reshape(input, shape, out=None):
shape : Sequence[int] shape : Sequence[int]
The new shape. The new shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -986,7 +995,7 @@ def stack(seq, dim=0, out=None): ...@@ -986,7 +995,7 @@ def stack(seq, dim=0, out=None):
dim : int, optional, default=0 dim : int, optional, default=0
The dim to stack. The dim to stack.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -1030,7 +1039,7 @@ def sum(input, dim=None, keepdim=False, out=None): ...@@ -1030,7 +1039,7 @@ def sum(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
......
...@@ -59,7 +59,7 @@ def arange( ...@@ -59,7 +59,7 @@ def arange(
step : number, optional, default=1 step : number, optional, default=1
The spacing between two elements. The spacing between two elements.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -113,7 +113,7 @@ def eye( ...@@ -113,7 +113,7 @@ def eye(
m : int, optional m : int, optional
The number output cols. The number output cols.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -175,7 +175,7 @@ def full( ...@@ -175,7 +175,7 @@ def full(
fill_value : number fill_value : number
The scalar to fill. The scalar to fill.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -216,7 +216,7 @@ def full_like( ...@@ -216,7 +216,7 @@ def full_like(
fill_value : number fill_value : number
The scalar to fill. The scalar to fill.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -268,7 +268,7 @@ def linspace( ...@@ -268,7 +268,7 @@ def linspace(
steps : int, optional, default=100 steps : int, optional, default=100
The number of values to generate. The number of values to generate.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
dim : int, optional, default=0 dim : int, optional, default=0
...@@ -326,7 +326,7 @@ def ones(*size, **kwargs): ...@@ -326,7 +326,7 @@ def ones(*size, **kwargs):
size : int... size : int...
The size(s) indicating the out shape. The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -378,7 +378,7 @@ def rand(*size, **kwargs): ...@@ -378,7 +378,7 @@ def rand(*size, **kwargs):
size : int... size : int...
The size(s) indicating the out shape. The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -404,7 +404,7 @@ def randn(*size, **kwargs): ...@@ -404,7 +404,7 @@ def randn(*size, **kwargs):
size : int... size : int...
The size(s) indicating the out shape. The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -436,7 +436,7 @@ def randperm(n, out=None, dtype='int64', device=None, requires_grad=False): ...@@ -436,7 +436,7 @@ def randperm(n, out=None, dtype='int64', device=None, requires_grad=False):
n: number n: number
The end of interval. The end of interval.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -479,7 +479,7 @@ def zeros(*size, **kwargs): ...@@ -479,7 +479,7 @@ def zeros(*size, **kwargs):
size : int... size : int...
The size(s) indicating the out shape. The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
......
...@@ -77,36 +77,42 @@ class Clip(function.Function): ...@@ -77,36 +77,42 @@ class Clip(function.Function):
return self.dispatch([input], [self.alloc(out)]) return self.dispatch([input], [self.alloc(out)])
class UnaryFunc(function.Function): class Gemm(function.Function):
"""Unary function.""" """Gemm function."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(UnaryFunc, self).__init__(key, dev, **kwargs) super(Gemm, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '') self.alpha = kwargs.get('alpha', 1.0)
self.beta = kwargs.get('beta', 1.0)
self.transA = kwargs.get('transA', False)
self.transB = kwargs.get('transB', False)
def attributes(self): def attributes(self):
return {'op_type': self.op_type, 'arguments': {}} return {
'op_type': 'Gemm',
'arguments': {
'axis': -1,
'alpha': self.alpha,
'beta': self.beta,
'transA': self.transA,
'transB': self.transB,
},
}
def forward(self, input, out=None): def forward(self, mat1, mat2, mat3=None, out=None):
return self.dispatch([input], [self.alloc(out)]) inputs = [mat1, mat2] + ([mat3] if mat3 else [])
return self.dispatch(inputs, [self.alloc(out)])
class MatMul(function.Function): class UnaryFunc(function.Function):
"""MatMul function.""" """Unary function."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs) super(UnaryFunc, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False) self.op_type = kwargs.get('op_type', '')
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self): def attributes(self):
return { return {'op_type': self.op_type, 'arguments': {}}
'op_type': 'MatMul',
'arguments': {
'transA': self.transpose_a,
'transB': self.transpose_b,
},
}
def forward(self, mat1, mat2, out=None): def forward(self, input, out=None):
return self.dispatch([mat1, mat2], [self.alloc(out)]) return self.dispatch([input], [self.alloc(out)])
...@@ -34,7 +34,7 @@ def abs(input, out=None): ...@@ -34,7 +34,7 @@ def abs(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -59,7 +59,7 @@ def axpby(input, alpha=1., beta=1., out=None): ...@@ -59,7 +59,7 @@ def axpby(input, alpha=1., beta=1., out=None):
beta : float, optional, default=1. beta : float, optional, default=1.
The value to :math:`\beta`. The value to :math:`\beta`.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -87,7 +87,7 @@ def add(input, other, out=None): ...@@ -87,7 +87,7 @@ def add(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to add. The tensor to add.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -98,6 +98,74 @@ def add(input, other, out=None): ...@@ -98,6 +98,74 @@ def add(input, other, out=None):
return _binary_func(input, other, 'Add', out) return _binary_func(input, other, 'Add', out)
def addmm(input, mat1, mat2, beta=1, alpha=1, out=None):
r"""Add input to the result of matrix-matrix multiplication.
.. math:: \text{out} = \alpha (\text{mat1} \times \text{mat2}) + \beta \text{input}
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
mat1 : dragon.vm.torch.Tensor
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return _functions.Gemm \
.instantiate(
input.device,
alpha=float(alpha),
beta=float(beta),
).apply(mat1, mat2, input, out=out)
def baddbmm(input, batch1, batch2, beta=1, alpha=1, out=None):
r"""Add input to the result of batched matrix-matrix multiplication.
.. math::
\text{out}_{i} = \alpha (\text{mat1}_{i} \times \text{mat2}_{i}) +
\beta \text{input}_{i}
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
input1 = bmm(batch1, batch2)
input2 = input * beta if beta != 1 else input
input1 = input1 * alpha if alpha != 1 else input1
return add(input1, input2, out)
def bitwise_not(input, out=None): def bitwise_not(input, out=None):
r"""Compute the element-wise NOT bitwise operation. r"""Compute the element-wise NOT bitwise operation.
...@@ -120,7 +188,7 @@ def bitwise_not(input, out=None): ...@@ -120,7 +188,7 @@ def bitwise_not(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -152,7 +220,7 @@ def bitwise_xor(input, other, out=None): ...@@ -152,7 +220,7 @@ def bitwise_xor(input, other, out=None):
other : dragon.vm.torch.Tensor other : dragon.vm.torch.Tensor
The second input tensor. The second input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -163,6 +231,31 @@ def bitwise_xor(input, other, out=None): ...@@ -163,6 +231,31 @@ def bitwise_xor(input, other, out=None):
return _binary_func(input, other, 'Sub', out) return _binary_func(input, other, 'Sub', out)
def bmm(input, mat2, out=None):
r"""Compute the batched matrix-matrix multiplication.
.. math:: \text{out}_{i} = \text{input}_{i} \times \text{mat2}_{i}
Parameters
----------
input : dragon.vm.torch.Tensor
The first batch of matrices.
mat2 : dragon.vm.torch.Tensor
The second batch of matrices.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return _functions.BinaryFunc \
.instantiate(input.device, op_type='MatMul') \
.apply(input, mat2, out=out)
def ceil(input, out=None): def ceil(input, out=None):
r"""Compute the smallest integer not less than input. r"""Compute the smallest integer not less than input.
...@@ -180,7 +273,7 @@ def ceil(input, out=None): ...@@ -180,7 +273,7 @@ def ceil(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -205,7 +298,7 @@ def clamp(input, min=None, max=None, out=None): ...@@ -205,7 +298,7 @@ def clamp(input, min=None, max=None, out=None):
max : number, optional max : number, optional
The max value. The max value.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -238,7 +331,7 @@ def cos(input, out=None): ...@@ -238,7 +331,7 @@ def cos(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -261,7 +354,7 @@ def div(input, other, out=None): ...@@ -261,7 +354,7 @@ def div(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to divide. The tensor to divide.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -284,7 +377,7 @@ def eq(input, other, out=None): ...@@ -284,7 +377,7 @@ def eq(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare. The tensor to compare.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -305,7 +398,7 @@ def exp(input, out=None): ...@@ -305,7 +398,7 @@ def exp(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -333,7 +426,7 @@ def floor(input, out=None): ...@@ -333,7 +426,7 @@ def floor(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -356,7 +449,7 @@ def ge(input, other, out=None): ...@@ -356,7 +449,7 @@ def ge(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare. The tensor to compare.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -379,7 +472,7 @@ def gt(input, other, out=None): ...@@ -379,7 +472,7 @@ def gt(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare. The tensor to compare.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -454,7 +547,7 @@ def le(input, other, out=None): ...@@ -454,7 +547,7 @@ def le(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare. The tensor to compare.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -475,7 +568,7 @@ def log(input, out=None): ...@@ -475,7 +568,7 @@ def log(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -523,7 +616,7 @@ def lt(input, other, out=None): ...@@ -523,7 +616,7 @@ def lt(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare. The tensor to compare.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -534,6 +627,60 @@ def lt(input, other, out=None): ...@@ -534,6 +627,60 @@ def lt(input, other, out=None):
return _binary_func(input, other, 'Less', out) return _binary_func(input, other, 'Less', out)
def matmul(input, other, out=None):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{input} \times \text{other}
The behavior depends on the shape of input tensors:
* If both tensors are 1d, computes the vector product.
* If tensors are 1d and >=2d, computes the vector-matrix multiplication.
* If tensors are >=2d and 1d, computes the matrix-vector multiplication.
* If both tensors are >= 2d, computes the matrix-matrix multiplication.
* If one tensor is >= 3d, applies batching and broadcasting to the computation.
Examples:
```python
# Vector x Vector
a = torch.ones(2)
b = torch.ones(2)
print(torch.matmul(a, b))
# Vector x Matrix
a = torch.ones(2)
b = torch.ones(2, 3)
print(torch.matmul(a, b))
# Matrix x Vector
a = torch.ones(3, 2)
b = torch.ones(2)
print(torch.matmul(a, b))
# Matrix x Matrix
a = torch.ones(2, 3)
b = torch.ones(3, 2)
print(torch.matmul(a, b))
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
other : dragon.vm.torch.Tensor
The tensor to multiply.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.Tensor
The output tensor.
"""
return _functions.BinaryFunc \
.instantiate(input.device, op_type='MatMul') \
.apply(input, other, out=out)
def maximum(input, other, out=None): def maximum(input, other, out=None):
r"""Compute the maximum value of inputs. r"""Compute the maximum value of inputs.
...@@ -546,7 +693,7 @@ def maximum(input, other, out=None): ...@@ -546,7 +693,7 @@ def maximum(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The second input tensor. The second input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -575,7 +722,7 @@ def minimum(input, other, out=None): ...@@ -575,7 +722,7 @@ def minimum(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The second input tensor. The second input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -586,13 +733,11 @@ def minimum(input, other, out=None): ...@@ -586,13 +733,11 @@ def minimum(input, other, out=None):
input, other = utils \ input, other = utils \
.remove_binary_scalar(input, other) .remove_binary_scalar(input, other)
return _functions.BinaryFunc \ return _functions.BinaryFunc \
.instantiate( .instantiate(input.device, op_type='Minimum') \
input.device, .apply(input, other, out)
op_type='Minimum',
).apply(input, other, out)
def mm(input, mat2, transpose_a=False, transpose_b=False, out=None): def mm(input, mat2, out=None):
r"""Compute the matrix-matrix multiplication. r"""Compute the matrix-matrix multiplication.
.. math:: \text{out} = \text{input} \times \text{mat2} .. math:: \text{out} = \text{input} \times \text{mat2}
...@@ -603,12 +748,8 @@ def mm(input, mat2, transpose_a=False, transpose_b=False, out=None): ...@@ -603,12 +748,8 @@ def mm(input, mat2, transpose_a=False, transpose_b=False, out=None):
The first matrix. The first matrix.
mat2 : dragon.vm.torch.Tensor mat2 : dragon.vm.torch.Tensor
The second matrix. The second matrix.
transpose_a : bool, optional, default=False
Transpose the first matrix before computation or not.
transpose_b : bool, optional, default=False
Transpose the second matrix before computation or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output. The output tensor.
Returns Returns
------- -------
...@@ -616,12 +757,9 @@ def mm(input, mat2, transpose_a=False, transpose_b=False, out=None): ...@@ -616,12 +757,9 @@ def mm(input, mat2, transpose_a=False, transpose_b=False, out=None):
The output tensor. The output tensor.
""" """
return _functions.MatMul \ return _functions.Gemm \
.instantiate( .instantiate(input.device) \
utils.unify_devices([input, mat2]), .apply(input, mat2, out=out)
transpose_a=transpose_a,
transpose_b=transpose_b,
).apply(input, mat2, out)
def mul(input, other, out=None): def mul(input, other, out=None):
...@@ -636,7 +774,7 @@ def mul(input, other, out=None): ...@@ -636,7 +774,7 @@ def mul(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to multiply. The tensor to multiply.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -659,7 +797,7 @@ def ne(input, other, out=None): ...@@ -659,7 +797,7 @@ def ne(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to compare. The tensor to compare.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -680,7 +818,7 @@ def neg(input, out=None): ...@@ -680,7 +818,7 @@ def neg(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -712,7 +850,7 @@ def pow(input, exponent, out=None): ...@@ -712,7 +850,7 @@ def pow(input, exponent, out=None):
exponent : Union[dragon.vm.torch.Tensor, number] exponent : Union[dragon.vm.torch.Tensor, number]
The exponent tensor. The exponent tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -740,7 +878,7 @@ def reciprocal(input, out=None): ...@@ -740,7 +878,7 @@ def reciprocal(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -768,7 +906,7 @@ def round(input, out=None): ...@@ -768,7 +906,7 @@ def round(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -796,7 +934,7 @@ def rsqrt(input, out=None): ...@@ -796,7 +934,7 @@ def rsqrt(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -830,7 +968,7 @@ def sign(input, out=None): ...@@ -830,7 +968,7 @@ def sign(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -858,7 +996,7 @@ def sin(input, out=None): ...@@ -858,7 +996,7 @@ def sin(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -886,7 +1024,7 @@ def sqrt(input, out=None): ...@@ -886,7 +1024,7 @@ def sqrt(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -909,7 +1047,7 @@ def sub(input, other, out=None): ...@@ -909,7 +1047,7 @@ def sub(input, other, out=None):
other : Union[dragon.vm.torch.Tensor, number] other : Union[dragon.vm.torch.Tensor, number]
The tensor to subtract. The tensor to subtract.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
......
...@@ -85,6 +85,35 @@ def add_(self, other): ...@@ -85,6 +85,35 @@ def add_(self, other):
return math_funcs.add(self, other, self) return math_funcs.add(self, other, self)
def addmm(self, mat1, mat2, beta=1, alpha=1):
r"""Add the result of matrix-matrix multiplication.
.. math:: \text{out} = \alpha (\text{mat1} \times \text{mat2}) + \beta \text{self}
Parameters
----------
mat1 : dragon.vm.torch.Tensor
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.addmm(...)`_
"""
return math_funcs.addmm(self, mat1, mat2, beta=beta, alpha=alpha)
def argmax(self, dim=None, keepdim=False): def argmax(self, dim=None, keepdim=False):
"""Return the index of maximum elements. """Return the index of maximum elements.
...@@ -154,6 +183,71 @@ def argsort(self, dim=-1, descending=False): ...@@ -154,6 +183,71 @@ def argsort(self, dim=-1, descending=False):
return array_funcs.argsort(self, dim, descending) return array_funcs.argsort(self, dim, descending)
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{out}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
return math_funcs.baddbmm(self, batch1, batch2, beta=beta, alpha=alpha)
def baddbmm_(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{self}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
return math_funcs.baddbmm(
self, batch1, batch2,
beta=beta, alpha=alpha, out=self,
)
def backward(self, gradient=None, retain_graph=False): def backward(self, gradient=None, retain_graph=False):
"""Compute the derivatives of this tensor w.r.t. graph leaves. """Compute the derivatives of this tensor w.r.t. graph leaves.
...@@ -254,6 +348,29 @@ def bitwise_xor_(self, other): ...@@ -254,6 +348,29 @@ def bitwise_xor_(self, other):
return math_funcs.bitwise_xor(self, other, self) return math_funcs.bitwise_xor(self, other, self)
def bmm(self, batch2):
r"""Compute the batched matrix multiplication.
.. math:: \text{out}_{i} = \text{self}_{i} \times \text{batch2}_{i}
Parameters
----------
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.bmm(...)`_
"""
return math_funcs.bmm(self, batch2)
def bool(self): def bool(self):
"""Return a bool tensor with the same data. """Return a bool tensor with the same data.
...@@ -719,50 +836,6 @@ def floor_(self): ...@@ -719,50 +836,6 @@ def floor_(self):
return math_funcs.floor(self, self) return math_funcs.floor(self, self)
def new_full(
self,
size,
fill_value,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : Sequence[int]
The size of output tensor.
fill_value : number
The scalar to fill.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.full(...)`_
"""
return init_funcs.full(
size,
fill_value,
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def ge(self, other): def ge(self, other):
r"""Compute the element-wise greater-equal comparison. r"""Compute the element-wise greater-equal comparison.
...@@ -1104,6 +1177,29 @@ def masked_select(self, mask): ...@@ -1104,6 +1177,29 @@ def masked_select(self, mask):
return array_funcs.masked_select(self, mask) return array_funcs.masked_select(self, mask)
def matmul(self, tensor2):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{self} \times \text{tensor2}
Parameters
----------
tensor2 : dragon.vm.torch.Tensor
The tensor to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.matmul(...)`_
"""
return math_funcs.matmul(self, tensor2)
def max(self, dim=None, keepdim=False): def max(self, dim=None, keepdim=False):
"""Compute the max value of elements along the given dimension. """Compute the max value of elements along the given dimension.
...@@ -1383,6 +1479,50 @@ def neg_(self): ...@@ -1383,6 +1479,50 @@ def neg_(self):
return math_funcs.neg(self, self) return math_funcs.neg(self, self)
def new_full(
self,
size,
fill_value,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : Sequence[int]
The size of output tensor.
fill_value : number
The scalar to fill.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.full(...)`_
"""
return init_funcs.full(
size,
fill_value,
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def nonzero(self): def nonzero(self):
r"""Return the index of non-zero elements. r"""Return the index of non-zero elements.
...@@ -1735,7 +1875,7 @@ def sort(self, dim=-1, descending=False): ...@@ -1735,7 +1875,7 @@ def sort(self, dim=-1, descending=False):
def split(self, split_size_or_sections, dim=0): def split(self, split_size_or_sections, dim=0):
"""Return the splited chunks along the given dimension. """Return the split chunks along the given dimension.
Parameters Parameters
---------- ----------
...@@ -2132,14 +2272,18 @@ def _process_index(item): ...@@ -2132,14 +2272,18 @@ def _process_index(item):
Tensor.abs = abs Tensor.abs = abs
Tensor.add = add Tensor.add = add
Tensor.add_ = add_ Tensor.add_ = add_
Tensor.addmm = addmm
Tensor.argmax = argmax Tensor.argmax = argmax
Tensor.argmin = argmin Tensor.argmin = argmin
Tensor.argsort = argsort Tensor.argsort = argsort
Tensor.backward = backward Tensor.backward = backward
Tensor.baddbmm = baddbmm
Tensor.baddbmm_ = baddbmm_
Tensor.bitwise_not = bitwise_not Tensor.bitwise_not = bitwise_not
Tensor.bitwise_not_ = bitwise_not_ Tensor.bitwise_not_ = bitwise_not_
Tensor.bitwise_xor = bitwise_xor Tensor.bitwise_xor = bitwise_xor
Tensor.bitwise_xor_ = bitwise_xor_ Tensor.bitwise_xor_ = bitwise_xor_
Tensor.bmm = bmm
Tensor.bool = bool Tensor.bool = bool
Tensor.bool_ = bool_ Tensor.bool_ = bool_
Tensor.byte = byte Tensor.byte = byte
...@@ -2184,6 +2328,7 @@ Tensor.logsumexp = logsumexp ...@@ -2184,6 +2328,7 @@ Tensor.logsumexp = logsumexp
Tensor.lt = lt Tensor.lt = lt
Tensor.masked_fill_ = masked_fill_ Tensor.masked_fill_ = masked_fill_
Tensor.masked_select = masked_select Tensor.masked_select = masked_select
Tensor.matmul = matmul
Tensor.max = max Tensor.max = max
Tensor.maximum = maximum Tensor.maximum = maximum
Tensor.mean = mean Tensor.mean = mean
......
...@@ -270,6 +270,33 @@ class Tensor(object): ...@@ -270,6 +270,33 @@ class Tensor(object):
""" """
def addmm(self, mat1, mat2, beta=1, alpha=1):
r"""Add the result of matrix-matrix multiplication.
.. math:: \text{out} = \alpha (\text{mat1} \times \text{mat2}) + \beta \text{self}
Parameters
----------
mat1 : dragon.vm.torch.Tensor
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.addmm(...)`_
"""
def argmax(self, dim=None, keepdim=False): def argmax(self, dim=None, keepdim=False):
"""Return the index of maximum elements. """Return the index of maximum elements.
...@@ -345,6 +372,64 @@ class Tensor(object): ...@@ -345,6 +372,64 @@ class Tensor(object):
""" """
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{out}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
def baddbmm_(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{self}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
def bitwise_not(self): def bitwise_not(self):
r"""Compute the element-wise NOT bitwise operation. r"""Compute the element-wise NOT bitwise operation.
...@@ -419,6 +504,27 @@ class Tensor(object): ...@@ -419,6 +504,27 @@ class Tensor(object):
""" """
def bmm(self, batch2):
r"""Compute the batched matrix multiplication.
.. math:: \text{out}_{i} = \text{self}_{i} \times \text{batch2}_{i}
Parameters
----------
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.bmm(...)`_
"""
def bool(self): def bool(self):
"""Return a bool tensor with the same data. """Return a bool tensor with the same data.
...@@ -1192,6 +1298,27 @@ class Tensor(object): ...@@ -1192,6 +1298,27 @@ class Tensor(object):
""" """
def matmul(self, tensor2):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{self} \times \text{tensor2}
Parameters
----------
tensor2 : dragon.vm.torch.Tensor
The tensor to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.matmul(...)`_
"""
def max(self, dim=None, keepdim=False): def max(self, dim=None, keepdim=False):
"""Compute the max value of elements along the given dimension. """Compute the max value of elements along the given dimension.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!