Commit 6bfe3e73 by Ting PAN

Reimplement the general matrix multiplication

Summary:
This commit generalizes the fully-connected operation into GEMM,
and enhances the matmul operation via batched Dot, GEMV and GEMM.
New representations and attributes have been consistent with ONNX.
1 parent 73ed1b96
Showing with 1525 additions and 647 deletions
...@@ -313,8 +313,8 @@ class InnerProduct(Layer): ...@@ -313,8 +313,8 @@ class InnerProduct(Layer):
param = layer_param.inner_product_param param = layer_param.inner_product_param
self.arguments = { self.arguments = {
'axis': param.axis, 'axis': param.axis,
'out_channels': param.num_output, 'n': param.num_output,
'transpose_w': not param.transpose, 'transpose_b': not param.transpose,
} }
self.add_blob(filler=self.get_filler(param, 'weight_filler')) self.add_blob(filler=self.get_filler(param, 'weight_filler'))
if param.bias_term: if param.bias_term:
...@@ -322,7 +322,7 @@ class InnerProduct(Layer): ...@@ -322,7 +322,7 @@ class InnerProduct(Layer):
def __call__(self, bottom): def __call__(self, bottom):
inputs = [bottom] + [blob['data'] for blob in self._blobs] inputs = [bottom] + [blob['data'] for blob in self._blobs]
return math_ops.fully_connected(inputs, **self.arguments) return math_ops.gemm(inputs, **self.arguments)
class Input(Layer): class Input(Layer):
...@@ -409,7 +409,7 @@ class Normalize(Layer): ...@@ -409,7 +409,7 @@ class Normalize(Layer):
def __call__(self, bottom): def __call__(self, bottom):
norm_out = [normalization_ops.lp_normalize(bottom, **self.l2norm_arguments)] norm_out = [normalization_ops.lp_normalize(bottom, **self.l2norm_arguments)]
norm_out += [blob['data'] for blob in self._blobs] norm_out += [blob['data'] for blob in self._blobs]
return math_ops.affine(norm_out, **self.affine_arguments) return array_ops.channel_affine(norm_out, **self.affine_arguments)
class Permute(Layer): class Permute(Layer):
...@@ -583,7 +583,7 @@ class Scale(Layer): ...@@ -583,7 +583,7 @@ class Scale(Layer):
def __call__(self, bottom): def __call__(self, bottom):
inputs = [bottom] + [blob['data'] for blob in self._blobs] inputs = [bottom] + [blob['data'] for blob in self._blobs]
return math_ops.affine(inputs, **self.arguments) return array_ops.channel_affine(inputs, **self.arguments)
class Slice(Layer): class Slice(Layer):
......
...@@ -48,6 +48,9 @@ dragon.math ...@@ -48,6 +48,9 @@ dragon.math
`floor(...) <math/floor.html>`_ `floor(...) <math/floor.html>`_
: Compute the largest integer not greater than input. : Compute the largest integer not greater than input.
`gemm(...) <math/gemm.html>`_
: Compute the general matrix multiplication.
`greater(...) <math/greater.html>`_ `greater(...) <math/greater.html>`_
: Compute the element-wise greater comparison. : Compute the element-wise greater comparison.
...@@ -158,6 +161,7 @@ dragon.math ...@@ -158,6 +161,7 @@ dragon.math
math/equal math/equal
math/exp math/exp
math/floor math/floor
math/gemm
math/greater math/greater
math/greater_equal math/greater_equal
math/is_inf math/is_inf
......
fully_connected gemm
=============== ====
.. autofunction:: dragon.nn.fully_connected .. autofunction:: dragon.math.gemm
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "dragon.nn."; content: "dragon.math.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -74,9 +74,6 @@ dragon.nn ...@@ -74,9 +74,6 @@ dragon.nn
: Apply the exponential linear unit. : Apply the exponential linear unit.
`[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_. `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`fully_connected(...) <nn/fully_connected.html>`_
: Compute the dense matrix multiplication along the given axes.
`group_norm(...) <nn/group_norm.html>`_ `group_norm(...) <nn/group_norm.html>`_
: Apply the group normalization. : Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_. `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
...@@ -167,7 +164,6 @@ dragon.nn ...@@ -167,7 +164,6 @@ dragon.nn
nn/drop_block2d nn/drop_block2d
nn/drop_path nn/drop_path
nn/elu nn/elu
nn/fully_connected
nn/group_norm nn/group_norm
nn/hardsigmoid nn/hardsigmoid
nn/hardswish nn/hardswish
......
...@@ -79,7 +79,7 @@ Name Supported Reference ...@@ -79,7 +79,7 @@ Name Supported Reference
`Gather`_ |v| :func:`dragon.index_select` `Gather`_ |v| :func:`dragon.index_select`
`GatherElements`_ `GatherElements`_
`GatherND`_ `GatherND`_
`Gemm`_ |v| :func:`dragon.nn.fully_connected` `Gemm`_ |v| :func:`dragon.math.gemm`
`GlobalAveragePool`_ |v| :func:`dragon.nn.pool2d` `GlobalAveragePool`_ |v| :func:`dragon.nn.pool2d`
`GlobalLpPool`_ `GlobalLpPool`_
`GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d` `GlobalMaxPool`_ |v| :func:`dragon.nn.pool2d`
......
...@@ -36,6 +36,9 @@ vm.torch ...@@ -36,6 +36,9 @@ vm.torch
`add(...) <torch/add.html>`_ `add(...) <torch/add.html>`_
: Compute the element-wise addition. : Compute the element-wise addition.
`addmm(...) <torch/addmm.html>`_
: Add input to the result of matrix-matrix multiplication.
`arange(...) <torch/arange.html>`_ `arange(...) <torch/arange.html>`_
: Return a tensor of evenly spaced values within a interval. : Return a tensor of evenly spaced values within a interval.
...@@ -51,12 +54,18 @@ vm.torch ...@@ -51,12 +54,18 @@ vm.torch
`axpby(...) <torch/axpby.html>`_ `axpby(...) <torch/axpby.html>`_
: Compute the element-wise addition from input to output. : Compute the element-wise addition from input to output.
`baddbmm(...) <torch/baddbmm.html>`_
: Add input to the result of batched matrix-matrix multiplication.
`bitwise_not(...) <torch/bitwise_not.html>`_ `bitwise_not(...) <torch/bitwise_not.html>`_
: Compute the element-wise NOT bitwise operation. : Compute the element-wise NOT bitwise operation.
`bitwise_xor(...) <torch/bitwise_xor.html>`_ `bitwise_xor(...) <torch/bitwise_xor.html>`_
: Compute the element-wise XOR bitwise operation. : Compute the element-wise XOR bitwise operation.
`bmm(...) <torch/bmm.html>`_
: Compute the batched matrix-matrix multiplication.
`cat(...) <torch/cat.html>`_ `cat(...) <torch/cat.html>`_
: Concatenate the inputs along the given dimension. : Concatenate the inputs along the given dimension.
...@@ -148,6 +157,9 @@ vm.torch ...@@ -148,6 +157,9 @@ vm.torch
`masked_select(...) <torch/logsumexp.html>`_ `masked_select(...) <torch/logsumexp.html>`_
: Select the input elements where mask is 1. : Select the input elements where mask is 1.
`matmul(...) <torch/matmul.html>`_
: Compute the matrix multiplication.
`max(...) <torch/max.html>`_ `max(...) <torch/max.html>`_
: Compute the max value of elements along the given dimension. : Compute the max value of elements along the given dimension.
...@@ -281,13 +293,16 @@ vm.torch ...@@ -281,13 +293,16 @@ vm.torch
torch/Tensor_ torch/Tensor_
torch/abs torch/abs
torch/add torch/add
torch/addmm
torch/arange torch/arange
torch/argmax torch/argmax
torch/argmin torch/argmin
torch/argsort torch/argsort
torch/axpby torch/axpby
torch/baddbmm
torch/bitwise_not torch/bitwise_not
torch/bitwise_xor torch/bitwise_xor
torch/bmm
torch/cat torch/cat
torch/ceil torch/ceil
torch/channel_affine torch/channel_affine
...@@ -321,6 +336,7 @@ vm.torch ...@@ -321,6 +336,7 @@ vm.torch
torch/logsumexp torch/logsumexp
torch/lt torch/lt
torch/masked_select torch/masked_select
torch/matmul
torch/max torch/max
torch/maximum torch/maximum
torch/mean torch/mean
......
...@@ -53,6 +53,10 @@ add\_ ...@@ -53,6 +53,10 @@ add\_
##### #####
.. automethod:: dragon.vm.torch.Tensor.add_ .. automethod:: dragon.vm.torch.Tensor.add_
addmm
#####
.. automethod:: dragon.vm.torch.Tensor.addmm
argmax argmax
###### ######
.. automethod:: dragon.vm.torch.Tensor.argmax .. automethod:: dragon.vm.torch.Tensor.argmax
...@@ -69,6 +73,14 @@ backward ...@@ -69,6 +73,14 @@ backward
######## ########
.. automethod:: dragon.vm.torch.Tensor.backward .. automethod:: dragon.vm.torch.Tensor.backward
baddbmm
#######
.. automethod:: dragon.vm.torch.Tensor.baddbmm
baddbmm\_
#########
.. automethod:: dragon.vm.torch.Tensor.baddbmm_
bitwise_not bitwise_not
########### ###########
.. automethod:: dragon.vm.torch.Tensor.bitwise_not .. automethod:: dragon.vm.torch.Tensor.bitwise_not
...@@ -85,6 +97,10 @@ bitwise_xor\_ ...@@ -85,6 +97,10 @@ bitwise_xor\_
############# #############
.. automethod:: dragon.vm.torch.Tensor.bitwise_xor_ .. automethod:: dragon.vm.torch.Tensor.bitwise_xor_
bmm
###
.. automethod:: dragon.vm.torch.Tensor.bmm
bool bool
#### ####
.. automethod:: dragon.vm.torch.Tensor.bool .. automethod:: dragon.vm.torch.Tensor.bool
...@@ -285,6 +301,14 @@ masked_fill\_ ...@@ -285,6 +301,14 @@ masked_fill\_
############# #############
.. automethod:: dragon.vm.torch.Tensor.masked_fill_ .. automethod:: dragon.vm.torch.Tensor.masked_fill_
masked_select
#############
.. automethod:: dragon.vm.torch.Tensor.masked_select
matmul
######
.. automethod:: dragon.vm.torch.Tensor.matmul
max max
### ###
.. automethod:: dragon.vm.torch.Tensor.max .. automethod:: dragon.vm.torch.Tensor.max
...@@ -293,10 +317,6 @@ maximum ...@@ -293,10 +317,6 @@ maximum
####### #######
.. automethod:: dragon.vm.torch.Tensor.maximum .. automethod:: dragon.vm.torch.Tensor.maximum
masked_select
#############
.. automethod:: dragon.vm.torch.Tensor.masked_select
mean mean
#### ####
.. automethod:: dragon.vm.torch.Tensor.mean .. automethod:: dragon.vm.torch.Tensor.mean
...@@ -535,11 +555,14 @@ zero\_ ...@@ -535,11 +555,14 @@ zero\_
.. _torch.abs(...): abs.html .. _torch.abs(...): abs.html
.. _torch.add(...): add.html .. _torch.add(...): add.html
.. _torch.addmm(...): addmm.html
.. _torch.argmax(...): argmax.html .. _torch.argmax(...): argmax.html
.. _torch.argmin(...): argmin.html .. _torch.argmin(...): argmin.html
.. _torch.argsort(...): argsort.html .. _torch.argsort(...): argsort.html
.. _torch.baddbmm(...): baddbmm.html
.. _torch.bitwise_not(...): bitwise_not.html .. _torch.bitwise_not(...): bitwise_not.html
.. _torch.bitwise_xor(...): bitwise_xor.html .. _torch.bitwise_xor(...): bitwise_xor.html
.. _torch.bmm(...): bmm.html
.. _torch.ceil(...): ceil.html .. _torch.ceil(...): ceil.html
.. _torch.clamp(...): clamp.html .. _torch.clamp(...): clamp.html
.. _torch.cos(...): cos.html .. _torch.cos(...): cos.html
...@@ -557,6 +580,7 @@ zero\_ ...@@ -557,6 +580,7 @@ zero\_
.. _torch.isnan(...): isnan.html .. _torch.isnan(...): isnan.html
.. _torch.le(...): le.html .. _torch.le(...): le.html
.. _torch.lt(...): lt.html .. _torch.lt(...): lt.html
.. _torch.matmul(...): matmul.html
.. _torch.max(...): max.html .. _torch.max(...): max.html
.. _torch.maximum(...): maximum.html .. _torch.maximum(...): maximum.html
.. _torch.min(...): min.html .. _torch.min(...): min.html
......
addmm
=====
.. autofunction:: dragon.vm.torch.addmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
baddbmm
=======
.. autofunction:: dragon.vm.torch.baddbmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
bmm
===
.. autofunction:: dragon.vm.torch.bmm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
matmul
======
.. autofunction:: dragon.vm.torch.matmul
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
...@@ -6,6 +6,24 @@ vm.torch.nn ...@@ -6,6 +6,24 @@ vm.torch.nn
Classes Classes
------- -------
`class AdaptiveAvgPool1d <nn/AdaptiveAvgPool1d.html>`_
: Apply the 1d adaptive average pooling.
`class AdaptiveAvgPool2d <nn/AdaptiveAvgPool2d.html>`_
: Apply the 2d adaptive average pooling.
`class AdaptiveAvgPool3d <nn/AdaptiveAvgPool3d.html>`_
: Apply the 3d adaptive average pooling.
`class AdaptiveMaxPool1d <nn/AdaptiveMaxPool1d.html>`_
: Apply the 1d adaptive max pooling.
`class AdaptiveMaxPool2d <nn/AdaptiveMaxPool2d.html>`_
: Apply the 2d adaptive max pooling.
`class AdaptiveMaxPool3d <nn/AdaptiveMaxPool3d.html>`_
: Apply the 3d adaptive max pooling.
`class AffineChannel <nn/AffineChannel.html>`_ `class AffineChannel <nn/AffineChannel.html>`_
: Apply affine transformation along the channels. : Apply affine transformation along the channels.
...@@ -238,6 +256,12 @@ vm.torch.nn ...@@ -238,6 +256,12 @@ vm.torch.nn
.. toctree:: .. toctree::
:hidden: :hidden:
nn/AdaptiveAvgPool1d
nn/AdaptiveAvgPool2d
nn/AdaptiveAvgPool3d
nn/AdaptiveMaxPool1d
nn/AdaptiveMaxPool2d
nn/AdaptiveMaxPool3d
nn/AffineChannel nn/AffineChannel
nn/AvgPool1d nn/AvgPool1d
nn/AvgPool2d nn/AvgPool2d
......
AdaptiveAvgPool1d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool1d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool1d.__init__
.. _torch.nn.functional.adaptive_avg_pool1d(...): functional/adaptive_avg_pool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveAvgPool2d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool2d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool2d.__init__
.. _torch.nn.functional.adaptive_avg_pool2d(...): functional/adaptive_avg_pool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveAvgPool3d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveAvgPool3d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveAvgPool3d.__init__
.. _torch.nn.functional.adaptive_avg_pool3d(...): functional/adaptive_avg_pool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool1d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool1d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool1d.__init__
.. _torch.nn.functional.adaptive_max_pool1d(...): functional/adaptive_max_pool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool2d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool2d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool2d.__init__
.. _torch.nn.functional.adaptive_max_pool2d(...): functional/adaptive_max_pool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
AdaptiveMaxPool3d
=================
.. autoclass:: dragon.vm.torch.nn.AdaptiveMaxPool3d
__init__
--------
.. automethod:: dragon.vm.torch.nn.AdaptiveMaxPool3d.__init__
.. _torch.nn.functional.adaptive_max_pool3d(...): functional/adaptive_max_pool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
...@@ -6,6 +6,24 @@ vm.torch.nn.functional ...@@ -6,6 +6,24 @@ vm.torch.nn.functional
Functions Functions
--------- ---------
`adaptive_avg_pool1d(...) <functional/adaptive_avg_pool1d.html>`_
: Apply the 1d adaptive average pooling to input.
`adaptive_avg_pool2d(...) <functional/adaptive_avg_pool2d.html>`_
: Apply the 2d adaptive average pooling to input.
`adaptive_avg_pool3d(...) <functional/adaptive_avg_pool3d.html>`_
: Apply the 3d adaptive average pooling to input.
`adaptive_max_pool1d(...) <functional/adaptive_max_pool1d.html>`_
: Apply the 1d adaptive max pooling to input.
`adaptive_max_pool2d(...) <functional/adaptive_max_pool2d.html>`_
: Apply the 2d adaptive max pooling to input.
`adaptive_max_pool3d(...) <functional/adaptive_max_pool3d.html>`_
: Apply the 3d adaptive max pooling to input.
`avg_pool1d(...) <functional/avg_pool1d.html>`_ `avg_pool1d(...) <functional/avg_pool1d.html>`_
: Apply the 1d average pooling to input. : Apply the 1d average pooling to input.
...@@ -167,6 +185,12 @@ vm.torch.nn.functional ...@@ -167,6 +185,12 @@ vm.torch.nn.functional
.. toctree:: .. toctree::
:hidden: :hidden:
functional/adaptive_avg_pool1d
functional/adaptive_avg_pool2d
functional/adaptive_avg_pool3d
functional/adaptive_max_pool1d
functional/adaptive_max_pool2d
functional/adaptive_max_pool3d
functional/avg_pool1d functional/avg_pool1d
functional/avg_pool2d functional/avg_pool2d
functional/avg_pool3d functional/avg_pool3d
......
adaptive_avg_pool1d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool1d
.. _torch.nn.AdaptiveAvgPool1d(...): ../AdaptiveAvgPool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_avg_pool2d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool2d
.. _torch.nn.AdaptiveAvgPool2d(...): ../AdaptiveAvgPool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_avg_pool3d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_avg_pool3d
.. _torch.nn.AdaptiveAvgPool3d(...): ../AdaptiveAvgPool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool1d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool1d
.. _torch.nn.AdaptiveMaxPool1d(...): ../AdaptiveMaxPool1d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool2d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool2d
.. _torch.nn.AdaptiveMaxPool2d(...): ../AdaptiveMaxPool2d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
adaptive_max_pool3d
===================
.. autofunction:: dragon.vm.torch.nn.functional.adaptive_max_pool3d
.. _torch.nn.AdaptiveMaxPool3d(...): ../AdaptiveMaxPool3d.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
...@@ -185,7 +185,6 @@ const Map<string, Map<string, string>>& ONNXBackend::get_node_renamed_attrs() ...@@ -185,7 +185,6 @@ const Map<string, Map<string, string>>& ONNXBackend::get_node_renamed_attrs()
const { const {
const static Map<string, Map<string, string>> kPerNodeRenamedAttrs = { const static Map<string, Map<string, string>> kPerNodeRenamedAttrs = {
{"DepthToSpace", {{"blocksize", "block_size"}}}, {"DepthToSpace", {{"blocksize", "block_size"}}},
{"Gemm", {{"transB", "transW"}}},
{"RoiAlign", {"RoiAlign",
{ {
{"output_height", "pooled_h"}, {"output_height", "pooled_h"},
......
...@@ -180,19 +180,7 @@ ONNXImporterReturns ONNXBackend::GemmImporter( ...@@ -180,19 +180,7 @@ ONNXImporterReturns ONNXBackend::GemmImporter(
ONNXNode* onnx_node, ONNXNode* onnx_node,
const ConversionContext& ctx) { const ConversionContext& ctx) {
auto& attributes = onnx_node->attributes; auto& attributes = onnx_node->attributes;
auto alpha = attributes.get<float>("alpha", 1.f); attributes.AddRewrittenAttribute("axis")->set_i(-1);
auto beta = attributes.get<float>("beta", 1.f);
auto trans_a = attributes.get<int64_t>("transA", 0L);
// Remove the unsupported attributes
if (alpha != 1.f || beta != 1.f) {
LOG(FATAL) << "alpha/beta can not be set currently.";
}
if (trans_a) {
LOG(FATAL) << "Tranposed A is not supported currently.";
}
attributes.remove("alpha");
attributes.remove("beta");
attributes.remove("transA");
return GenericImporter(onnx_node, ctx); return GenericImporter(onnx_node, ctx);
} }
......
...@@ -98,7 +98,6 @@ DECLARE_ELEMENTWISE_OP(SignGradient); ...@@ -98,7 +98,6 @@ DECLARE_ELEMENTWISE_OP(SignGradient);
DECLARE_ELEMENTWISE_OP(SinGradient); DECLARE_ELEMENTWISE_OP(SinGradient);
DECLARE_ELEMENTWISE_OP(SqrtGradient); DECLARE_ELEMENTWISE_OP(SqrtGradient);
DECLARE_ELEMENTWISE_OP(SquareGradient); DECLARE_ELEMENTWISE_OP(SquareGradient);
// Binary ElementwiseOp // Binary ElementwiseOp
DECLARE_ELEMENTWISE_OP(Add); DECLARE_ELEMENTWISE_OP(Add);
DECLARE_ELEMENTWISE_OP(Sub); DECLARE_ELEMENTWISE_OP(Sub);
...@@ -122,7 +121,6 @@ DECLARE_ELEMENTWISE_OP(PowGradient); ...@@ -122,7 +121,6 @@ DECLARE_ELEMENTWISE_OP(PowGradient);
DECLARE_ELEMENTWISE_OP(DotGradient); DECLARE_ELEMENTWISE_OP(DotGradient);
DECLARE_ELEMENTWISE_OP(MinimumGradient); DECLARE_ELEMENTWISE_OP(MinimumGradient);
DECLARE_ELEMENTWISE_OP(MaximumGradient); DECLARE_ELEMENTWISE_OP(MaximumGradient);
#undef DECLARE_ELEMENTWISE_OP #undef DECLARE_ELEMENTWISE_OP
} // namespace dragon } // namespace dragon
......
#include "dragon/operators/math/fully_connected_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/filler.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void FullyConnectedOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(X);
// Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N;
if (out_channels_ <= 0) {
// Infer the "N" from the weights shape
N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString();
} else {
// Use a fixed "N" from the argument
N = out_channels_;
}
vec64_t Y_dims(axis + 1);
for (int i = 0; i < axis + 1; i++) {
Y_dims[i] = i < axis ? X.dim(i) : N;
}
if (transW_ > 0) {
TENSOR_FILL(W, vec64_t({N, K}));
CHECK(W.ndim() == 2 && W.dim(1) == K)
<< "\nWeights dimensions should be [N, K].\n"
<< "Got X as (" << M << ", " << K << "), "
<< "and W as " << W.DimString();
} else {
TENSOR_FILL(W, vec64_t({K, N}));
CHECK(W.ndim() == 2 && W.dim(0) == K)
<< "\nWeights dimensions should be [K, N].\n"
<< "Got X as (" << M << ", " << K << "), "
<< "and W as " << W.DimString();
}
math::Gemm(
CblasNoTrans,
(CBLAS_TRANSPOSE)transW_,
M,
N,
K,
1.f,
X.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
if (InputSize() > 2) {
TENSOR_FILL(Input(2), vec64_t({N}));
kernel::BiasAdd(
M,
1,
N,
Y->template data<T, Context>(),
Input(2).template data<T, Context>(),
Y->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
void FullyConnectedOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void FullyConnectedGradientOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1), &dY = Input(2);
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
CANONICALIZE_AXIS_WITH_TENSOR(X);
// Determine the number of output channels
int64_t M = X.count(0, axis), K = X.count(axis), N;
if (out_channels_ <= 0) {
// Infer the "N" from the weights shape
N = W.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer the N from "
<< "the weights shape: " << W.DimString();
} else {
// Use a fixed "N" from the argument
N = out_channels_;
}
if (dX->has_name()) {
if (transW_) {
math::Gemm(
CblasNoTrans,
CblasNoTrans,
M,
K,
N,
1.f,
dY.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasNoTrans,
CblasTrans,
M,
K,
N,
1.f,
dY.template data<T, Context>(),
W.template data<T, Context>(),
0.f,
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
}
if (dW->has_name()) {
if (transW_) {
math::Gemm(
CblasTrans,
CblasNoTrans,
N,
K,
M,
1.f,
dY.template data<T, Context>(),
X.template data<T, Context>(),
0.f,
dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasTrans,
CblasNoTrans,
K,
N,
M,
1.f,
X.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dW->ReshapeLike(W)->template mutable_data<T, Context>(),
ctx());
}
}
if (dB->has_name()) {
vec32_t dims = {(int)M, (int)N}, axes = {0};
math::ReduceSum(
2,
dims.data(),
1,
axes.data(),
1.f,
dY.template data<T, Context>(),
dB->Reshape({N})->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
void FullyConnectedGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(FullyConnected);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(FullyConnected);
#endif
DEPLOY_CPU_OPERATOR(FullyConnectedGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(FullyConnectedGradient);
#endif
OPERATOR_SCHEMA(FullyConnected)
/* X, W, B */
.NumInputs(2, 3)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(FullyConnectedGradient)
/* X, W, dY */
.NumInputs(3)
/* dX, dW, dB */
.NumOutputs(3);
namespace {
class GradientMaker : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
"",
vector<string>({I(0), I(1), GO(0)}),
vector<string>({GI(0), GI(1), GI(2)}));
}
};
} // namespace
REGISTER_GRADIENT(FullyConnected, GradientMaker);
} // namespace dragon
#include "dragon/operators/math/gemm_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/filler.h"
#include "dragon/utils/math_functions.h"
namespace dragon {
template <class Context>
template <typename T>
void GemmOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), *Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR(A);
// Check matrix A
auto M = transA_ ? A.count(axis) : A.count(0, axis);
auto K = transA_ ? A.count(0, axis) : A.count(axis);
// Check matrix B
auto N = n_; // Init "N" from the argument
if (N <= 0) {
// Infer "N" from the B shape
N = B.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer 'N' from "
<< "the B shape: " << B.DimString();
}
if (transB_ > 0) {
TENSOR_FILL(B, vec64_t({N, K}));
CHECK(B.ndim() == 2 && B.dim(1) == K)
<< "\nMatrixB's dimensions should be [N, K].\n"
<< "Got A as (" << M << ", " << K << "), "
<< "and B as " << B.DimString();
} else {
TENSOR_FILL(B, vec64_t({K, N}));
CHECK(B.ndim() == 2 && B.dim(0) == K)
<< "\nMatrixB's dimensions should be [K, N].\n"
<< "Got A as (" << M << ", " << K << "), "
<< "and B as " << B.DimString();
}
// Copy matrix C to Y if provided
vec64_t Y_dims(A.dims().begin(), A.dims().begin() + axis);
Y_dims.insert(transA_ ? Y_dims.begin() : Y_dims.end(), N);
if (InputSize() > 2) {
auto& C = Input(2);
if (C.ndim() == 0) {
TENSOR_FILL(C, vec64_t({N}));
}
if (math::utils::IsBinaryBroadcast(Y_dims, C.dims(), Y_dims)) {
math::Set(
C.ndim(),
C.dims().data(),
Y_dims.size(),
Y_dims.data(),
C.template data<T, Context>(),
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
} else {
LOG(FATAL) << "Could not broadcast together with shapes: "
<< Tensor::DimString(Y_dims) << " " << C.DimString();
}
}
math::Gemm(
(CBLAS_TRANSPOSE)transA_,
(CBLAS_TRANSPOSE)transB_,
M,
N,
K,
alpha_,
A.template data<T, Context>(),
B.template data<T, Context>(),
InputSize() > 2 ? beta_ : 0.f,
Y->Reshape(Y_dims)->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void GemmOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void GemmGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &dY = Input(3);
auto *dA = Output(0), *dB = Output(1), *dC = Output(2);
CANONICALIZE_AXIS_WITH_TENSOR(A);
// Check matrix A
auto M = transA_ ? A.count(axis) : A.count(0, axis);
auto K = transA_ ? A.count(0, axis) : A.count(axis);
// Check matrix B
auto N = n_; // Init "N" from the argument
if (N <= 0) {
// Infer "N" from the B shape
N = B.count() / K;
CHECK_GT(N, 0) << "\nFailed to infer 'N' from "
<< "the B shape: " << B.DimString();
}
if (dA->has_name()) {
if (transA_ > 0) {
math::Gemm(
transB_ ? CblasTrans : CblasNoTrans,
CblasTrans,
K,
M,
N,
alpha_,
B.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
CblasNoTrans,
transB_ ? CblasNoTrans : CblasTrans,
M,
K,
N,
alpha_,
dY.template data<T, Context>(),
B.template data<T, Context>(),
0.f,
dA->ReshapeLike(A)->template mutable_data<T, Context>(),
ctx());
}
}
if (dB->has_name()) {
if (transB_) {
math::Gemm(
CblasTrans,
transA_ ? CblasTrans : CblasNoTrans,
N,
K,
M,
alpha_,
dY.template data<T, Context>(),
A.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
} else {
math::Gemm(
transA_ ? CblasNoTrans : CblasTrans,
CblasNoTrans,
K,
N,
M,
alpha_,
A.template data<T, Context>(),
dY.template data<T, Context>(),
0.f,
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
}
}
if (dC->has_name()) {
auto& C = Input(2);
if (C.count() == dY.count()) {
math::Scale(
dY.count(),
beta_,
dY.template data<T, Context>(),
dC->ReshapeLike(C)->template mutable_data<T, Context>(),
ctx());
} else {
vec32_t Y_axes, C_axes;
math::utils::ComputeBinaryBroadcastAxes(
dY.dims(), C.dims(), dY.dims(), Y_axes, C_axes);
math::ReduceSum(
dY.ndim(),
vec32_t{dY.dims().begin(), dY.dims().end()}.data(),
C_axes.size(),
C_axes.data(),
beta_,
dY.template data<T, Context>(),
dC->ReshapeLike(C)->template mutable_data<T, Context>(),
ctx());
}
}
}
template <class Context>
void GemmGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Gemm);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Gemm);
#endif
DEPLOY_CPU_OPERATOR(GemmGradient);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(GemmGradient);
#endif
OPERATOR_SCHEMA(Gemm)
/* A, B, C */
.NumInputs(2, 3)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(GemmGradient)
/* A, B, C, dY */
.NumInputs(4)
/* dA, dB, dC */
.NumOutputs(3);
namespace {
class GradientMaker : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
vector<OperatorDef> MakeDef() override {
return SingleDef(
def.type() + "Gradient",
"",
vector<string>({I(0), I(1), I(2), GO(0)}),
vector<string>({GI(0), GI(1), GI(2)}));
}
};
} // namespace
REGISTER_GRADIENT(Gemm, GradientMaker);
} // namespace dragon
...@@ -10,20 +10,23 @@ ...@@ -10,20 +10,23 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_ #ifndef DRAGON_OPERATORS_MATH_GEMM_OP_H_
#define DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_ #define DRAGON_OPERATORS_MATH_GEMM_OP_H_
#include "dragon/core/operator.h" #include "dragon/core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class FullyConnectedOp final : public Operator<Context> { class GemmOp final : public Operator<Context> {
public: public:
FullyConnectedOp(const OperatorDef& def, Workspace* ws) GemmOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
out_channels_(OP_SINGLE_ARG(int64_t, "out_channels", 0)), n_(OP_SINGLE_ARG(int64_t, "n", 0)),
transW_(OP_SINGLE_ARG(int64_t, "transW", 1)) {} alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)),
beta_(OP_SINGLE_ARG(float, "beta", 1.f)),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -32,16 +35,20 @@ class FullyConnectedOp final : public Operator<Context> { ...@@ -32,16 +35,20 @@ class FullyConnectedOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
int64_t out_channels_, transW_; float alpha_, beta_;
int64_t n_, transA_, transB_;
}; };
template <class Context> template <class Context>
class FullyConnectedGradientOp final : public Operator<Context> { class GemmGradientOp final : public Operator<Context> {
public: public:
FullyConnectedGradientOp(const OperatorDef& def, Workspace* ws) GemmGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
out_channels_(OP_SINGLE_ARG(int64_t, "out_channels", 0)), n_(OP_SINGLE_ARG(int64_t, "n", 0)),
transW_(OP_SINGLE_ARG(int64_t, "transW", 1)) {} alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)),
beta_(OP_SINGLE_ARG(float, "beta", 1.f)),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -50,9 +57,10 @@ class FullyConnectedGradientOp final : public Operator<Context> { ...@@ -50,9 +57,10 @@ class FullyConnectedGradientOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
int64_t out_channels_, transW_; float alpha_, beta_;
int64_t n_, transA_, transB_;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_MATH_FULLY_CONNECTED_OP_H_ #endif // DRAGON_OPERATORS_MATH_GEMM_OP_H_
...@@ -20,37 +20,25 @@ namespace dragon { ...@@ -20,37 +20,25 @@ namespace dragon {
template <class Context> template <class Context>
class MatMulOp final : public Operator<Context> { class MatMulOp final : public Operator<Context> {
public: public:
MatMulOp(const OperatorDef& def, Workspace* ws) SIMPLE_CTOR_DTOR(MatMulOp);
: Operator<Context>(def, ws),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
int64_t transA_, transB_;
}; };
template <class Context> template <class Context>
class MatMulGradientOp final : public Operator<Context> { class MatMulGradientOp final : public Operator<Context> {
public: public:
MatMulGradientOp(const OperatorDef& def, Workspace* ws) SIMPLE_CTOR_DTOR(MatMulGradientOp);
: Operator<Context>(def, ws),
transA_(OP_SINGLE_ARG(int64_t, "transA", 0)),
transB_(OP_SINGLE_ARG(int64_t, "transB", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
protected:
int64_t transA_, transB_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -35,6 +35,7 @@ from dragon.core.ops.math_ops import dot ...@@ -35,6 +35,7 @@ from dragon.core.ops.math_ops import dot
from dragon.core.ops.math_ops import equal from dragon.core.ops.math_ops import equal
from dragon.core.ops.math_ops import exp from dragon.core.ops.math_ops import exp
from dragon.core.ops.math_ops import floor from dragon.core.ops.math_ops import floor
from dragon.core.ops.math_ops import gemm
from dragon.core.ops.math_ops import greater from dragon.core.ops.math_ops import greater
from dragon.core.ops.math_ops import greater_equal from dragon.core.ops.math_ops import greater_equal
from dragon.core.ops.math_ops import is_inf from dragon.core.ops.math_ops import is_inf
......
...@@ -33,7 +33,6 @@ from dragon.core.ops.activation_ops import relu6 ...@@ -33,7 +33,6 @@ from dragon.core.ops.activation_ops import relu6
from dragon.core.ops.activation_ops import selu from dragon.core.ops.activation_ops import selu
from dragon.core.ops.activation_ops import softmax from dragon.core.ops.activation_ops import softmax
from dragon.core.ops.activation_ops import swish from dragon.core.ops.activation_ops import swish
from dragon.core.ops.math_ops import fully_connected
from dragon.core.ops.normalization_ops import batch_norm from dragon.core.ops.normalization_ops import batch_norm
from dragon.core.ops.normalization_ops import group_norm from dragon.core.ops.normalization_ops import group_norm
from dragon.core.ops.normalization_ops import instance_norm from dragon.core.ops.normalization_ops import instance_norm
......
...@@ -414,30 +414,34 @@ def flatten_spec(args, inputs, outputs): ...@@ -414,30 +414,34 @@ def flatten_spec(args, inputs, outputs):
return outputs return outputs
@register('FullyConnected') @register('Gemm')
def fully_connected_spec(args, inputs, outputs): def gemm_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
axis, out_channels = args['axis'], args.get('out_channels', None) axis, n = args['axis'], args.get('n', None)
while axis < 0: while axis < 0:
try: try:
axis += len(inputs[0].shape) axis += len(inputs[0].shape)
except TypeError: except TypeError:
return outputs break
out_shape = [None] * (axis + 1) out_shape = [None] * axis if axis >= 0 else None
if out_channels is None: if n is None:
try: try:
if args['transW']: if args['transB']:
out_channels = inputs[1].shape[0] n = inputs[1].shape[0]
else: else:
out_channels = inputs[1].shape[1] n = inputs[1].shape[1]
except (TypeError, IndexError): except (TypeError, IndexError):
out_channels = None n = None
try: try:
out_shape[axis] = out_channels if out_shape is None or inputs[0].shape is not None:
out_shape[:axis] = inputs[0].shape[:axis] out_shape = list(inputs[0].shape[:axis])
if args['transA']:
out_shape.insert(0, n)
else:
out_shape.append(n)
outputs[0].shape = out_shape
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
outputs[0].shape = out_shape
return outputs return outputs
...@@ -510,12 +514,25 @@ def masked_select_spec(args, inputs, outputs): ...@@ -510,12 +514,25 @@ def masked_select_spec(args, inputs, outputs):
@register('MatMul') @register('MatMul')
def matmul_spec(args, inputs, outputs): def matmul_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
ta, tb = args['transA'], args['transB']
try: try:
a_shape = list(inputs[0].shape[:])
b_shape = list(inputs[1].shape[:]) b_shape = list(inputs[1].shape[:])
a_shape = out_shape = list(inputs[0].shape[:]) if len(a_shape) >= 2 and len(b_shape) >= 2:
out_shape[-2] = a_shape[-1] if ta else a_shape[-2] out_shape = [1] * max(len(a_shape), len(b_shape))
out_shape[-1] = b_shape[-2] if tb else b_shape[-1] a_shape = [1] * (len(out_shape) - len(a_shape)) + a_shape
b_shape = [1] * (len(out_shape) - len(b_shape)) + b_shape
for i in range(len(out_shape)):
try:
out_shape[i] = max(a_shape[i], b_shape[i])
except TypeError:
out_shape[i] = None
out_shape[-2] = a_shape[-2]
out_shape[-1] = b_shape[-1]
elif len(a_shape) == 1 and len(b_shape) == 1:
out_shape = []
else:
out_shape = a_shape if len(b_shape) == 1 else b_shape
out_shape.pop(-1 if len(b_shape) == 1 else -2)
except (TypeError, IndexError): except (TypeError, IndexError):
out_shape = None out_shape = None
outputs[0].shape = out_shape outputs[0].shape = out_shape
......
...@@ -498,23 +498,30 @@ def floor(inputs, **kwargs): ...@@ -498,23 +498,30 @@ def floor(inputs, **kwargs):
@OpSchema.num_inputs(2, 3) @OpSchema.num_inputs(2, 3)
def fully_connected(inputs, axis=1, transpose_w=True, **kwargs): def gemm(
r"""Compute the dense matrix multiplication along the given axes. inputs,
alpha=1.0,
.. math:: y = Wx + b beta=1.0,
transpose_a=False,
The column of input matrix is determined by: transpose_b=False,
**kwargs
.. math:: \text{Col} = \text{DimSince}(\text{Input}, \text{Axis}) ):
r"""Compute the general matrix multiplication.
.. math:: \text{out} = \alpha AB + \beta C
Parameters Parameters
---------- ----------
inputs : Sequence[dragon.Tensor] inputs : Sequence[dragon.Tensor]
The tensor :math:`x`, :math:`W` and :math:`b`. The matrix :math:`A`, :math:`B` and optional :math:`C`.
axis : int, optional, default=1 alpha : float, optional, default=1.0
The start axis to compute, can be negative. The value to :math:`\alpha`.
transpose_w : bool, optional, default=True beta : float, optional, default=1.0
**True** to transpose :math:`W` before computation. The value to :math:`\beta`.
transpose_a : bool, optional, default=False
**True** to transpose :math:`A` before computation.
transpose_b : bool, optional, default=False
**True** to transpose :math:`B` before computation.
Returns Returns
------- -------
...@@ -523,15 +530,22 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs): ...@@ -523,15 +530,22 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs):
""" """
args = ArgHelper.parse(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.FullyConnected args['axis'] = kwargs.get('axis', -1)
args['alpha'], args['beta'] = float(alpha), float(beta)
op_lib = math_ops_lib.Gemm
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate(axis=axis, transpose_w=transpose_w) \ .instantiate(
.apply(inputs) axis=args['axis'],
alpha=args['alpha'],
beta=args['beta'],
transpose_a=transpose_a,
transpose_b=transpose_b,
).apply(inputs)
else: else:
args.pop('transpose_w') args['transA'] = args.pop('transpose_a')
args['transW'] = transpose_w args['transB'] = args.pop('transpose_b')
return op_lib.blend('FullyConnected', **args) return op_lib.blend(**args)
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
...@@ -812,42 +826,44 @@ def less_equal(inputs, **kwargs): ...@@ -812,42 +826,44 @@ def less_equal(inputs, **kwargs):
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs): def matmul(inputs, **kwargs):
r"""Compute the matrix multiplication. r"""Compute the matrix multiplication.
.. math:: y = a \times b .. math:: \text{out} = \text{input1} \times \text{input2}
The rank of ``a`` and ``b`` should be equal and >= 2: The behavior depends on the shape of input tensors:
```python * If both tensors are 1d, computes the vector product.
# Ok, a typical matrix multiplication * If tensors are 1d and >=2d, computes the vector-matrix multiplication.
a = dragon.ones((2, 3), 'float32') * If tensors are >=2d and 1d, computes the matrix-vector multiplication.
b = dragon.ones((3, 3), 'float32') * If both tensors are >= 2d, computes the matrix-matrix multiplication.
print(dragon.math.matmul([a, b])) * If one tensor is >= 3d, applies batching and broadcasting to the computation.
# Compute a batch matrix multiplication if rank > 2 Examples:
aa = dragon.ones((4, 2, 3), 'float32')
bb = dragon.ones((4, 3, 3), 'float32')
print(dragon.math.matmul([aa, bb]))
```
If inputs are transposed, remember to transpose them back:
```python ```python
# Vector x Vector
a = dragon.ones((2,), 'float32')
b = dragon.ones((2,), 'float32')
print(dragon.math.matmul([a, b]))
# Vector x Matrix
a = dragon.ones((2,), 'float32')
b = dragon.ones((2, 3), 'float32')
print(dragon.math.matmul([a, b]))
# Matrix x Vector
a = dragon.ones((3, 2), 'float32') a = dragon.ones((3, 2), 'float32')
b = dragon.ones((3, 3), 'float32') b = dragon.ones((2,), 'float32')
print(dragon.math.matmul([a, b])) # ``a`` takes the wrong dimensions print(dragon.math.matmul([a, b]))
print(dragon.math.matmul([a, b], transpose_a=True)) # Ok # Matrix x Matrix
a = dragon.ones((2, 3), 'float32')
b = dragon.ones((3, 2), 'float32')
print(dragon.math.matmul([a, b]))
``` ```
Parameters Parameters
---------- ----------
inputs : Sequence[dragon.Tensor] inputs : Sequence[dragon.Tensor]
The matrix :math:`a` and :math:`b`. The input tensors.
transpose_a : bool, optional, default=False
**True** to transpose :math:`a` before computation.
transpose_b : bool, optional, default=False
**True** to transpose :math:`b` before computation.
Returns Returns
------- -------
...@@ -858,15 +874,9 @@ def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs): ...@@ -858,15 +874,9 @@ def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs):
args = ArgHelper.parse(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.MatMul op_lib = math_ops_lib.MatMul
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib.instantiate().apply(inputs)
.instantiate(
transpose_a=transpose_a,
transpose_b=transpose_b,
).apply(inputs)
else: else:
args.pop('transpose_a') return op_lib.blend(**args)
args.pop('transpose_b')
return op_lib.blend(transA=transpose_a, transB=transpose_b, **args)
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
......
...@@ -80,20 +80,26 @@ class Clip(Operator): ...@@ -80,20 +80,26 @@ class Clip(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class FullyConnected(Operator): class Gemm(Operator):
"""FullyConnected operator.""" """FullyConnected operator."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(FullyConnected, self).__init__(key, dev, **kwargs) super(Gemm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 1) self.axis = kwargs.get('axis', -1)
self.transpose_w = kwargs.get('transpose_w', True) self.alpha = kwargs.get('alpha', 1.0)
self.beta = kwargs.get('beta', 1.0)
self.transpose_a = kwargs.get('transpose_a', False)
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self): def attributes(self):
return { return {
'op_type': 'FullyConnected', 'op_type': 'Gemm',
'arguments': { 'arguments': {
'axis': self.axis, 'axis': self.axis,
'transW': self.transpose_w, 'alpha': self.alpha,
'beta': self.beta,
'transA': self.transpose_a,
'transB': self.transpose_b,
} }
} }
...@@ -104,18 +110,10 @@ class FullyConnected(Operator): ...@@ -104,18 +110,10 @@ class FullyConnected(Operator):
class MatMul(Operator): class MatMul(Operator):
"""MatMul operator.""" """MatMul operator."""
def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False)
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self): def attributes(self):
return { return {
'op_type': 'MatMul', 'op_type': 'MatMul',
'arguments': { 'arguments': {},
'transA': self.transpose_a,
'transB': self.transpose_b,
}
} }
def forward(self, inputs): def forward(self, inputs):
......
...@@ -136,6 +136,7 @@ def conv( ...@@ -136,6 +136,7 @@ def conv(
data_format=data_format, data_format=data_format,
bias=len(inputs) > 2, bias=len(inputs) > 2,
dtype=inputs[1].dtype, dtype=inputs[1].dtype,
input_shape=inputs[0].shape,
).apply(inputs) ).apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -465,6 +466,7 @@ def conv_transpose( ...@@ -465,6 +466,7 @@ def conv_transpose(
data_format=data_format, data_format=data_format,
bias=len(inputs) > 2, bias=len(inputs) > 2,
dtype=inputs[1].dtype, dtype=inputs[1].dtype,
input_shape=inputs[0].shape,
).apply(inputs) ).apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -44,13 +44,6 @@ def hardsigmoid_exporter(op_def, context): ...@@ -44,13 +44,6 @@ def hardsigmoid_exporter(op_def, context):
return node, const_tensors return node, const_tensors
@export_util.register('PRelu')
def prelu_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
const_tensors = [helper.from_tensor(op_def.input[1], context.ws)]
return node, const_tensors
@export_util.register('Relu') @export_util.register('Relu')
def relu_exporter(op_def, context): def relu_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
......
...@@ -81,24 +81,14 @@ def clip_exporter_v11(op_def, context): ...@@ -81,24 +81,14 @@ def clip_exporter_v11(op_def, context):
return node, const_tensors return node, const_tensors
@export_util.register('FullyConnected-7') @export_util.register('Gemm-7')
def fully_connected_exporter_v7(op_def, context): def gemm_exporter_v7(op_def, context):
node, const_tensors = export_util.translate(**locals()) return export_util.translate(**locals())
node.op_type = 'Gemm'
helper.add_attribute(node, 'alpha', 1.)
helper.add_attribute(node, 'beta', 1.)
for arg in op_def.arg:
if arg.name == 'transW':
helper.add_attribute(node, 'transB', arg.i)
# Weights and biases
const_tensors = [helper.from_tensor(name, context.ws)
for name in op_def.input[1:]]
return node, const_tensors
@export_util.register('FullyConnected') @export_util.register('Gemm')
def fully_connected_exporter(op_def, context): def gemm_exporter(op_def, context):
node, const_tensors = fully_connected_exporter_v7(op_def, context) node, const_tensors = gemm_exporter_v7(op_def, context)
helper.add_attribute(node, 'broadcast', 1) # Removed since opset 7 helper.add_attribute(node, 'broadcast', 1) # Removed since opset 7
return node, const_tensors return node, const_tensors
......
...@@ -29,8 +29,6 @@ def batch_norm_exporter(op_def, context): ...@@ -29,8 +29,6 @@ def batch_norm_exporter(op_def, context):
elif arg.name == 'momentum_desc': elif arg.name == 'momentum_desc':
momentum = helper.fetch_argument(op_def, arg, context.ws) momentum = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'momentum', float(momentum)) helper.add_attribute(node, 'momentum', float(momentum))
# Weight, bias, running mean and running variance
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
...@@ -48,8 +46,6 @@ def group_norm_exporter(op_def, context): ...@@ -48,8 +46,6 @@ def group_norm_exporter(op_def, context):
else: else:
helper.add_attribute(node, 'op_type', 'GroupNorm') helper.add_attribute(node, 'op_type', 'GroupNorm')
helper.add_attribute(node, 'group', arg.i) helper.add_attribute(node, 'group', arg.i)
# Weight and bias
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
......
...@@ -25,7 +25,7 @@ from dragon.vm.onnx.core.exporters import utils as export_util ...@@ -25,7 +25,7 @@ from dragon.vm.onnx.core.exporters import utils as export_util
'ConvTranspose', 'ConvTranspose',
'DepthwiseConv', 'DepthwiseConv',
]) ])
def convolution(op_def, context): def conv_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'ConvTranspose' if 'Transpose' in op_def.type else 'Conv' node.op_type = 'ConvTranspose' if 'Transpose' in op_def.type else 'Conv'
if 'Depthwise' in op_def.type: if 'Depthwise' in op_def.type:
...@@ -58,8 +58,6 @@ def convolution(op_def, context): ...@@ -58,8 +58,6 @@ def convolution(op_def, context):
helper.add_attribute(node, 'output_shape', arg.ints) helper.add_attribute(node, 'output_shape', arg.ints)
elif arg.name == 'output_padding': elif arg.name == 'output_padding':
helper.add_attribute(node, 'output_padding', arg.ints) helper.add_attribute(node, 'output_padding', arg.ints)
# Weights and biases
const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
......
...@@ -203,8 +203,7 @@ DRAGON_API void Gemv<float16, CPUContext>( ...@@ -203,8 +203,7 @@ DRAGON_API void Gemv<float16, CPUContext>(
const float16* x, const float16* x,
const float beta, const float beta,
float16* y, float16* y,
CPUContext* ctx, CPUContext* ctx) {
const std::string math_type) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -219,8 +218,7 @@ DRAGON_API void Gemv<float16, CPUContext>( ...@@ -219,8 +218,7 @@ DRAGON_API void Gemv<float16, CPUContext>(
const T* x, \ const T* x, \
const float beta, \ const float beta, \
T* y, \ T* y, \
CPUContext* ctx, \ CPUContext* ctx) { \
const string math_type) { \
T _alpha_ = alpha, _beta_ = beta; \ T _alpha_ = alpha, _beta_ = beta; \
EigenVectorMap<T> y_vec(y, TransA == CblasNoTrans ? M : N); \ EigenVectorMap<T> y_vec(y, TransA == CblasNoTrans ? M : N); \
if (beta == 0.f) \ if (beta == 0.f) \
...@@ -260,8 +258,7 @@ DRAGON_API void Gemm<float16, CPUContext>( ...@@ -260,8 +258,7 @@ DRAGON_API void Gemm<float16, CPUContext>(
const float16* B, const float16* B,
const float beta, const float beta,
float16* C, float16* C,
CPUContext* ctx, CPUContext* ctx) {
const string math_type) {
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
...@@ -278,8 +275,7 @@ DRAGON_API void Gemm<float16, CPUContext>( ...@@ -278,8 +275,7 @@ DRAGON_API void Gemm<float16, CPUContext>(
const T* B, \ const T* B, \
const float beta, \ const float beta, \
T* C, \ T* C, \
CPUContext* ctx, \ CPUContext* ctx) { \
const string math_type) { \
T _alpha_ = alpha, _beta_ = beta; \ T _alpha_ = alpha, _beta_ = beta; \
auto C_mat = EigenMatrixMap<T>(C, N, M); \ auto C_mat = EigenMatrixMap<T>(C, N, M); \
if (beta == 0.f) \ if (beta == 0.f) \
...@@ -328,6 +324,105 @@ DEFINE_GEMM_FUNC(float); ...@@ -328,6 +324,105 @@ DEFINE_GEMM_FUNC(float);
DEFINE_GEMM_FUNC(double); DEFINE_GEMM_FUNC(double);
#undef DEFINE_GEMM_FUNC #undef DEFINE_GEMM_FUNC
template <>
DRAGON_API void GemmBatched<float16, CPUContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
const float16** A,
const float16** B,
const float beta,
float16** C,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_BATCHED_GEMM_FUNC(T) \
template <> \
DRAGON_API void GemmBatched<T, CPUContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const float alpha, \
const T** A, \
const T** B, \
const float beta, \
T** C, \
CPUContext* ctx) { \
for (int i = 0; i < batch_size; ++i) { \
Gemm(TransA, TransB, M, N, K, alpha, A[i], B[i], beta, C[i], ctx); \
} \
}
DEFINE_BATCHED_GEMM_FUNC(float);
DEFINE_BATCHED_GEMM_FUNC(double);
#undef DEFINE_BATCHED_GEMM_FUNC
template <>
DRAGON_API void GemmStridedBatched<float16, CPUContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const int A_stride,
const int B_stride,
const int C_stride,
const float alpha,
const float16* A,
const float16* B,
const float beta,
float16* C,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_STRIDED_BATCHED_GEMM_FUNC(T) \
template <> \
DRAGON_API void GemmStridedBatched<T, CPUContext>( \
const CBLAS_TRANSPOSE TransA, \
const CBLAS_TRANSPOSE TransB, \
const int batch_size, \
const int M, \
const int N, \
const int K, \
const int A_stride, \
const int B_stride, \
const int C_stride, \
const float alpha, \
const T* A, \
const T* B, \
const float beta, \
T* C, \
CPUContext* ctx) { \
for (int i = 0; i < batch_size; ++i) { \
Gemm( \
TransA, \
TransB, \
M, \
N, \
K, \
alpha, \
A + i * A_stride, \
B + i * B_stride, \
beta, \
C + i * C_stride, \
ctx); \
} \
}
DEFINE_STRIDED_BATCHED_GEMM_FUNC(float);
DEFINE_STRIDED_BATCHED_GEMM_FUNC(double);
#undef DEFINE_STRIDED_BATCHED_GEMM_FUNC
} // namespace math } // namespace math
} // namespace dragon } // namespace dragon
...@@ -85,8 +85,7 @@ DRAGON_API void Gemv( ...@@ -85,8 +85,7 @@ DRAGON_API void Gemv(
const T* x, const T* x,
const float beta, const float beta,
T* y, T* y,
Context* ctx, Context* ctx);
const string math_type = "float32");
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void Gemm( DRAGON_API void Gemm(
...@@ -100,8 +99,40 @@ DRAGON_API void Gemm( ...@@ -100,8 +99,40 @@ DRAGON_API void Gemm(
const T* B, const T* B,
const float beta, const float beta,
T* C, T* C,
Context* ctx, Context* ctx);
const string math_type = "float32");
template <typename T, class Context>
DRAGON_API void GemmBatched(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
const T** A,
const T** B,
const float beta,
T** C,
Context* ctx);
template <typename T, class Context>
DRAGON_API void GemmStridedBatched(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int batch_size,
const int M,
const int N,
const int K,
const int A_stride,
const int B_stride,
const int C_stride,
const float alpha,
const T* A,
const T* B,
const float beta,
T* C,
Context* ctx);
} // namespace math } // namespace math
......
...@@ -24,21 +24,21 @@ void _Cast(const int n, const InputT* x, OutputT* y) { ...@@ -24,21 +24,21 @@ void _Cast(const int n, const InputT* x, OutputT* y) {
#define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void Cast<InputT, OutputT, CPUContext>( \ DRAGON_API void Cast<InputT, OutputT, CPUContext>( \
const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \ const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \
_Cast(n, x, y); \ _Cast(n, x, y); \
} }
#define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void Cast<InputT, OutputT, CPUContext>( \ DRAGON_API void Cast<InputT, OutputT, CPUContext>( \
const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \ const int n, const InputT* x, OutputT* y, CPUContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \ LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<InputT>()) << " -> " \ << types::to_string(TypeMeta::Make<InputT>()) << " -> " \
<< types::to_string(TypeMeta::Make<OutputT>()); \ << types::to_string(TypeMeta::Make<OutputT>()); \
} \ } \
template <> \ template <> \
void Cast<OutputT, InputT, CPUContext>( \ DRAGON_API void Cast<OutputT, InputT, CPUContext>( \
const int n, const OutputT* x, InputT* y, CPUContext* ctx) { \ const int n, const OutputT* x, InputT* y, CPUContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \ LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<OutputT>()) << " -> " \ << types::to_string(TypeMeta::Make<OutputT>()) << " -> " \
......
...@@ -23,7 +23,7 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) { ...@@ -23,7 +23,7 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) {
#define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_CAST_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void Cast<InputT, OutputT, CUDAContext>( \ DRAGON_API void Cast<InputT, OutputT, CUDAContext>( \
const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \ const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \
_Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Cast<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, \ n, \
...@@ -33,14 +33,14 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) { ...@@ -33,14 +33,14 @@ __global__ void _Cast(const int nthreads, const InputT* x, OutputT* y) {
#define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \ #define DEFINE_UNSUPPORTED_KERNEL_LAUNCHER(InputT, OutputT) \
template <> \ template <> \
void Cast<InputT, OutputT, CUDAContext>( \ DRAGON_API void Cast<InputT, OutputT, CUDAContext>( \
const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \ const int n, const InputT* x, OutputT* y, CUDAContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \ LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<InputT>()) << " -> " \ << types::to_string(TypeMeta::Make<InputT>()) << " -> " \
<< types::to_string(TypeMeta::Make<OutputT>()); \ << types::to_string(TypeMeta::Make<OutputT>()); \
} \ } \
template <> \ template <> \
void Cast<OutputT, InputT, CUDAContext>( \ DRAGON_API void Cast<OutputT, InputT, CUDAContext>( \
const int n, const OutputT* x, InputT* y, CUDAContext* ctx) { \ const int n, const OutputT* x, InputT* y, CUDAContext* ctx) { \
LOG(FATAL) << "Unsupported conversion: " \ LOG(FATAL) << "Unsupported conversion: " \
<< types::to_string(TypeMeta::Make<OutputT>()) << " -> " \ << types::to_string(TypeMeta::Make<OutputT>()) << " -> " \
......
...@@ -219,7 +219,7 @@ DEFINE_REDUCE_FUNC(Sum); ...@@ -219,7 +219,7 @@ DEFINE_REDUCE_FUNC(Sum);
#define DEFINE_KERNEL_LAUNCHER(name) \ #define DEFINE_KERNEL_LAUNCHER(name) \
template <> \ template <> \
void Reduce##name<float16, CPUContext>( \ DRAGON_API void Reduce##name<float16, CPUContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
...@@ -258,7 +258,7 @@ DRAGON_API float16 Sum<float16, CPUContext>( ...@@ -258,7 +258,7 @@ DRAGON_API float16 Sum<float16, CPUContext>(
#define DEFINE_KERNEL_LAUNCHER(name, T) \ #define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void Reduce##name<T, CPUContext>( \ DRAGON_API void Reduce##name<T, CPUContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
...@@ -298,7 +298,7 @@ DEFINE_KERNEL_LAUNCHER(Sum, double); ...@@ -298,7 +298,7 @@ DEFINE_KERNEL_LAUNCHER(Sum, double);
*y = val * T(scale); \ *y = val * T(scale); \
} \ } \
template <> \ template <> \
T Sum<T, CPUContext>( \ DRAGON_API T Sum<T, CPUContext>( \
const int n, const float scale, const T* x, CPUContext* ctx) { \ const int n, const float scale, const T* x, CPUContext* ctx) { \
T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \ T val = ConstEigenVectorArrayMap<T>(x, n).sum(); \
return val * T(scale); \ return val * T(scale); \
......
...@@ -174,7 +174,7 @@ DEFINE_REDUCE_DISPATCHER(Sum); ...@@ -174,7 +174,7 @@ DEFINE_REDUCE_DISPATCHER(Sum);
// We found that FP16 accumulator drops too many small values in // We found that FP16 accumulator drops too many small values in
// empirical experiments. // empirical experiments.
template <> template <>
void ReduceSum<float16, CUDAContext>( DRAGON_API void ReduceSum<float16, CUDAContext>(
const int num_dims, const int num_dims,
const int* dims, const int* dims,
const int num_axes, const int num_axes,
...@@ -199,7 +199,7 @@ void ReduceSum<float16, CUDAContext>( ...@@ -199,7 +199,7 @@ void ReduceSum<float16, CUDAContext>(
#define DEFINE_KERNEL_LAUNCHER(name, T, AccT, Reducer, kInit) \ #define DEFINE_KERNEL_LAUNCHER(name, T, AccT, Reducer, kInit) \
template <> \ template <> \
void Reduce##name<T, CUDAContext>( \ DRAGON_API void Reduce##name<T, CUDAContext>( \
const int num_dims, \ const int num_dims, \
const int* dims, \ const int* dims, \
const int num_axes, \ const int num_axes, \
......
...@@ -174,7 +174,8 @@ void ComputeBinaryBroadcastDims( ...@@ -174,7 +174,8 @@ void ComputeBinaryBroadcastDims(
const vec64_t& A_dims, const vec64_t& A_dims,
const vec64_t& B_dims, const vec64_t& B_dims,
vec64_t& A_broadcast_dims, vec64_t& A_broadcast_dims,
vec64_t& B_broadcast_dims) { vec64_t& B_broadcast_dims,
int64_t* C_broadcast_dims) {
auto num_dims = std::max(A_dims.size(), B_dims.size()); auto num_dims = std::max(A_dims.size(), B_dims.size());
A_broadcast_dims.resize(num_dims); A_broadcast_dims.resize(num_dims);
B_broadcast_dims.resize(num_dims); B_broadcast_dims.resize(num_dims);
...@@ -194,6 +195,16 @@ void ComputeBinaryBroadcastDims( ...@@ -194,6 +195,16 @@ void ComputeBinaryBroadcastDims(
B_dims.begin(), B_dims.begin(),
B_dims.end(), B_dims.end(),
B_broadcast_dims.begin() + num_dims - B_dims.size()); B_broadcast_dims.begin() + num_dims - B_dims.size());
if (C_broadcast_dims != nullptr) {
for (int i = 0; i < num_dims; ++i) {
if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) {
C_broadcast_dims[i] = 0;
} else {
C_broadcast_dims[i] =
std::max(A_broadcast_dims[i], B_broadcast_dims[i]);
}
}
}
} }
void ComputeBinaryBroadcastStrides( void ComputeBinaryBroadcastStrides(
......
...@@ -304,7 +304,8 @@ DRAGON_API void ComputeBinaryBroadcastDims( ...@@ -304,7 +304,8 @@ DRAGON_API void ComputeBinaryBroadcastDims(
const vec64_t& A_dims, const vec64_t& A_dims,
const vec64_t& B_dims, const vec64_t& B_dims,
vec64_t& A_broadcast_dims, vec64_t& A_broadcast_dims,
vec64_t& B_broadcast_dims); vec64_t& B_broadcast_dims,
int64_t* C_broadcast_dims = nullptr);
DRAGON_API void ComputeBinaryBroadcastStrides( DRAGON_API void ComputeBinaryBroadcastStrides(
const vec64_t& A_dims, const vec64_t& A_dims,
...@@ -326,22 +327,22 @@ DRAGON_API void TransposeAxesForReduce( ...@@ -326,22 +327,22 @@ DRAGON_API void TransposeAxesForReduce(
const int* reduce_axes, const int* reduce_axes,
int* transpose_axes); int* transpose_axes);
template <typename dim_t, typename stride_t> template <typename DimT, typename StrideT>
inline void inline void
ComputeStrides(const int num_dims, const dim_t* dims, stride_t* strides) { ComputeStrides(const int num_dims, const DimT* dims, StrideT* strides) {
int64_t cur_stride = 1; int64_t cur_stride = 1;
for (int i = num_dims - 1; i >= 0; --i) { for (int i = num_dims - 1; i >= 0; --i) {
strides[i] = stride_t(cur_stride); strides[i] = StrideT(cur_stride);
cur_stride *= int64_t(dims[i]); cur_stride *= int64_t(dims[i]);
} }
} }
template <typename dim_t, typename axis_t, typename stride_t> template <typename DimT, typename AxisT, typename StrideT>
inline void ComputeTransposeStrides( inline void ComputeTransposeStrides(
const int num_dims, const int num_dims,
const dim_t* dims, const DimT* dims,
const axis_t* axes, const AxisT* axes,
stride_t* strides) { StrideT* strides) {
vec64_t buf(num_dims); vec64_t buf(num_dims);
int64_t cur_stride = 1; int64_t cur_stride = 1;
for (int i = num_dims - 1; i >= 0; --i) { for (int i = num_dims - 1; i >= 0; --i) {
...@@ -349,13 +350,25 @@ inline void ComputeTransposeStrides( ...@@ -349,13 +350,25 @@ inline void ComputeTransposeStrides(
cur_stride *= int64_t(dims[i]); cur_stride *= int64_t(dims[i]);
} }
for (int i = 0; i < num_dims; ++i) { for (int i = 0; i < num_dims; ++i) {
strides[i] = stride_t(buf[axes[i]]); strides[i] = StrideT(buf[axes[i]]);
} }
} }
template <typename dim_t, typename index_t> template <typename DimT, typename IndexT>
inline IndexT
GetIndexFromDims(const int num_dims, const DimT* dims, IndexT* index) {
IndexT ret = 0;
for (int i = 0; i < num_dims; ++i) {
if (dims[i] > 1) {
ret = ret * dims[i] + index[i];
}
}
return ret;
}
template <typename DimT, typename IndexT>
inline void inline void
IncreaseIndexInDims(const int num_dims, const dim_t* dims, index_t* index) { IncreaseIndexInDims(const int num_dims, const DimT* dims, IndexT* index) {
for (int i = num_dims - 1; i >= 0; --i) { for (int i = num_dims - 1; i >= 0; --i) {
++index[i]; ++index[i];
if (index[i] >= dims[i]) { if (index[i] >= dims[i]) {
......
...@@ -116,11 +116,9 @@ class Dense(Layer): ...@@ -116,11 +116,9 @@ class Dense(Layer):
self.built = True self.built = True
def call(self, inputs): def call(self, inputs):
outputs = math_ops.fully_connected( outputs = math_ops.gemm(
[inputs, self.kernel] + [self.bias] [inputs, self.kernel] +
if self.use_bias else [], ([self.bias] if self.use_bias else []),
axis=-1,
transW=False,
) )
if self.activation is not None: if self.activation is not None:
return self.activation(outputs) return self.activation(outputs)
......
...@@ -703,38 +703,38 @@ def log(x, name=None): ...@@ -703,38 +703,38 @@ def log(x, name=None):
return math_ops.log(x, name=name) return math_ops.log(x, name=name)
def matmul( def matmul(a, b, name=None):
a,
b,
transpose_a=False,
transpose_b=False,
name=None,
):
r"""Compute the matrix multiplication. r"""Compute the matrix multiplication.
.. math:: y = a \times b .. math:: \text{out} = a \times b
The rank of ``a`` and ``b`` should be equal and >= 2: The behavior depends on the shape of input tensors:
```python * If both tensors are 1d, computes the vector product.
# Ok, a typical matrix multiplication * If tensors are 1d and >=2d, computes the vector-matrix multiplication.
a = tf.ones((2, 3), 'float32') * If tensors are >=2d and 1d, computes the matrix-vector multiplication.
b = tf.ones((3, 3), 'float32') * If both tensors are >= 2d, computes the matrix-matrix multiplication.
print(tf.linalg.matmul(a, b)) * If one tensor is >= 3d, applies batching and broadcasting to the computation.
# Compute a batch matrix multiplication if rank > 2 Examples:
aa = tf.ones((4, 2, 3), 'float32')
bb = tf.ones((4, 3, 3), 'float32')
print(tf.linalg.matmul(aa, bb))
```
If inputs are transposed, remember to transpose them back:
```python ```python
# Vector x Vector
a = tf.ones((2,), 'float32')
b = tf.ones((2,), 'float32')
print(tf.linalg.matmul(a, b))
# Vector x Matrix
a = tf.ones((2,), 'float32')
b = tf.ones((2, 3), 'float32')
print(tf.linalg.matmul(a, b))
# Matrix x Vector
a = tf.ones((3, 2), 'float32') a = tf.ones((3, 2), 'float32')
b = tf.ones((3, 3), 'float32') b = tf.ones((2,), 'float32')
print(tf.linalg.matmul(a, b)) # ``a`` takes the wrong dimensions print(tf.linalg.matmul(a, b))
print(tf.linalg.matmul(a, b, transpose_a=True)) # Ok # Matrix x Matrix
a = tf.ones((2, 3), 'float32')
b = tf.ones((3, 2), 'float32')
print(tf.linalg.matmul(a, b))
``` ```
Parameters Parameters
...@@ -743,10 +743,6 @@ def matmul( ...@@ -743,10 +743,6 @@ def matmul(
The matrix :math:`a`. The matrix :math:`a`.
b : dragon.Tensor b : dragon.Tensor
The matrix :math:`b`. The matrix :math:`b`.
transpose_a : bool, optional, default=False
**True** to transpose :math:`a` before computing.
transpose_b : bool, optional, default=False
**True** to transpose :math:`b` before computing.
name : str, optional name : str, optional
The operation name. The operation name.
...@@ -756,12 +752,7 @@ def matmul( ...@@ -756,12 +752,7 @@ def matmul(
The output tensor. The output tensor.
""" """
return math_ops.matmul( return math_ops.matmul([a, b], name=name)
[a, b],
transpose_a=transpose_a,
transpose_b=transpose_b,
name=name,
)
def multiply(x, y, name=None): def multiply(x, y, name=None):
......
...@@ -85,25 +85,25 @@ class Dense(layer.Layer): ...@@ -85,25 +85,25 @@ class Dense(layer.Layer):
raise AssertionError('The input dimension must be rank 2.' raise AssertionError('The input dimension must be rank 2.'
'Please reshape or flatten it.') 'Please reshape or flatten it.')
if self.in_channels: if self.in_channels:
shape = [self.n_units, self.in_channels] shape = [self.in_channels, self.n_units]
else: else:
self.in_channels = inputs_shape[1] self.in_channels = inputs_shape[1]
shape = [self.n_units, inputs_shape[1]] shape = [inputs_shape[1], self.n_units]
self.W = self.add_weight( self.W = self.add_weight(
name="weights", name='weights',
shape=shape, shape=shape,
init=self.W_init, init=self.W_init,
) )
if self.b_init: if self.b_init:
self.b = self.add_weight( self.b = self.add_weight(
name="biases", name='biases',
shape=[self.n_units], shape=[self.n_units],
init=self.b_init, init=self.b_init,
) )
def forward(self, inputs): def forward(self, inputs):
outputs = math_ops.fully_connected( outputs = math_ops.gemm(
[inputs, self.W] + ([self.b] if self.b_init else []), axis=1) [inputs, self.W] + ([self.b] if self.b_init else []))
if self.act: if self.act:
outputs = self.act(outputs) outputs = self.act(outputs)
return outputs return outputs
...@@ -281,17 +281,15 @@ class TestOpSpec(unittest.TestCase): ...@@ -281,17 +281,15 @@ class TestOpSpec(unittest.TestCase):
self.assertEqual(dragon.flatten( self.assertEqual(dragon.flatten(
self.sym4, axis=1, num_axes=-1).shape, (1, None)) self.sym4, axis=1, num_axes=-1).shape, (1, None))
def test_fully_connected(self): def test_gemm(self):
w = dragon.Tensor((3, 2)) w = dragon.Tensor((3, 2))
with dragon.graph_mode(): with dragon.graph_mode():
self.assertEqual(dragon.nn.fully_connected( self.assertEqual(dragon.math.gemm(
[self.sym1, w]).shape, (None, 3)) [self.sym1, w]).shape, None)
self.assertEqual(dragon.nn.fully_connected( self.assertEqual(dragon.math.gemm(
[self.sym1, w], transpose_w=False).shape, (None, 2)) [self.sym1, w], axis=1).shape, (None, 2))
self.assertEqual(dragon.nn.fully_connected( self.assertEqual(dragon.math.gemm(
[self.sym1, w], axis=-1).shape, None) [self.sym1, self.sym1]).shape, None)
self.assertEqual(dragon.nn.fully_connected(
[self.sym1, self.sym1]).shape, (None, None))
def test_index_select(self): def test_index_select(self):
with dragon.graph_mode(): with dragon.graph_mode():
...@@ -325,7 +323,9 @@ class TestOpSpec(unittest.TestCase): ...@@ -325,7 +323,9 @@ class TestOpSpec(unittest.TestCase):
self.assertEqual(dragon.math.matmul( self.assertEqual(dragon.math.matmul(
[self.sym1, self.sym3]).shape, None) [self.sym1, self.sym3]).shape, None)
self.assertEqual(dragon.math.matmul( self.assertEqual(dragon.math.matmul(
[self.sym2, self.sym3]).shape, None) [self.sym2, self.sym3]).shape, (None,))
self.assertEqual(dragon.math.matmul(
[self.sym3, self.sym2]).shape, (1,))
self.assertEqual(dragon.math.matmul( self.assertEqual(dragon.math.matmul(
[self.sym3, self.sym3]).shape, (1, None)) [self.sym3, self.sym3]).shape, (1, None))
self.assertEqual(dragon.math.matmul( self.assertEqual(dragon.math.matmul(
......
...@@ -1868,22 +1868,22 @@ class TestMathOps(OpTestCase): ...@@ -1868,22 +1868,22 @@ class TestMathOps(OpTestCase):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_floor() self.test_floor()
def test_fully_connected(self): def test_gemm(self):
entries = [((2, 3), (3, 4), (4,), False), entries = [((2, 3), (3, 4), (4,), False),
((2, 3), (4, 3), (4,), True)] ((2, 3), (4, 3), (4,), True)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
for x_shape, w_shape, b_shape, trans_w in entries: for x_shape, w_shape, b_shape, trans_b in entries:
data1, data2, data3 = arange(x_shape), arange(w_shape), arange(b_shape) data1, data2, data3 = arange(x_shape), arange(w_shape), arange(b_shape)
x, w, b = new_tensor(data1), new_tensor(data2), new_tensor(data3) x, w, b = new_tensor(data1), new_tensor(data2), new_tensor(data3)
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([x, w, b]) tape.watch([x, w, b])
y = dragon.nn.fully_connected([x, w, b], transpose_w=trans_w) y = dragon.math.gemm([x, w, b], transpose_b=trans_b)
data4 = arange(y.shape) data4 = arange(y.shape)
dy = new_tensor(data4) dy = new_tensor(data4)
dx, dw, db = tape.gradient(y, [x, w, b], output_gradients=[dy]) dx, dw, db = tape.gradient(y, [x, w, b], output_gradients=[dy])
result = np.matmul(data1, data2.T if trans_w else data2) + data3 result = np.matmul(data1, data2.T if trans_b else data2) + data3
if trans_w: if trans_b:
grad1 = np.matmul(data4, data2) grad1 = np.matmul(data4, data2)
grad2 = np.matmul(data4.T, data1) grad2 = np.matmul(data4.T, data1)
else: else:
...@@ -1894,9 +1894,9 @@ class TestMathOps(OpTestCase): ...@@ -1894,9 +1894,9 @@ class TestMathOps(OpTestCase):
[result, grad1, grad2, reduce_like(data4, data3)]) [result, grad1, grad2, reduce_like(data4, data3)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_fully_connected_cuda(self): def test_gemm_cuda(self):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_fully_connected() self.test_gemm()
def test_greater(self): def test_greater(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
...@@ -1997,40 +1997,62 @@ class TestMathOps(OpTestCase): ...@@ -1997,40 +1997,62 @@ class TestMathOps(OpTestCase):
self.test_log() self.test_log()
def test_matmul(self): def test_matmul(self):
entries = [ entries = [((2, 3), (3, 4)),
((2, 3), (3, 4), False, False), ((1, 2, 3), (2, 3, 4)),
((2, 3), (4, 3), False, True), ((2, 2, 3), (1, 3, 4)),
((3, 2), (3, 4), True, False), ((2, 2, 3), (2, 3, 4)),
((3, 2), (4, 3), True, True)] ((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for execution in ('EAGER_MODE', 'GRAPH_MODE',):
with execution_context().mode(execution):
for a_shape, b_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
a, b = new_tensor(data1), new_tensor(data2)
with dragon.GradientTape() as tape:
tape.watch([a, b])
y = dragon.math.matmul([a, b])
data3 = arange(y.shape)
dy = new_tensor(data3)
da, db = tape.gradient(y, [a, b], output_gradients=[dy])
grad1 = np.matmul(data3, transpose_last(data2, 2))
grad2 = np.matmul(transpose_last(data1, 2), data3)
self.assertEqual(
[y, da, db],
[np.matmul(data1, data2),
reduce_like(grad1, data1),
reduce_like(grad2, data2)])
entries = [((2,), (2,), (2, 1), (2, 1), (1, 1)),
((2,), (2, 3), (2, 1), (2, 3), (1, 3)),
((2, 3), (3,), (2, 3), (1, 3), (2, 1)),
((2,), (4, 2, 3), (1, 2, 1), (4, 2, 3), (4, 1, 3)),
((4, 2, 3), (3,), (4, 2, 3), (1, 1, 3), (4, 2, 1))]
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
for a_shape, b_shape, trans_a, trans_b in entries: for a_shape, b_shape, da_shape, db_shape, dy_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape) data1, data2 = arange(a_shape), arange(b_shape)
data4 = data1 if len(a_shape) > len(b_shape) else data2
a, b = new_tensor(data1), new_tensor(data2) a, b = new_tensor(data1), new_tensor(data2)
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([a, b]) tape.watch([a, b])
y = dragon.math.matmul([a, b], trans_a, trans_b) y = dragon.math.matmul([a, b])
data3 = arange(y.shape) data3 = arange(y.shape)
dy = new_tensor(data3) dy = new_tensor(data3)
da, db = tape.gradient(y, [a, b], output_gradients=[dy]) da, db = tape.gradient(y, [a, b], output_gradients=[dy])
if trans_a: grad1 = data3.reshape(dy_shape) * data2.reshape(db_shape)
if trans_b: grad2 = data1.reshape(da_shape) * data3.reshape(dy_shape)
grad1 = np.matmul(data2.T, data3.T) grad1_axes, grad2_axes = [], []
grad2 = np.matmul(data3.T, data1.T) for i in range(len(dy_shape)):
else: if da_shape[i] != db_shape[i]:
grad1 = np.matmul(data2, data3.T) if da_shape[i] == 1:
grad2 = np.matmul(data1, data3) grad1_axes.append(i)
else: if db_shape[i] == 1:
if trans_b: grad2_axes.append(i)
grad1 = np.matmul(data3, data2)
grad2 = np.matmul(data3.T, data1)
else:
grad1 = np.matmul(data3, data2.T)
grad2 = np.matmul(data1.T, data3)
self.assertEqual( self.assertEqual(
[y, da, db], [y, da, db],
[np.matmul(data1.T if trans_a else data1, [np.matmul(data1, data2),
data2.T if trans_b else data2), grad1, grad2]) reduce(grad1, tuple(grad1_axes)).reshape(data1.shape),
reduce(grad2, tuple(grad2_axes)).reshape(data2.shape)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_matmul_cuda(self): def test_matmul_cuda(self):
...@@ -4145,6 +4167,16 @@ def reduce_like(data, other, reduction='sum'): ...@@ -4145,6 +4167,16 @@ def reduce_like(data, other, reduction='sum'):
return data return data
def transpose_last(data, num_axes=None, axes=None):
"""Transpose the last axes of data."""
if axes is None and num_axes is not None:
axes = list(range(num_axes))[::-1]
perm = list(range(len(data.shape)))
start_axis = len(perm) - len(axes)
perm[start_axis:] = [v + start_axis for v in axes]
return np.transpose(data, perm)
def uniform(shape, dtype='float32'): def uniform(shape, dtype='float32'):
"""Return the uniform data with given shape.""" """Return the uniform data with given shape."""
return np.random.uniform(-1., 1., size=shape).astype(dtype) return np.random.uniform(-1., 1., size=shape).astype(dtype)
......
...@@ -619,39 +619,54 @@ class TestModules(OpTestCase): ...@@ -619,39 +619,54 @@ class TestModules(OpTestCase):
self.assertEqual(m4(x), np.pad(data, pads, 'constant')) self.assertEqual(m4(x), np.pad(data, pads, 'constant'))
def test_pool1d(self): def test_pool1d(self):
entries = [((2, 2, 2,), (2,), 2, 1, 'MAX'), entries = [((2, 2, 2,), (2,), 2, 1, 'MaxPool1d'),
((2, 2, 2,), (2,), 2, 1, 'AVG')] ((2, 2, 2,), (2,), 2, 1, 'AvgPool1d'),
((2, 2, 2,), (1,), 1, 0, 'AdaptiveMaxPool1d'),
((2, 2, 2,), (1,), 1, 0, 'AdaptiveAvgPool1d')]
for x_shape, kernel_shape, strides, pads, mode in entries: for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1 data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool1d if mode == 'AVG' else torch.nn.MaxPool1d module_cls = getattr(torch.nn, mode)
x = new_tensor(data) x = new_tensor(data)
if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads) m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m) y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.) result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result) self.assertEqual(y, result)
def test_pool2d(self): def test_pool2d(self):
entries = [((2, 2, 2, 2), (2, 2), 2, 1, 'MAX'), entries = [((2, 2, 2, 2), (2, 2), 2, 1, 'MaxPool2d'),
((2, 2, 2, 2), (2, 2), 2, 1, 'AVG')] ((2, 2, 2, 2), (2, 2), 2, 1, 'AvgPool2d'),
((2, 2, 2, 2), (1, 1), 1, 0, 'AdaptiveMaxPool2d'),
((2, 2, 2, 2), (1, 1), 1, 0, 'AdaptiveAvgPool2d')]
for x_shape, kernel_shape, strides, pads, mode in entries: for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1 data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool2d if mode == 'AVG' else torch.nn.MaxPool2d module_cls = getattr(torch.nn, mode)
x = new_tensor(data) x = new_tensor(data)
if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads) m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m) y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.) result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result) self.assertEqual(y, result)
def test_pool3d(self): def test_pool3d(self):
entries = [((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'MAX'), entries = [((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'MaxPool3d'),
((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'AVG')] ((2, 2, 2, 2, 2), (2, 2, 2), 2, 1, 'AvgPool3d'),
((2, 2, 2, 2, 2), (1, 1, 1), 1, 0, 'AdaptiveMaxPool3d'),
((2, 2, 2, 2, 2), (1, 1, 1), 1, 0, 'AdaptiveAvgPool3d')]
for x_shape, kernel_shape, strides, pads, mode in entries: for x_shape, kernel_shape, strides, pads, mode in entries:
data = arange(x_shape) * .1 data = arange(x_shape) * .1
module_cls = torch.nn.AvgPool3d if mode == 'AVG' else torch.nn.MaxPool3d module_cls = getattr(torch.nn, mode)
x = new_tensor(data) x = new_tensor(data)
if 'Adaptive' in mode:
m = module_cls(x_shape[-1])
else:
m = module_cls(kernel_shape, strides, pads) m = module_cls(kernel_shape, strides, pads)
y, _ = m(x), repr(m) y, _ = m(x), repr(m)
result = data / (np.prod(kernel_shape) if mode == 'AVG' else 1.) result = data / (np.prod(kernel_shape) if 'Avg' in mode else 1.)
self.assertEqual(y, result) self.assertEqual(y, result)
def test_prelu(self): def test_prelu(self):
......
...@@ -95,6 +95,16 @@ class TestTensorOps(OpTestCase): ...@@ -95,6 +95,16 @@ class TestTensorOps(OpTestCase):
a += b a += b
self.assertEqual(a, data1 + data2) self.assertEqual(a, data1 + data2)
def test_addmm(self):
entries = [((2, 3), (3, 4), (2, 4))]
for a_shape, b_shape, c_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
data3 = arange(c_shape)
a, b = new_tensor(data1), new_tensor(data2)
c = new_tensor(data3)
y = c.addmm(a, b)
self.assertEqual(y, np.matmul(data1, data2) + data3)
def test_argmax(self): def test_argmax(self):
entries = [(0, True), (0, False), (1, True), (1, False), (None, False)] entries = [(0, True), (0, False), (1, True), (1, False), (None, False)]
for axis, keepdims in entries: for axis, keepdims in entries:
...@@ -115,6 +125,18 @@ class TestTensorOps(OpTestCase): ...@@ -115,6 +125,18 @@ class TestTensorOps(OpTestCase):
result = np.expand_dims(result, axis) result = np.expand_dims(result, axis)
self.assertEqual(x.argmin(axis, keepdims), result) self.assertEqual(x.argmin(axis, keepdims), result)
def test_baddbmm(self):
entries = [((2, 2, 3), (2, 3, 4), (2, 2, 4))]
for a_shape, b_shape, c_shape in entries:
data1, data2 = arange(a_shape), arange(b_shape)
data3 = arange(c_shape)
a, b = new_tensor(data1), new_tensor(data2)
c = new_tensor(data3)
y = c.baddbmm(a, b)
self.assertEqual(y, np.matmul(data1, data2) + data3)
c.baddbmm_(a, b)
self.assertEqual(c, np.matmul(data1, data2) + data3)
def test_bitwise_not(self): def test_bitwise_not(self):
for shape in self.unary_test_shapes: for shape in self.unary_test_shapes:
data = np.random.binomial(1, 0.5, shape).astype('bool') data = np.random.binomial(1, 0.5, shape).astype('bool')
...@@ -132,6 +154,18 @@ class TestTensorOps(OpTestCase): ...@@ -132,6 +154,18 @@ class TestTensorOps(OpTestCase):
a.bitwise_xor_(b) a.bitwise_xor_(b)
self.assertEqual(a, np.bitwise_xor(data1, data2)) self.assertEqual(a, np.bitwise_xor(data1, data2))
def test_bmm(self):
test_shapes = [((1, 2, 3), (2, 3, 4)),
((2, 2, 3), (1, 3, 4)),
((2, 2, 3), (2, 3, 4)),
((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for a_shape, b_shape in test_shapes:
data1, data2 = arange(a_shape), arange(b_shape, 1)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.bmm(b), np.matmul(data1, data2))
def test_ceil(self): def test_ceil(self):
data = np.array([1.4, 1.7, 2.0]) data = np.array([1.4, 1.7, 2.0])
x = new_tensor(data) x = new_tensor(data)
...@@ -334,6 +368,24 @@ class TestTensorOps(OpTestCase): ...@@ -334,6 +368,24 @@ class TestTensorOps(OpTestCase):
data[data > 2] = 0 data[data > 2] = 0
self.assertEqual(x, data) self.assertEqual(x, data)
def test_matmul(self):
test_shapes = [((2,), (2,)),
((2,), (2, 3)),
((2, 3), (3,)),
((2, 3), (3, 4)),
((2,), (4, 2, 3)),
((4, 2, 3), (3,)),
((1, 2, 3), (2, 3, 4)),
((2, 2, 3), (1, 3, 4)),
((2, 2, 3), (2, 3, 4)),
((2, 1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 2, 3, 4)),
((2, 1, 2, 3), (1, 2, 3, 4))]
for a_shape, b_shape in test_shapes:
data1, data2 = arange(a_shape), arange(b_shape, 1)
a, b = new_tensor(data1, False), new_tensor(data2, False)
self.assertEqual(a.matmul(b), np.matmul(data1, data2))
def test_max(self): def test_max(self):
entries = [(0, True), (0, False), entries = [(0, True), (0, False),
(1, True), (1, False), (1, True), (1, False),
...@@ -382,20 +434,12 @@ class TestTensorOps(OpTestCase): ...@@ -382,20 +434,12 @@ class TestTensorOps(OpTestCase):
self.assertEqual(y, np.minimum(data1, data2)) self.assertEqual(y, np.minimum(data1, data2))
def test_mm(self): def test_mm(self):
entries = [ entries = [((2, 3), (3, 4))]
((2, 3), (3, 4), False, False), for a_shape, b_shape in entries:
((2, 3), (4, 3), False, True),
((3, 2), (3, 4), True, False),
((3, 2), (4, 3), True, True)]
for a_shape, b_shape, trans_a, trans_b in entries:
data1, data2 = arange(a_shape), arange(b_shape) data1, data2 = arange(a_shape), arange(b_shape)
a, b = new_tensor(data1), new_tensor(data2) a, b = new_tensor(data1), new_tensor(data2)
if trans_a or trans_b:
y = torch.mm(a, b, trans_a, trans_b)
else:
y = a.mm(b) y = a.mm(b)
self.assertEqual(y, np.matmul(data1.T if trans_a else data1, self.assertEqual(y, np.matmul(data1, data2))
data2.T if trans_b else data2))
def test_mul(self): def test_mul(self):
for a_shape, b_shape in self.binary_test_shapes: for a_shape, b_shape in self.binary_test_shapes:
......
...@@ -94,9 +94,12 @@ from dragon.vm.torch.core.ops.init.functional import zeros ...@@ -94,9 +94,12 @@ from dragon.vm.torch.core.ops.init.functional import zeros
from dragon.vm.torch.core.ops.init.functional import zeros_like from dragon.vm.torch.core.ops.init.functional import zeros_like
from dragon.vm.torch.core.ops.math.functional import abs from dragon.vm.torch.core.ops.math.functional import abs
from dragon.vm.torch.core.ops.math.functional import add from dragon.vm.torch.core.ops.math.functional import add
from dragon.vm.torch.core.ops.math.functional import addmm
from dragon.vm.torch.core.ops.math.functional import axpby from dragon.vm.torch.core.ops.math.functional import axpby
from dragon.vm.torch.core.ops.math.functional import baddbmm
from dragon.vm.torch.core.ops.math.functional import bitwise_not from dragon.vm.torch.core.ops.math.functional import bitwise_not
from dragon.vm.torch.core.ops.math.functional import bitwise_xor from dragon.vm.torch.core.ops.math.functional import bitwise_xor
from dragon.vm.torch.core.ops.math.functional import bmm
from dragon.vm.torch.core.ops.math.functional import ceil from dragon.vm.torch.core.ops.math.functional import ceil
from dragon.vm.torch.core.ops.math.functional import clamp from dragon.vm.torch.core.ops.math.functional import clamp
from dragon.vm.torch.core.ops.math.functional import cos from dragon.vm.torch.core.ops.math.functional import cos
...@@ -112,6 +115,7 @@ from dragon.vm.torch.core.ops.math.functional import le ...@@ -112,6 +115,7 @@ from dragon.vm.torch.core.ops.math.functional import le
from dragon.vm.torch.core.ops.math.functional import log from dragon.vm.torch.core.ops.math.functional import log
from dragon.vm.torch.core.ops.math.functional import logsumexp from dragon.vm.torch.core.ops.math.functional import logsumexp
from dragon.vm.torch.core.ops.math.functional import lt from dragon.vm.torch.core.ops.math.functional import lt
from dragon.vm.torch.core.ops.math.functional import matmul
from dragon.vm.torch.core.ops.math.functional import maximum from dragon.vm.torch.core.ops.math.functional import maximum
from dragon.vm.torch.core.ops.math.functional import minimum from dragon.vm.torch.core.ops.math.functional import minimum
from dragon.vm.torch.core.ops.math.functional import mm from dragon.vm.torch.core.ops.math.functional import mm
......
...@@ -76,6 +76,12 @@ from dragon.vm.torch.core.nn.modules.padding import ReplicationPad1d ...@@ -76,6 +76,12 @@ from dragon.vm.torch.core.nn.modules.padding import ReplicationPad1d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad2d from dragon.vm.torch.core.nn.modules.padding import ReplicationPad2d
from dragon.vm.torch.core.nn.modules.padding import ReplicationPad3d from dragon.vm.torch.core.nn.modules.padding import ReplicationPad3d
from dragon.vm.torch.core.nn.modules.padding import ZeroPad2d from dragon.vm.torch.core.nn.modules.padding import ZeroPad2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool1d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveAvgPool3d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool1d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool2d
from dragon.vm.torch.core.nn.modules.pooling import AdaptiveMaxPool3d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool1d from dragon.vm.torch.core.nn.modules.pooling import AvgPool1d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool2d from dragon.vm.torch.core.nn.modules.pooling import AvgPool2d
from dragon.vm.torch.core.nn.modules.pooling import AvgPool3d from dragon.vm.torch.core.nn.modules.pooling import AvgPool3d
......
...@@ -14,6 +14,12 @@ from __future__ import absolute_import as _absolute_import ...@@ -14,6 +14,12 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool1d
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool2d
from dragon.vm.torch.core.nn.functional import adaptive_avg_pool3d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool1d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool2d
from dragon.vm.torch.core.nn.functional import adaptive_max_pool3d
from dragon.vm.torch.core.nn.functional import avg_pool1d from dragon.vm.torch.core.nn.functional import avg_pool1d
from dragon.vm.torch.core.nn.functional import avg_pool2d from dragon.vm.torch.core.nn.functional import avg_pool2d
from dragon.vm.torch.core.nn.functional import avg_pool3d from dragon.vm.torch.core.nn.functional import avg_pool3d
......
...@@ -76,7 +76,7 @@ class Function(object): ...@@ -76,7 +76,7 @@ class Function(object):
Parameters Parameters
---------- ----------
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
......
...@@ -86,7 +86,6 @@ class Pool(function.Function): ...@@ -86,7 +86,6 @@ class Pool(function.Function):
self.pads = kwargs.get('pads', 0) self.pads = kwargs.get('pads', 0)
self.ceil_mode = kwargs.get('ceil_mode', False) self.ceil_mode = kwargs.get('ceil_mode', False)
self.mode = kwargs.get('mode', 'MAX') self.mode = kwargs.get('mode', 'MAX')
self.global_pool = kwargs.get('global_pool', False)
def attributes(self): def attributes(self):
return { return {
...@@ -98,7 +97,6 @@ class Pool(function.Function): ...@@ -98,7 +97,6 @@ class Pool(function.Function):
'ceil_mode': self.ceil_mode, 'ceil_mode': self.ceil_mode,
'mode': self.mode, 'mode': self.mode,
'data_format': 'NCHW', 'data_format': 'NCHW',
'global_pool': self.global_pool,
} }
} }
...@@ -316,24 +314,6 @@ class L2Loss(Loss): ...@@ -316,24 +314,6 @@ class L2Loss(Loss):
} }
class Linear(function.Function):
"""Linear function."""
def attributes(self):
return {
'op_type': 'FullyConnected',
'arguments': {
'axis': -1,
'transW': True,
},
}
def forward(self, input, weight, bias=None, out=None):
inputs = [input, weight] + ([bias] if bias else [])
outputs = [out] if out else [self.alloc()]
return self.dispatch(inputs, outputs)
class LocalResponseNorm(function.Function): class LocalResponseNorm(function.Function):
"""LocalResponseNorm function.""" """LocalResponseNorm function."""
......
...@@ -48,7 +48,7 @@ class Identity(Module): ...@@ -48,7 +48,7 @@ class Identity(Module):
class Linear(Module): class Linear(Module):
r"""Apply the linear transformation. r"""Apply the linear transformation.
.. math:: y = Wx + b .. math:: \text{out} = \text{input} \times \text{weight}^{T} + \text{bias}
Examples: Examples:
......
...@@ -22,6 +22,18 @@ import itertools ...@@ -22,6 +22,18 @@ import itertools
from dragon.core.util import six from dragon.core.util import six
def _get_adaptive_pool_kwargs(input_sizes, output_sizes):
stride, kernel_size = [], []
for input_size, output_size in zip(input_sizes, output_sizes):
if output_size == 1:
stride.append(1)
kernel_size.append(input_size)
else:
stride.append(input_size // output_size)
kernel_size.append(input_size - (output_size - 1) * stride[-1])
return {'kernel_size': kernel_size, 'stride': stride}
def _ntuple(n): def _ntuple(n):
def parse(x): def parse(x):
if isinstance(x, six.collections_abc.Sequence): if isinstance(x, six.collections_abc.Sequence):
......
...@@ -315,12 +315,16 @@ class OneHot(function.Function): ...@@ -315,12 +315,16 @@ class OneHot(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(OneHot, self).__init__(key, dev, **kwargs) super(OneHot, self).__init__(key, dev, **kwargs)
self.depth = kwargs.get('depth', 1) self.depth = kwargs.get('depth', 1)
self.on_value = kwargs.get('on_value', 1)
self.off_value = kwargs.get('off_value', 0)
def attributes(self): def attributes(self):
return { return {
'op_type': 'OneHot', 'op_type': 'OneHot',
'arguments': { 'arguments': {
'depth': self.depth, 'depth': self.depth,
'on_value': self.on_value,
'off_value': self.off_value,
}, },
} }
......
...@@ -46,7 +46,7 @@ def argmax(input, dim=None, keepdim=False, out=None): ...@@ -46,7 +46,7 @@ def argmax(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimension or not. Keep the reduced dimension or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -81,7 +81,7 @@ def argmin(input, dim=None, keepdim=False, out=None): ...@@ -81,7 +81,7 @@ def argmin(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimension or not. Keep the reduced dimension or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -174,7 +174,7 @@ def cat(seq, dim=0, out=None): ...@@ -174,7 +174,7 @@ def cat(seq, dim=0, out=None):
dim : int, optional dim : int, optional
The dim to concatenate. The dim to concatenate.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -197,11 +197,11 @@ def channel_affine(input, weight, bias=None, dim=0, out=None): ...@@ -197,11 +197,11 @@ def channel_affine(input, weight, bias=None, dim=0, out=None):
weight : dragon.vm.torch.Tensor weight : dragon.vm.torch.Tensor
The weight tensor. The weight tensor.
bias : dragon.vm.torch.Tensor, optional bias : dragon.vm.torch.Tensor, optional
The optional bias. The bias tensor.
dim : int, optional, default=0 dim : int, optional, default=0
The start dimension to transform. The start dimension to transform.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -369,7 +369,7 @@ def cumsum(input, dim, out=None): ...@@ -369,7 +369,7 @@ def cumsum(input, dim, out=None):
dim : int dim : int
The cumulative dimension. The cumulative dimension.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -429,7 +429,7 @@ def flatten(input, start_dim=0, end_dim=-1, out=None): ...@@ -429,7 +429,7 @@ def flatten(input, start_dim=0, end_dim=-1, out=None):
end_dim : int, optional, default=-1 end_dim : int, optional, default=-1
The end dimension to flatten. The end dimension to flatten.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -465,7 +465,7 @@ def index_select(input, dim, index, out=None): ...@@ -465,7 +465,7 @@ def index_select(input, dim, index, out=None):
index : dragon.vm.torch.Tensor index : dragon.vm.torch.Tensor
The index tensor. The index tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -523,7 +523,7 @@ def masked_select(input, mask, out=None): ...@@ -523,7 +523,7 @@ def masked_select(input, mask, out=None):
mask : dragon.vm.torch.Tensor mask : dragon.vm.torch.Tensor
The mask for selecting. The mask for selecting.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -566,7 +566,7 @@ def max(input, dim=None, keepdim=False, out=None): ...@@ -566,7 +566,7 @@ def max(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -606,7 +606,7 @@ def mean(input, dim=None, keepdim=False, out=None): ...@@ -606,7 +606,7 @@ def mean(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -646,7 +646,7 @@ def min(input, dim=None, keepdim=False, out=None): ...@@ -646,7 +646,7 @@ def min(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -721,7 +721,7 @@ def nonzero(input, out=None): ...@@ -721,7 +721,7 @@ def nonzero(input, out=None):
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -732,7 +732,7 @@ def nonzero(input, out=None): ...@@ -732,7 +732,7 @@ def nonzero(input, out=None):
return _functions.NonZero.instantiate(input.device).apply(input, out) return _functions.NonZero.instantiate(input.device).apply(input, out)
def one_hot(input, depth): def one_hot(input, depth, on_value=1, off_value=0):
r"""Return the one-hot representation for input. r"""Return the one-hot representation for input.
.. math:: .. math::
...@@ -748,6 +748,10 @@ def one_hot(input, depth): ...@@ -748,6 +748,10 @@ def one_hot(input, depth):
The input tensor. The input tensor.
depth : int depth : int
The depth of channels. The depth of channels.
on_value : int, optional, default=1
The value for equal branch.
off_value : int, optional, default=0
The value for not-equal branch.
Returns Returns
------- -------
...@@ -755,7 +759,12 @@ def one_hot(input, depth): ...@@ -755,7 +759,12 @@ def one_hot(input, depth):
The output tensor. The output tensor.
""" """
return _functions.OneHot.instantiate(input.device, depth=depth).apply(input) return _functions.OneHot.instantiate(
input.device,
depth=depth,
on_value=on_value,
off_value=off_value,
).apply(input)
def permute(input, dims): def permute(input, dims):
...@@ -812,7 +821,7 @@ def reshape(input, shape, out=None): ...@@ -812,7 +821,7 @@ def reshape(input, shape, out=None):
shape : Sequence[int] shape : Sequence[int]
The new shape. The new shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -986,7 +995,7 @@ def stack(seq, dim=0, out=None): ...@@ -986,7 +995,7 @@ def stack(seq, dim=0, out=None):
dim : int, optional, default=0 dim : int, optional, default=0
The dim to stack. The dim to stack.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
...@@ -1030,7 +1039,7 @@ def sum(input, dim=None, keepdim=False, out=None): ...@@ -1030,7 +1039,7 @@ def sum(input, dim=None, keepdim=False, out=None):
keepdim : bool, optional, default=False keepdim : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
Returns Returns
------- -------
......
...@@ -59,7 +59,7 @@ def arange( ...@@ -59,7 +59,7 @@ def arange(
step : number, optional, default=1 step : number, optional, default=1
The spacing between two elements. The spacing between two elements.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -113,7 +113,7 @@ def eye( ...@@ -113,7 +113,7 @@ def eye(
m : int, optional m : int, optional
The number output cols. The number output cols.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -175,7 +175,7 @@ def full( ...@@ -175,7 +175,7 @@ def full(
fill_value : number fill_value : number
The scalar to fill. The scalar to fill.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -216,7 +216,7 @@ def full_like( ...@@ -216,7 +216,7 @@ def full_like(
fill_value : number fill_value : number
The scalar to fill. The scalar to fill.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -268,7 +268,7 @@ def linspace( ...@@ -268,7 +268,7 @@ def linspace(
steps : int, optional, default=100 steps : int, optional, default=100
The number of values to generate. The number of values to generate.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
dim : int, optional, default=0 dim : int, optional, default=0
...@@ -326,7 +326,7 @@ def ones(*size, **kwargs): ...@@ -326,7 +326,7 @@ def ones(*size, **kwargs):
size : int... size : int...
The size(s) indicating the out shape. The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -378,7 +378,7 @@ def rand(*size, **kwargs): ...@@ -378,7 +378,7 @@ def rand(*size, **kwargs):
size : int... size : int...
The size(s) indicating the out shape. The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -404,7 +404,7 @@ def randn(*size, **kwargs): ...@@ -404,7 +404,7 @@ def randn(*size, **kwargs):
size : int... size : int...
The size(s) indicating the out shape. The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -436,7 +436,7 @@ def randperm(n, out=None, dtype='int64', device=None, requires_grad=False): ...@@ -436,7 +436,7 @@ def randperm(n, out=None, dtype='int64', device=None, requires_grad=False):
n: number n: number
The end of interval. The end of interval.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
...@@ -479,7 +479,7 @@ def zeros(*size, **kwargs): ...@@ -479,7 +479,7 @@ def zeros(*size, **kwargs):
size : int... size : int...
The size(s) indicating the out shape. The size(s) indicating the out shape.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The output tensor.
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
device : dragon.vm.torch.device, optional device : dragon.vm.torch.device, optional
......
...@@ -77,36 +77,42 @@ class Clip(function.Function): ...@@ -77,36 +77,42 @@ class Clip(function.Function):
return self.dispatch([input], [self.alloc(out)]) return self.dispatch([input], [self.alloc(out)])
class UnaryFunc(function.Function): class Gemm(function.Function):
"""Unary function.""" """Gemm function."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(UnaryFunc, self).__init__(key, dev, **kwargs) super(Gemm, self).__init__(key, dev, **kwargs)
self.op_type = kwargs.get('op_type', '') self.alpha = kwargs.get('alpha', 1.0)
self.beta = kwargs.get('beta', 1.0)
self.transA = kwargs.get('transA', False)
self.transB = kwargs.get('transB', False)
def attributes(self): def attributes(self):
return {'op_type': self.op_type, 'arguments': {}} return {
'op_type': 'Gemm',
'arguments': {
'axis': -1,
'alpha': self.alpha,
'beta': self.beta,
'transA': self.transA,
'transB': self.transB,
},
}
def forward(self, input, out=None): def forward(self, mat1, mat2, mat3=None, out=None):
return self.dispatch([input], [self.alloc(out)]) inputs = [mat1, mat2] + ([mat3] if mat3 else [])
return self.dispatch(inputs, [self.alloc(out)])
class MatMul(function.Function): class UnaryFunc(function.Function):
"""MatMul function.""" """Unary function."""
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(MatMul, self).__init__(key, dev, **kwargs) super(UnaryFunc, self).__init__(key, dev, **kwargs)
self.transpose_a = kwargs.get('transpose_a', False) self.op_type = kwargs.get('op_type', '')
self.transpose_b = kwargs.get('transpose_b', False)
def attributes(self): def attributes(self):
return { return {'op_type': self.op_type, 'arguments': {}}
'op_type': 'MatMul',
'arguments': {
'transA': self.transpose_a,
'transB': self.transpose_b,
},
}
def forward(self, mat1, mat2, out=None): def forward(self, input, out=None):
return self.dispatch([mat1, mat2], [self.alloc(out)]) return self.dispatch([input], [self.alloc(out)])
...@@ -85,6 +85,35 @@ def add_(self, other): ...@@ -85,6 +85,35 @@ def add_(self, other):
return math_funcs.add(self, other, self) return math_funcs.add(self, other, self)
def addmm(self, mat1, mat2, beta=1, alpha=1):
r"""Add the result of matrix-matrix multiplication.
.. math:: \text{out} = \alpha (\text{mat1} \times \text{mat2}) + \beta \text{self}
Parameters
----------
mat1 : dragon.vm.torch.Tensor
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.addmm(...)`_
"""
return math_funcs.addmm(self, mat1, mat2, beta=beta, alpha=alpha)
def argmax(self, dim=None, keepdim=False): def argmax(self, dim=None, keepdim=False):
"""Return the index of maximum elements. """Return the index of maximum elements.
...@@ -154,6 +183,71 @@ def argsort(self, dim=-1, descending=False): ...@@ -154,6 +183,71 @@ def argsort(self, dim=-1, descending=False):
return array_funcs.argsort(self, dim, descending) return array_funcs.argsort(self, dim, descending)
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{out}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
return math_funcs.baddbmm(self, batch1, batch2, beta=beta, alpha=alpha)
def baddbmm_(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{self}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
return math_funcs.baddbmm(
self, batch1, batch2,
beta=beta, alpha=alpha, out=self,
)
def backward(self, gradient=None, retain_graph=False): def backward(self, gradient=None, retain_graph=False):
"""Compute the derivatives of this tensor w.r.t. graph leaves. """Compute the derivatives of this tensor w.r.t. graph leaves.
...@@ -254,6 +348,29 @@ def bitwise_xor_(self, other): ...@@ -254,6 +348,29 @@ def bitwise_xor_(self, other):
return math_funcs.bitwise_xor(self, other, self) return math_funcs.bitwise_xor(self, other, self)
def bmm(self, batch2):
r"""Compute the batched matrix multiplication.
.. math:: \text{out}_{i} = \text{self}_{i} \times \text{batch2}_{i}
Parameters
----------
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.bmm(...)`_
"""
return math_funcs.bmm(self, batch2)
def bool(self): def bool(self):
"""Return a bool tensor with the same data. """Return a bool tensor with the same data.
...@@ -719,50 +836,6 @@ def floor_(self): ...@@ -719,50 +836,6 @@ def floor_(self):
return math_funcs.floor(self, self) return math_funcs.floor(self, self)
def new_full(
self,
size,
fill_value,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : Sequence[int]
The size of output tensor.
fill_value : number
The scalar to fill.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.full(...)`_
"""
return init_funcs.full(
size,
fill_value,
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def ge(self, other): def ge(self, other):
r"""Compute the element-wise greater-equal comparison. r"""Compute the element-wise greater-equal comparison.
...@@ -1104,6 +1177,29 @@ def masked_select(self, mask): ...@@ -1104,6 +1177,29 @@ def masked_select(self, mask):
return array_funcs.masked_select(self, mask) return array_funcs.masked_select(self, mask)
def matmul(self, tensor2):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{self} \times \text{tensor2}
Parameters
----------
tensor2 : dragon.vm.torch.Tensor
The tensor to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.matmul(...)`_
"""
return math_funcs.matmul(self, tensor2)
def max(self, dim=None, keepdim=False): def max(self, dim=None, keepdim=False):
"""Compute the max value of elements along the given dimension. """Compute the max value of elements along the given dimension.
...@@ -1383,6 +1479,50 @@ def neg_(self): ...@@ -1383,6 +1479,50 @@ def neg_(self):
return math_funcs.neg(self, self) return math_funcs.neg(self, self)
def new_full(
self,
size,
fill_value,
dtype=None,
device=None,
requires_grad=False,
):
"""Return a tensor filled with a scalar.
Refer this tensor if ``dtype`` and ``device`` not provided.
Parameters
----------
size : Sequence[int]
The size of output tensor.
fill_value : number
The scalar to fill.
dtype : str, optional
The optional data type.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.full(...)`_
"""
return init_funcs.full(
size,
fill_value,
dtype=self.dtype if dtype is None else dtype,
device=self.device if device is None else device,
requires_grad=requires_grad,
)
def nonzero(self): def nonzero(self):
r"""Return the index of non-zero elements. r"""Return the index of non-zero elements.
...@@ -1735,7 +1875,7 @@ def sort(self, dim=-1, descending=False): ...@@ -1735,7 +1875,7 @@ def sort(self, dim=-1, descending=False):
def split(self, split_size_or_sections, dim=0): def split(self, split_size_or_sections, dim=0):
"""Return the splited chunks along the given dimension. """Return the split chunks along the given dimension.
Parameters Parameters
---------- ----------
...@@ -2132,14 +2272,18 @@ def _process_index(item): ...@@ -2132,14 +2272,18 @@ def _process_index(item):
Tensor.abs = abs Tensor.abs = abs
Tensor.add = add Tensor.add = add
Tensor.add_ = add_ Tensor.add_ = add_
Tensor.addmm = addmm
Tensor.argmax = argmax Tensor.argmax = argmax
Tensor.argmin = argmin Tensor.argmin = argmin
Tensor.argsort = argsort Tensor.argsort = argsort
Tensor.backward = backward Tensor.backward = backward
Tensor.baddbmm = baddbmm
Tensor.baddbmm_ = baddbmm_
Tensor.bitwise_not = bitwise_not Tensor.bitwise_not = bitwise_not
Tensor.bitwise_not_ = bitwise_not_ Tensor.bitwise_not_ = bitwise_not_
Tensor.bitwise_xor = bitwise_xor Tensor.bitwise_xor = bitwise_xor
Tensor.bitwise_xor_ = bitwise_xor_ Tensor.bitwise_xor_ = bitwise_xor_
Tensor.bmm = bmm
Tensor.bool = bool Tensor.bool = bool
Tensor.bool_ = bool_ Tensor.bool_ = bool_
Tensor.byte = byte Tensor.byte = byte
...@@ -2184,6 +2328,7 @@ Tensor.logsumexp = logsumexp ...@@ -2184,6 +2328,7 @@ Tensor.logsumexp = logsumexp
Tensor.lt = lt Tensor.lt = lt
Tensor.masked_fill_ = masked_fill_ Tensor.masked_fill_ = masked_fill_
Tensor.masked_select = masked_select Tensor.masked_select = masked_select
Tensor.matmul = matmul
Tensor.max = max Tensor.max = max
Tensor.maximum = maximum Tensor.maximum = maximum
Tensor.mean = mean Tensor.mean = mean
......
...@@ -270,6 +270,33 @@ class Tensor(object): ...@@ -270,6 +270,33 @@ class Tensor(object):
""" """
def addmm(self, mat1, mat2, beta=1, alpha=1):
r"""Add the result of matrix-matrix multiplication.
.. math:: \text{out} = \alpha (\text{mat1} \times \text{mat2}) + \beta \text{self}
Parameters
----------
mat1 : dragon.vm.torch.Tensor
The first matrix.
mat2 : dragon.vm.torch.Tensor
The second matrix.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.addmm(...)`_
"""
def argmax(self, dim=None, keepdim=False): def argmax(self, dim=None, keepdim=False):
"""Return the index of maximum elements. """Return the index of maximum elements.
...@@ -345,6 +372,64 @@ class Tensor(object): ...@@ -345,6 +372,64 @@ class Tensor(object):
""" """
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{out}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
def baddbmm_(self, batch1, batch2, beta=1, alpha=1):
r"""Add the result of batched matrix-matrix multiplication.
.. math::
\text{self}_{i} = \alpha (\text{batch1}_{i} \times \text{batch2}_{i}) +
\beta \text{self}_{i}
Parameters
----------
batch1 : dragon.vm.torch.Tensor
The first batch of matrices.
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
beta : float, optional, default=1
The value to :math:`\beta`.
alpha : float, optional, default=1
The value to :math:`\alpha`.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.baddbmm(...)`_
"""
def bitwise_not(self): def bitwise_not(self):
r"""Compute the element-wise NOT bitwise operation. r"""Compute the element-wise NOT bitwise operation.
...@@ -419,6 +504,27 @@ class Tensor(object): ...@@ -419,6 +504,27 @@ class Tensor(object):
""" """
def bmm(self, batch2):
r"""Compute the batched matrix multiplication.
.. math:: \text{out}_{i} = \text{self}_{i} \times \text{batch2}_{i}
Parameters
----------
batch2 : dragon.vm.torch.Tensor
The second batch of matrices.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.bmm(...)`_
"""
def bool(self): def bool(self):
"""Return a bool tensor with the same data. """Return a bool tensor with the same data.
...@@ -1192,6 +1298,27 @@ class Tensor(object): ...@@ -1192,6 +1298,27 @@ class Tensor(object):
""" """
def matmul(self, tensor2):
r"""Compute the matrix multiplication.
.. math:: \text{out} = \text{self} \times \text{tensor2}
Parameters
----------
tensor2 : dragon.vm.torch.Tensor
The tensor to multiply.
Returns
-------
dragon.Tensor
The output tensor.
See Also
--------
`torch.matmul(...)`_
"""
def max(self, dim=None, keepdim=False): def max(self, dim=None, keepdim=False):
"""Compute the max value of elements along the given dimension. """Compute the max value of elements along the given dimension.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!