Commit 9f583556 by Ting PAN

Rename Triangular operator

Summary:
This commit renames triangular operator to trilu following the ONNX.
1 parent b7e959e9
......@@ -183,6 +183,7 @@ Name Supported Reference
`Tile`_ |v| :func:`dragon.tile`
`TopK`_ |v| :func:`dragon.math.top_k`
`Transpose`_ |v| :func:`dragon.transpose`
`Trilu`_ |v| :func:`dragon.tril`
`Unique`_ |v| :func:`dragon.unique`
`Unsqueeze`_ |v| :func:`dragon.expand_dims`
`Upsample`_ |v| :func:`dragon.vision.resize`
......@@ -350,6 +351,7 @@ Name Supported Reference
.. _Tile: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Tile
.. _TopK: https://github.com/onnx/onnx/blob/master/docs/Operators.md#TopK
.. _Transpose: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Transpose
.. _Trilu: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Trilu
.. _Unique: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unique
.. _Unsqueeze: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unsqueeze
.. _Upsample: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Upsample
......
......@@ -8,7 +8,7 @@ namespace kernels {
namespace {
template <typename T>
void _SetTriangular(
void _SetTrilu(
const int batch_size,
const int M,
const int N,
......@@ -42,19 +42,19 @@ void _SetTriangular(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void SetTriangular<T, CPUContext>( \
const int batch_size, \
const int M, \
const int N, \
const int k, \
const int upper, \
const T* x, \
T* y, \
CPUContext* ctx) { \
math::Copy(batch_size* M* N, x, y, ctx); \
_SetTriangular(batch_size, M, N, k, upper, y); \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void SetTrilu<T, CPUContext>( \
const int batch_size, \
const int M, \
const int N, \
const int k, \
const int upper, \
const T* x, \
T* y, \
CPUContext* ctx) { \
math::Copy(batch_size* M* N, x, y, ctx); \
_SetTrilu(batch_size, M, N, k, upper, y); \
}
DEFINE_KERNEL_LAUNCHER(bool);
......
......@@ -11,12 +11,8 @@ namespace kernels {
namespace {
template <typename T, bool kUpper>
__global__ void _SetTriangular(
const int nthreads,
const int M,
const int N,
const int k,
T* y) {
__global__ void
_SetTrilu(const int nthreads, const int M, const int N, const int k, T* y) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int j = index % N;
const int i = (index / N) % M;
......@@ -34,7 +30,7 @@ __global__ void _SetTriangular(
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void SetTriangular<T, CUDAContext>( \
void SetTrilu<T, CUDAContext>( \
const int batch_size, \
const int M, \
const int N, \
......@@ -46,11 +42,11 @@ __global__ void _SetTriangular(
const auto nthreads = batch_size * M * N; \
math::Copy(nthreads, x, y, ctx); \
if (upper > 0) { \
_SetTriangular<T, true> \
_SetTrilu<T, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, M, N, k, y); \
} else { \
_SetTriangular<T, false> \
_SetTrilu<T, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, M, N, k, y); \
} \
......
......@@ -15,11 +15,6 @@ void SigmoidOp<Context>::DoRunWithType() {
}
template <class Context>
void SigmoidOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void SigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -31,11 +26,6 @@ void SigmoidGradientOp<Context>::DoRunWithType() {
ctx());
}
template <class Context>
void SigmoidGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Sigmoid);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Sigmoid);
......
......@@ -18,24 +18,28 @@
namespace dragon {
template <class Context>
class SigmoidOp : public Operator<Context> {
class SigmoidOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(SigmoidOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
};
template <class Context>
class SigmoidGradientOp : public Operator<Context> {
class SigmoidGradientOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(SigmoidGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -44,10 +48,10 @@ class SigmoidGradientOp : public Operator<Context> {
#ifdef USE_CUDNN
template <class Context>
class CuDNNSigmoidOp final : public SigmoidOp<Context> {
class CuDNNSigmoidOp final : public Operator<Context> {
public:
CuDNNSigmoidOp(const OperatorDef& def, Workspace* ws)
: SigmoidOp<Context>(def, ws) {
: Operator<Context>(def, ws) {
CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnSetActivationDescriptor(
......@@ -60,7 +64,9 @@ class CuDNNSigmoidOp final : public SigmoidOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -71,10 +77,10 @@ class CuDNNSigmoidOp final : public SigmoidOp<Context> {
};
template <class Context>
class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
class CuDNNSigmoidGradientOp final : public Operator<Context> {
public:
CuDNNSigmoidGradientOp(const OperatorDef& def, Workspace* ws)
: SigmoidGradientOp<Context>(def, ws) {
: Operator<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnSetActivationDescriptor(
......@@ -87,7 +93,9 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......
......@@ -21,11 +21,6 @@ void CuDNNSigmoidOp<Context>::DoRunWithType() {
}
template <class Context>
void CuDNNSigmoidOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void CuDNNSigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -45,11 +40,6 @@ void CuDNNSigmoidGradientOp<Context>::DoRunWithType() {
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
}
template <class Context>
void CuDNNSigmoidGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CUDNN_OPERATOR(Sigmoid);
DEPLOY_CUDNN_OPERATOR(SigmoidGradient);
......
......@@ -15,11 +15,6 @@ void TanhOp<Context>::DoRunWithType() {
}
template <class Context>
void TanhOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void TanhGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -31,11 +26,6 @@ void TanhGradientOp<Context>::DoRunWithType() {
ctx());
}
template <class Context>
void TanhGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Tanh);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Tanh);
......
......@@ -18,24 +18,28 @@
namespace dragon {
template <class Context>
class TanhOp : public Operator<Context> {
class TanhOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(TanhOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
};
template <class Context>
class TanhGradientOp : public Operator<Context> {
class TanhGradientOp final : public Operator<Context> {
public:
SIMPLE_CTOR_DTOR(TanhGradientOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -44,10 +48,10 @@ class TanhGradientOp : public Operator<Context> {
#ifdef USE_CUDNN
template <class Context>
class CuDNNTanhOp final : public TanhOp<Context> {
class CuDNNTanhOp final : public Operator<Context> {
public:
CuDNNTanhOp(const OperatorDef& def, Workspace* ws)
: TanhOp<Context>(def, ws) {
: Operator<Context>(def, ws) {
CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnSetActivationDescriptor(
......@@ -60,7 +64,9 @@ class CuDNNTanhOp final : public TanhOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......@@ -71,10 +77,10 @@ class CuDNNTanhOp final : public TanhOp<Context> {
};
template <class Context>
class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
class CuDNNTanhGradientOp final : public Operator<Context> {
public:
CuDNNTanhGradientOp(const OperatorDef& def, Workspace* ws)
: TanhGradientOp<Context>(def, ws) {
: Operator<Context>(def, ws) {
CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnSetActivationDescriptor(
......@@ -87,7 +93,9 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
void RunOnDevice() override;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
......
......@@ -21,11 +21,6 @@ void CuDNNTanhOp<Context>::DoRunWithType() {
}
template <class Context>
void CuDNNTanhOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void CuDNNTanhGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
......@@ -45,11 +40,6 @@ void CuDNNTanhGradientOp<Context>::DoRunWithType() {
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
}
template <class Context>
void CuDNNTanhGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CUDNN_OPERATOR(Tanh);
DEPLOY_CUDNN_OPERATOR(TanhGradient);
......
#include "dragon/operators/array/triangular_op.h"
#include "dragon/operators/array/trilu_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void TriangularOp<Context>::DoRunWithType() {
void TriluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
kernels::SetTriangular(
kernels::SetTrilu(
X.count(0, X.ndim() - 2),
X.dim(-2),
X.dim(-1),
......@@ -18,16 +18,18 @@ void TriangularOp<Context>::DoRunWithType() {
}
template <class Context>
void TriangularOp<Context>::RunOnDevice() {
void TriluOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Generic>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Triangular);
DEPLOY_CPU_OPERATOR(Trilu);
REGISTER_CPU_OPERATOR(TriluGradient, TriluOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Triangular);
DEPLOY_CUDA_OPERATOR(Trilu);
REGISTER_CUDA_OPERATOR(TriluGradient, TriluOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(Triangular)
OPERATOR_SCHEMA(Trilu)
/* X */
.NumInputs(1)
/* Y */
......@@ -35,6 +37,14 @@ OPERATOR_SCHEMA(Triangular)
/* X -> Y */
.AllowInplace({{0, 0}});
NO_GRADIENT(Triangular);
OPERATOR_SCHEMA(TriluGradient)
/* dY */
.NumInputs(1)
/* dX */
.NumOutputs(1)
/* dY -> dX */
.AllowInplace({{0, 0}});
REGISTER_GRADIENT(Trilu, SimpleGradientMaker);
} // namespace dragon
......@@ -10,17 +10,17 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_TRIANGULAR_OP_H_
#define DRAGON_OPERATORS_ARRAY_TRIANGULAR_OP_H_
#ifndef DRAGON_OPERATORS_ARRAY_TRILU_OP_H_
#define DRAGON_OPERATORS_ARRAY_TRILU_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class TriangularOp final : public Operator<Context> {
class TriluOp final : public Operator<Context> {
public:
TriangularOp(const OperatorDef& def, Workspace* ws)
TriluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
k_(OP_SINGLE_ARG(int64_t, "k", 0)),
upper_(OP_SINGLE_ARG(int64_t, "upper", 0)) {}
......@@ -37,4 +37,4 @@ class TriangularOp final : public Operator<Context> {
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_TRIANGULAR_OP_H_
#endif // DRAGON_OPERATORS_ARRAY_TRILU_OP_H_
......@@ -609,7 +609,7 @@ def transpose_args(**kwargs):
return {'perm_desc': 'int64' if kwargs.get('ndim', 0) else None}
@register('Triangular')
@register('Trilu')
def triangular_args(**kwargs):
return {'k': kwargs.get('k', 0), 'upper': kwargs.get('upper', False)}
......
......@@ -1825,9 +1825,9 @@ def tril(inputs, k=0, copy=True, **kwargs):
"""
if context.executing_eagerly():
return OpLib.execute(
'Triangular', inputs, outputs=[None] if copy else inputs,
'Trilu', inputs, outputs=[None] if copy else inputs,
k=k, upper=False)
return OpLib.add('Triangular', inputs, k=k, upper=False, **kwargs)
return OpLib.add('Trilu', inputs, k=k, upper=False, **kwargs)
@OpSchema.num_inputs(1)
......@@ -1865,9 +1865,9 @@ def triu(inputs, k=0, copy=True, **kwargs):
"""
if context.executing_eagerly():
return OpLib.execute(
'Triangular', inputs, outputs=[None] if copy else inputs,
'Trilu', inputs, outputs=[None] if copy else inputs,
k=k, upper=True)
return OpLib.add('Triangular', inputs, k=k, upper=True, **kwargs)
return OpLib.add('Trilu', inputs, k=k, upper=True, **kwargs)
@OpSchema.num_inputs(1)
......
......@@ -77,8 +77,17 @@ def softmax_exporter(op_def, context):
if arg.name == 'axis':
axis = arg.i + (ndim if arg.i < 0 else 0)
if axis != (ndim - 1):
raise ValueError(
'Softmax axis could only be the last one.\n'
'Use Exp(LogSoftmax) to compute the softmax instead.')
raise ValueError('Axis could only be the last if opset < 13.')
helper.add_attribute(node, 'axis', arg.i)
return node, const_tensors
@export_util.register('Softmax-13')
def softmax_exporter_v13(op_def, context):
node, const_tensors = export_util.translate(**locals())
ndim = len(context.blob_shapes[op_def.input[0]])
for arg in op_def.arg:
if arg.name == 'axis':
axis = arg.i + (ndim if arg.i < 0 else 0)
helper.add_attribute(node, 'axis', arg.i)
return node, const_tensors
......@@ -529,6 +529,23 @@ def top_k_exporter_v11(op_def, context):
return node, [k]
@export_util.register('Trilu')
def trilu_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
k = 0
for arg in op_def.arg:
if arg.name == 'upper':
helper.add_attribute(node, 'upper', arg.i)
elif arg.name == 'k':
k = arg.i
k = helper.from_array(
numpy.array(k, 'int64'),
context.unique_name(op_def.input[0] + '/trilu/k'),
)
node.input.extend([k.name])
return node, [k]
@export_util.register('Unique')
def unique_exporter(op_def, context):
node, const_tensors = export_util.translate(**locals())
......
......@@ -47,6 +47,9 @@ class DragonFrontend(object):
(9, '1.4.1'),
(10, '1.5.0'),
(11, '1.6.0'),
(12, '1.7.0'),
(13, '1.8.0'),
(14, '1.9.0'),
])
@classmethod
......
......@@ -419,7 +419,7 @@ void SetOneHot(
Context* ctx);
template <typename T, class Context>
void SetTriangular(
void SetTrilu(
const int batch_size,
const int M,
const int N,
......
......@@ -1129,8 +1129,11 @@ class TestArrayOps(OpTestCase):
data = arange(shape, 1)
x = new_tensor(data)
for k in range(-max(shape), max(shape) + 1):
y = dragon.tril(x, k=k)
self.assertEqual(y, np.tril(data, k))
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.tril(x, k=k)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual([y, dx], [np.tril(data, k), np.tril(data, k)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_tril_cuda(self):
......@@ -1145,8 +1148,11 @@ class TestArrayOps(OpTestCase):
data = arange(shape, 1)
x = new_tensor(data)
for k in range(-max(shape), max(shape) + 1):
y = dragon.triu(x, k=k)
self.assertEqual(y, np.triu(data, k))
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.triu(x, k=k)
dx = tape.gradient(y, [x], output_gradients=[x])[0]
self.assertEqual([y, dx], [np.triu(data, k), np.triu(data, k)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_triu_cuda(self):
......
......@@ -1402,7 +1402,7 @@ def tril(input, diagonal=0, out=None):
"""
return FunctionLib.apply(
'Triangular', input.device, [input], outputs=[out],
'Trilu', input.device, [input], outputs=[out],
k=diagonal, upper=False)
......@@ -1439,7 +1439,7 @@ def triu(input, diagonal=0, out=None):
"""
return FunctionLib.apply(
'Triangular', input.device, [input], outputs=[out],
'Trilu', input.device, [input], outputs=[out],
k=diagonal, upper=True)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!