Commit 094c8c32 by Ting PAN

Add Reverse operator

Summary:
This commit adds reverse or flip operator.
1 parent bdf4e10f
......@@ -136,6 +136,9 @@ dragon
`reshape(...) <dragon/reshape.html>`_
: Change the dimensions of input.
`reverse(...) <dragon/reverse.html>`_
: Reverse elements along the given axis.
`roll(...) <dragon/roll.html>`_
: Roll elements along the given axis.
......@@ -237,6 +240,7 @@ dragon
dragon/repeat
dragon/reset_workspace
dragon/reshape
dragon/reverse
dragon/roll
dragon/scatter_add
dragon/scatter_elements
......
reverse
=======
.. autofunction:: dragon.reverse
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
......@@ -87,6 +87,9 @@ vm.tensorflow
`reshape(...) <tensorflow/reshape.html>`_
: Change the dimensions of input.
`reverse(...) <tensorflow/reverse.html>`_
: Reverse elements along the given axis.
`roll(...) <tensorflow/roll.html>`_
: Roll elements along the given axis.
......@@ -152,6 +155,7 @@ vm.tensorflow
tensorflow/pad
tensorflow/range
tensorflow/reshape
tensorflow/reverse
tensorflow/roll
tensorflow/shape
tensorflow/slice
......
reverse
=======
.. autofunction:: dragon.vm.tensorflow.reverse
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
......@@ -111,6 +111,15 @@ vm.torch
`flatten(...) <torch/flatten.html>`_
: Return a tensor with dimensions flattened.
`flip(...) <torch/flip.html>`_
: Reverse elements along the given dimension.
`fliplr(...) <torch/fliplr.html>`_
: Reverse elements along the second dimension.
`flipud(...) <torch/flipud.html>`_
: Reverse elements along the first dimension.
`floor(...) <torch/floor.html>`_
: Compute the largest integer not greater than input.
......@@ -350,6 +359,9 @@ vm.torch
torch/exp
torch/eye
torch/flatten
torch/flip
torch/fliplr
torch/flipud
torch/floor
torch/from_numpy
torch/full
......
......@@ -233,6 +233,18 @@ flatten\_
#########
.. automethod:: dragon.vm.torch.Tensor.flatten_
flip
####
.. automethod:: dragon.vm.torch.Tensor.flip
fliplr
######
.. automethod:: dragon.vm.torch.Tensor.fliplr
flipud
######
.. automethod:: dragon.vm.torch.Tensor.flipud
float
#####
.. automethod:: dragon.vm.torch.Tensor.float
......@@ -650,6 +662,9 @@ zero\_
.. _torch.eq(...): eq.html
.. _torch.exp(...): exp.html
.. _torch.flatten(...): flatten.html
.. _torch.flip(...): flip.html
.. _torch.fliplr(...): fliplr.html
.. _torch.flipud(...): flipud.html
.. _torch.floor(...): floor.html
.. _torch.full(...): full.html
.. _torch.gather(...): gather.html
......
flip
====
.. autofunction:: dragon.vm.torch.flip
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
fliplr
======
.. autofunction:: dragon.vm.torch.fliplr
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
flipud
======
.. autofunction:: dragon.vm.torch.flipud
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T>
void _Reverse(
const int num_dims,
const uint8_t* x_flips,
const int64_t* x_strides,
const int64_t* y_dims,
const T* x,
T* y) {
const auto N =
std::accumulate(y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>());
vec64_t index(num_dims, 0);
int64_t xi;
for (int yi = 0; yi < N; ++yi) {
xi = 0;
for (int d = num_dims - 1; d >= 0; --d) {
xi += (x_flips[d] ? y_dims[d] - index[d] - 1 : index[d]) * x_strides[d];
}
y[yi] = x[xi];
math::utils::IncreaseIndexInDims(num_dims, y_dims, index.data());
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Reverse<T, CPUContext>( \
const int num_dims, \
const uint8_t* x_flips, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_Reverse(num_dims, x_flips, x_strides, y_dims, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernels {
namespace {
template <typename T, int D>
__global__ void _Reverse(
const int N,
const int num_dims,
const SimpleArray<uint8_t, D> X_flips,
const SimpleArray<int, D> X_strides,
const SimpleArray<int, D> Y_dims,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, N) {
int xi = 0, tmp = yi;
for (int d = num_dims - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], tmp, &tmp, &r);
xi += (X_flips.data[d] ? Y_dims.data[d] - r - 1 : r) * X_strides.data[d];
}
y[yi] = x[xi];
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Reverse<T, CUDAContext>( \
const int num_dims, \
const uint8_t* x_flips, \
const int64_t* x_strides, \
const int64_t* y_dims, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK(num_dims); \
SimpleArray<uint8_t, CUDA_TENSOR_MAX_DIMS> X_flips; \
SimpleArray<int, CUDA_TENSOR_MAX_DIMS> X_strides, Y_dims; \
const auto N = std::accumulate( \
y_dims, y_dims + num_dims, 1, std::multiplies<int64_t>()); \
for (int i = 0; i < num_dims; ++i) { \
X_flips.data[i] = x_flips[i]; \
X_strides.data[i] = x_strides[i]; \
Y_dims.data[i] = y_dims[i]; \
} \
_Reverse<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, num_dims, X_flips, X_strides, Y_dims, x, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels
} // namespace dragon
#endif // USE_CUDA
#include "dragon/operators/array/reverse_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void ReverseOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
int num_dims = X.ndim();
vector<uint8_t> X_flips(num_dims, 0);
for (int i = 0; i < axes_.size(); ++i) {
int axis = axes_[i];
axis = axis < 0 ? axis + num_dims : axis;
CHECK(axis >= 0 && axis < num_dims)
<< "\nExcepted the <axis> in [-" << num_dims << ", " << num_dims
<< "), got " << axes_[i] << ".";
X_flips[axis] = 1;
}
kernels::Reverse(
num_dims,
X_flips.data(),
X.strides().data(),
X.dims().data(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
}
DEPLOY_CPU_OPERATOR(Reverse);
REGISTER_CPU_OPERATOR(ReverseGradient, ReverseOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Reverse);
REGISTER_CUDA_OPERATOR(ReverseGradient, ReverseOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(Reverse)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(ReverseGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(Reverse, SimpleGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_REVERSE_OP_H_
#define DRAGON_OPERATORS_ARRAY_REVERSE_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class ReverseOp final : public Operator<Context> {
public:
ReverseOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), axes_(OP_REPEATED_ARG(int64_t, "axes")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
vec64_t axes_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_REVERSE_OP_H_
......@@ -11,11 +11,8 @@ void GroupNormOp<Context>::DoRunWithType() {
using ParamT = typename math::AccmulatorType<T>::type;
INITIALIZE_TENSOR_VIA_SPEC(Input(1), vec64_t({C_}), ParamT);
INITIALIZE_TENSOR_VIA_SPEC(Input(2), vec64_t({C_}), ParamT);
auto* X_mu = Buffer("X_mu")->Reshape({N_, G_});
auto* X_rsig = Buffer("X_rsig")->Reshape({N_, G_});
auto* X_scale = Buffer("X_scale")->Reshape({N_, C_});
auto* X_bias = Buffer("X_bias")->Reshape({N_, C_});
auto* x = Input(0).template data<T, Context>();
auto* mu = X_mu->template mutable_data<ParamT, Context>();
......@@ -36,6 +33,8 @@ void GroupNormOp<Context>::DoRunWithType() {
math::InvStd(N_ * G_, epsilon_, rsig, rsig, ctx());
// Fuse parameters to compute affine transformation
auto* scratch =
ctx()->workspace()->template data<ParamT, Context>({2 * N_ * C_})[0];
kernels::GroupNorm(
N_,
G_,
......@@ -47,8 +46,8 @@ void GroupNormOp<Context>::DoRunWithType() {
rsig,
Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<ParamT, Context>(), // beta
X_scale->template mutable_data<ParamT, Context>(),
X_bias->template mutable_data<ParamT, Context>(),
scratch,
scratch + N_ * C_,
Output(0)->template mutable_data<T, Context>(),
ctx());
}
......@@ -65,12 +64,11 @@ template <typename T>
void GroupNormGradientOp<Context>::DoRunWithType() {
using ParamT = typename math::AccmulatorType<T>::type;
auto *dX = Output(0), *dW = Output(1), *dB = Output(2);
auto *X_mu = Buffer("X_mu"), *X_rsig = Buffer("X_rsig");
auto* X_scale = Buffer("X_scale")->Reshape({N_, G_});
auto* X_bias = Buffer("X_bias")->Reshape({N_, G_});
// Gradient w.r.t. gamma, beta and input
auto* scratch =
ctx()->workspace()->template data<ParamT, Context>({2 * N_ * G_})[0];
kernels::GroupNormGrad(
N_,
G_,
......@@ -82,8 +80,8 @@ void GroupNormGradientOp<Context>::DoRunWithType() {
X_rsig->template data<ParamT, Context>(),
Input(1).template data<ParamT, Context>(), // gamma
Input(2).template data<T, Context>(), // dy
X_scale->template mutable_data<ParamT, Context>(),
X_bias->template mutable_data<ParamT, Context>(),
scratch,
scratch + N_ * G_,
dW->Reshape({C_})->template mutable_data<ParamT, Context>(),
dB->Reshape({C_})->template mutable_data<ParamT, Context>(),
dX->template mutable_data<T, Context>(),
......@@ -120,7 +118,6 @@ OPERATOR_SCHEMA(GroupNormGradient)
.NumOutputs(3);
namespace {
class GradientMaker final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GradientMaker);
......
......@@ -76,6 +76,7 @@ from dragon.core.ops.array_ops import pad
from dragon.core.ops.array_ops import range
from dragon.core.ops.array_ops import repeat
from dragon.core.ops.array_ops import reshape
from dragon.core.ops.array_ops import reverse
from dragon.core.ops.array_ops import roll
from dragon.core.ops.array_ops import scatter_add
from dragon.core.ops.array_ops import scatter_elements
......
......@@ -456,6 +456,11 @@ def resize_args(**kwargs):
}
@register('Reverse')
def reverse_args(**kwargs):
return {'axes': kwargs.get('axes', None)}
@register('Recurrent')
def rnn_args(**kwargs):
return {
......
......@@ -1229,6 +1229,42 @@ def reshape(inputs, shape, copy=True, **kwargs):
@OpSchema.num_inputs(1)
def reverse(inputs, axis, **kwargs):
"""Reverse elements along the given axis.
:attr:`axis` could be negative:
```python
x = dragon.constant([[1, 2, 3], [4, 5, 6]])
# A negative axis is the last-k axis
print(dragon.reverse(x, axis=1)) # [[3, 2, 1], [6, 5, 4]]
print(dragon.reverse(x, axis=-1)) # Equivalent
# Also, axis could be a sequence of integers
print(dragon.reverse(x, axis=(0, 1))) # [[6, 5, 4], [3, 2, 1]]
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
axis : Union[int, Sequence[int]]
The axis to reverse.
Returns
-------
dragon.Tensor
The output tensor.
"""
axes = nest.flatten(axis) if axis is not None else axis
if context.executing_eagerly():
return OpLib.execute('Reverse', inputs, axes=axes)
return OpLib.add('Reverse', inputs, axes=axes, **kwargs)
@OpSchema.num_inputs(1)
@OpSchema.convert_arg('shift', name_v2='shifts')
def roll(inputs, shift, axis=None, **kwargs):
"""Roll elements along the given axis.
......
......@@ -502,6 +502,16 @@ void RepeatGrad(
Context* ctx);
template <typename T, class Context>
void Reverse(
const int num_dims,
const uint8_t* x_flips,
const int64_t* x_strides,
const int64_t* y_dims,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void Roll(
const int num_dims,
const int64_t* x_shifts,
......
......@@ -80,6 +80,7 @@ from dragon.vm.tensorflow.core.ops.array_ops import one_hot
from dragon.vm.tensorflow.core.ops.array_ops import pad
from dragon.vm.tensorflow.core.ops.array_ops import placeholder
from dragon.vm.tensorflow.core.ops.array_ops import reshape
from dragon.vm.tensorflow.core.ops.array_ops import reverse
from dragon.vm.tensorflow.core.ops.array_ops import roll
from dragon.vm.tensorflow.core.ops.array_ops import shape
from dragon.vm.tensorflow.core.ops.array_ops import slice
......
......@@ -503,6 +503,40 @@ def reshape(tensor, shape, name=None):
return array_ops.reshape(tensor, shape=shape, name=name)
def reverse(tensor, axis, name=None):
"""Reverse elements along the given axis.
:attr:`axis` could be negative:
```python
x = tf.constant([[1, 2, 3], [4, 5, 6]])
# A negative axis is the last-k axis
print(tf.reverse(x, axis=1)) # [[3, 2, 1], [6, 5, 4]]
print(tf.reverse(x, axis=-1)) # Equivalent
# Also, axis could be a sequence of integers
print(tf.reverse(x, axis=(0, 1))) # [[6, 5, 4], [3, 2, 1]]
```
Parameters
----------
tensor : dragon.Tensor
The input tensor.
axis : Union[int, Sequence[int]]
The axis to reverse.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The output tensor.
"""
return array_ops.reverse(tensor, axis=axis, name=name)
def roll(input, shift, axis, name=None):
"""Roll elements along the given axis.
......
......@@ -938,6 +938,24 @@ class TestArrayOps(OpTestCase):
with dragon.device('cuda'):
self.test_reshape()
def test_reverse(self):
entries = [0, 1, (1, 2)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for axis in entries:
data = arange((2, 3, 4))
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.reverse(x, axis)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual([y, dx], [np.flip(data, axis), np.flip(data, axis)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_reverse_cuda(self):
with dragon.device('cuda'):
self.test_reverse()
def test_shape(self):
entries = [(2, 3), (2, 3, 3)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......
......@@ -288,6 +288,13 @@ class TestTensorOps(OpTestCase):
x.flatten_(-3, -2)
self.assertEqual(x, data.reshape((2, 3)))
def test_flip(self):
data = arange((2, 3, 4))
x = new_tensor(data)
self.assertEqual(x.flip((1, 2)), np.flip(data, (1, 2)))
self.assertEqual(x.fliplr(), np.fliplr(data))
self.assertEqual(x.flipud(), np.flipud(data))
def test_floor(self):
data = np.array([0.9, 1.4, 1.9])
x = new_tensor(data)
......
......@@ -55,6 +55,9 @@ from dragon.vm.torch.core.ops.array_ops import channel_normalize
from dragon.vm.torch.core.ops.array_ops import chunk
from dragon.vm.torch.core.ops.array_ops import cumsum
from dragon.vm.torch.core.ops.array_ops import flatten
from dragon.vm.torch.core.ops.array_ops import flip
from dragon.vm.torch.core.ops.array_ops import fliplr
from dragon.vm.torch.core.ops.array_ops import flipud
from dragon.vm.torch.core.ops.array_ops import gather
from dragon.vm.torch.core.ops.array_ops import index_select
from dragon.vm.torch.core.ops.array_ops import masked_select
......
......@@ -1557,22 +1557,29 @@ def multi_head_attention_forward(
assert embed_dim == embed_dim_to_check
assert src_len == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
scaling = float(head_dim) ** -0.5
def to_qkv(input, weight, bias, num_proj=1):
"""Compute input projections via a single matmul."""
qkv_size = (tgt_len, bsz, num_proj * num_heads, head_dim)
outputs = linear(input, weight, bias).reshape_(qkv_size)
outputs = outputs.permute(1, 2, 0, 3)
return outputs if num_proj == 1 else outputs.chunk(num_proj, 1)
q, k, v = None, None, None
if not use_separate_proj_weight:
if (query is key) and (key is value):
# Parallelism for self attention
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
# Parallelism for self attention.
q, k, v = to_qkv(query, in_proj_weight, in_proj_bias, 3)
elif key is value:
# Parallelism for encode-decoder attention
# Parallelism for encode-decoder attention.
q_proj_weight = in_proj_weight[:embed_dim, :]
kv_proj_weight = in_proj_weight[embed_dim:, :]
q_proj_bias = kv_proj_bias = in_proj_bias
if in_proj_bias is not None:
q_proj_bias = in_proj_bias[:embed_dim]
kv_proj_bias = in_proj_bias[embed_dim:]
q = linear(query, q_proj_weight, q_proj_bias)
k, v = linear(key, kv_proj_weight, kv_proj_bias).chunk(2, dim=-1)
q = to_qkv(query, q_proj_weight, q_proj_bias)
k, v = to_qkv(key, kv_proj_weight, kv_proj_bias, 2)
if q is None:
q_proj_bias = k_proj_bias = v_proj_bias = in_proj_bias
if use_separate_proj_weight and q_proj_weight is None:
......@@ -1583,37 +1590,28 @@ def multi_head_attention_forward(
q_proj_bias = in_proj_bias[:embed_dim]
k_proj_bias = in_proj_bias[embed_dim:embed_dim * 2]
v_proj_bias = in_proj_bias[embed_dim * 2:]
q = linear(query, q_proj_weight, q_proj_bias)
k = linear(key, k_proj_weight, k_proj_bias)
v = linear(value, v_proj_weight, v_proj_bias)
q *= scaling
q = q.reshape_((-1, bsz * num_heads, head_dim)).transpose(0, 1)
k = k.reshape_((-1, bsz * num_heads, head_dim)).transpose(0, 1)
v = v.reshape_((-1, bsz * num_heads, head_dim)).transpose(0, 1)
attn_weights = q.bmm(k.transpose(1, 2))
assert attn_weights.size() == (bsz * num_heads, tgt_len, src_len)
q = to_qkv(query, q_proj_weight, q_proj_bias)
k = to_qkv(key, k_proj_weight, k_proj_bias)
v = to_qkv(value, v_proj_weight, v_proj_bias)
q *= float(head_dim) ** -0.5
attn = q.bmm(k.transpose(-2, -1))
assert attn.size() == (bsz, num_heads, tgt_len, src_len)
if attn_mask is not None:
if attn_mask.dtype == 'bool' or attn_mask.dtype == 'uint8':
attn_weights.masked_fill_(attn_mask, float('-inf'))
attn.masked_fill_(attn_mask, float('-inf'))
else:
attn_weights += attn_mask
attn += attn_mask
if key_padding_mask is not None:
attn_weights.reshape_((bsz, num_heads, tgt_len, src_len))
if key_padding_mask.size() != attn_weights.size():
if key_padding_mask.size() != attn.size():
key_padding_mask.reshape_((bsz, 1, 1, src_len))
attn_weights.masked_fill_(key_padding_mask, float('-inf'))
attn_weights.reshape_((bsz * num_heads, tgt_len, src_len))
attn_weights = softmax(attn_weights, dim=-1, inplace=True)
attn_weights = dropout(attn_weights, p=dropout_p, training=training)
attn_output = attn_weights.bmm(v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).reshape_((tgt_len, bsz, embed_dim))
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
weights = attn_weights.reshape((bsz, num_heads, tgt_len, src_len))
return attn_output, weights.mean(dim=1)
else:
return attn_output, None
attn.masked_fill_(key_padding_mask, float('-inf'))
attn = softmax(attn, dim=-1, inplace=True)
attn = dropout(attn, p=dropout_p, training=training)
output = attn.bmm(v).permute(2, 0, 1, 3)
output = output.reshape_((tgt_len, bsz, embed_dim))
output = linear(output, out_proj_weight, out_proj_bias)
weights = attn.mean(dim=1) if need_weights else None
return output, weights
def nll_loss(
......
......@@ -161,21 +161,26 @@ class TransformerDecoderLayer(Module):
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
):
tgt2 = self.self_attn(tgt, tgt, tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
need_weights=False)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
need_weights=False)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.self_attn(
tgt, tgt, tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
need_weights=False)[0]
tgt2 = self.dropout1(tgt2)
tgt2 += tgt
tgt = self.norm1(tgt2)
tgt2 = self.multihead_attn(
tgt, memory, memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
need_weights=False)[0]
tgt2 = self.dropout2(tgt2)
tgt2 += tgt
tgt = self.norm2(tgt2)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
tgt2 = self.dropout3(tgt2)
tgt2 += tgt
tgt = self.norm3(tgt2)
return tgt
......@@ -292,15 +297,18 @@ class TransformerEncoderLayer(Module):
self.activation = _get_activation_fn(activation)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
src2 = self.self_attn(src, src, src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
need_weights=False)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.self_attn(
src, src, src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
need_weights=False)[0]
src2 = self.dropout1(src2)
src2 += src
src = self.norm1(src2)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
src2 = self.dropout2(src2)
src2 += src
src = self.norm2(src2)
return src
......
......@@ -392,6 +392,88 @@ def flatten(input, start_dim=0, end_dim=-1, out=None):
axis=start_dim, end_axis=end_dim)
def flip(input, dims):
"""Reverse elements along the given dimension.
:attr:`dims` could be negative:
```python
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# A negative dimension is the last-k dimension
print(torch.flip(x, dims=1)) # [[3, 2, 1], [6, 5, 4]]
print(torch.flip(x, dims=-1)) # Equivalent
# Also, dimension could be a sequence of integers
print(torch.flip(x, dims=(0, 1))) # [[6, 5, 4], [3, 2, 1]]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dims : Union[int, Sequence[int]]
The dimension to reverse.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return FunctionLib.apply(
'Reverse', input.device, [input],
axes=nest.flatten(dims) if dims is not None else dims)
def fliplr(input):
"""Reverse elements along the second dimension.
Examples:
```python
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(torch.fliplr(x)) # [[3, 2, 1], [6, 5, 4]]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return flip(input, 1)
def flipud(input):
"""Reverse elements along the first dimension.
Examples:
```python
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(torch.flipud(x)) # [4, 5, 6], [1, 2, 3]]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return flip(input, 0)
def gather(input, dim, index, out=None):
"""Gather elements along the given dimension of index.
......@@ -559,10 +641,8 @@ def max(input, dim=None, keepdim=False, out=None):
The output tensor.
"""
if dim is None:
keepdim = False
else:
dim = nest.flatten(dim)
keepdim = keepdim if dim is not None else False
dim = nest.flatten(dim) if dim is not None else dim
return FunctionLib.apply(
'ReduceMax', input.device, [input], outputs=[out],
axes=dim, keepdims=keepdim)
......@@ -605,10 +685,8 @@ def mean(input, dim=None, keepdim=False, out=None):
The output tensor.
"""
if dim is None:
keepdim = False
else:
dim = nest.flatten(dim)
keepdim = keepdim if dim is not None else False
dim = nest.flatten(dim) if dim is not None else dim
return FunctionLib.apply(
'ReduceMean', input.device, [input], outputs=[out],
axes=dim, keepdims=keepdim)
......@@ -651,10 +729,8 @@ def min(input, dim=None, keepdim=False, out=None):
The output tensor.
"""
if dim is None:
keepdim = False
else:
dim = nest.flatten(dim)
keepdim = keepdim if dim is not None else False
dim = nest.flatten(dim) if dim is not None else dim
return FunctionLib.apply(
'ReduceMin', input.device, [input], outputs=[out],
axes=dim, keepdims=keepdim)
......@@ -1208,10 +1284,8 @@ def sum(input, dim=None, keepdim=False, out=None):
The output tensor.
"""
if dim is None:
keepdim = False
else:
dim = nest.flatten(dim)
keepdim = keepdim if dim is not None else False
dim = nest.flatten(dim) if dim is not None else dim
return FunctionLib.apply(
'ReduceSum', input.device, [input], outputs=[out],
axes=dim, keepdims=keepdim)
......
......@@ -828,7 +828,7 @@ def fill_(self, value):
def flatten(self, start_dim=0, end_dim=-1):
"""Return a new tensor with dimensions flattened.
"""Return a tensor with dimensions flattened.
Parameters
----------
......@@ -873,6 +873,59 @@ def flatten_(self, start_dim=0, end_dim=-1):
return array_ops.flatten(self, start_dim, end_dim, self)
def flip(self, dims):
"""Return a tensor with elements reversed along the given dimension.
Parameters
----------
dims : Union[int, Sequence[int]]
The dimension to reverse.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.flip(...)`_
"""
return array_ops.flip(self, dims)
def fliplr(self):
"""Return a tensor with elements reversed along the second dimension.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.fliplr(...)`_
"""
return array_ops.fliplr(self)
def flipud(self):
"""Return a tensor with elements reversed along the first dimension.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.flipud(...)`_
"""
return array_ops.flipud(self)
def _float(self):
"""Return a float32 tensor with the same data.
......@@ -1630,7 +1683,7 @@ def multinomial(self, num_samples):
def narrow(self, dimension, start, length):
"""Return a new tensor that is a narrowed version of input tensor.
"""Return a narrowed tensor.
Parameters
----------
......@@ -2512,7 +2565,7 @@ def topk(self, k, dim=-1, largest=True, sorted=True):
def transpose(self, dim0, dim1):
"""Return a new tensor with two dimensions swapped.
"""Return a tensor with two dimensions swapped.
Parameters
----------
......@@ -2867,6 +2920,9 @@ Tensor.expand = expand
Tensor.fill_ = fill_
Tensor.flatten = flatten
Tensor.flatten_ = flatten_
Tensor.flip = flip
Tensor.fliplr = fliplr
Tensor.flipud = flipud
Tensor.float = _float
Tensor.float_ = _float_
Tensor.floor = floor
......
......@@ -1021,7 +1021,7 @@ class Tensor(object):
"""
def flatten(self, start_dim=0, end_dim=-1):
"""Return a new tensor with dimensions flattened.
"""Return a tensor with dimensions flattened.
Parameters
----------
......@@ -1062,6 +1062,53 @@ class Tensor(object):
"""
def flip(self, dims):
"""Return a tensor with elements reversed along the given dimension.
Parameters
----------
dims : Union[int, Sequence[int]]
The dimension to reverse.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.flip(...)`_
"""
def fliplr(self):
"""Return a tensor with elements reversed along the second dimension.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.fliplr(...)`_
"""
def flipud(self):
"""Return a tensor with elements reversed along the first dimension.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.flipud(...)`_
"""
def float(self):
"""Return a float32 tensor with the same data.
......@@ -1723,7 +1770,7 @@ class Tensor(object):
"""
def narrow(self, dimension, start, length):
"""Return a new tensor that is a narrowed version of input tensor.
"""Return a narrowed tensor.
Parameters
----------
......@@ -2026,7 +2073,7 @@ class Tensor(object):
return self.fill_(1)
def permute(self, *dims):
"""Return a new tensor with the specific order of dimensions.
"""Return a tensor with the specific order of dimensions.
Parameters
----------
......@@ -2655,7 +2702,7 @@ class Tensor(object):
"""
def transpose(self, dim0, dim1):
"""Return a new tensor with two dimensions swapped.
"""Return a tensor with two dimensions swapped.
Parameters
----------
......@@ -2897,7 +2944,7 @@ class Tensor(object):
return self.reshape(shape)
def view_(self, *shape):
"""Change into a new shape with the same data.
"""Change into a new size with the same data.
Parameters
----------
......@@ -2917,8 +2964,7 @@ class Tensor(object):
return self.reshape_(shape)
def view_as(self, other):
"""Return a new tensor with the same data
but a different size as the given tensor.
"""Return a tensor with the same data but a different size.
Parameters
----------
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!