Commit cca00c0d by Ting PAN

Normalize the getter of operator argument

Summary:
This commit renames the operator argument getter to ``GetArgument``
whatever an argument is single or repeated.
1 parent 58708021
Showing with 774 additions and 524 deletions
......@@ -11,14 +11,6 @@ Constructors
Public Functions
----------------
Arg
###
.. doxygenfunction:: dragon::Operator::Arg
Args
####
.. doxygenfunction:: dragon::Operator::Args
Buffer
######
.. doxygenfunction:: dragon::Operator::Buffer
......@@ -27,6 +19,14 @@ Fuse
####
.. doxygenfunction:: dragon::Operator::Fuse
GetArgument
###########
.. doxygenfunction:: dragon::Operator::GetArgument(const string &name)
GetArgument
###########
.. doxygenfunction:: dragon::Operator::GetArgument(const string &name, const T &default_value)
Input
#####
.. doxygenfunction:: dragon::Operator::Input
......
......@@ -21,9 +21,6 @@ dragon
Functions
---------
`arange(...) <dragon/arange.html>`_
: Return a tensor of evenly spaced values within a interval.
`assign(...) <dragon/assign.html>`_
: Assign the value to input.
......@@ -120,6 +117,9 @@ dragon
`python_plugin(...) <dragon/python_plugin.html>`_
: Create a plugin operator from the python class.
`range(...) <dragon/range.html>`_
: Return a tensor of evenly spaced values within a interval.
`repeat(...) <dragon/repeat.html>`_
: Repeat the elements along the given axis.
......@@ -165,7 +165,6 @@ dragon
.. toctree::
:hidden:
dragon/arange
dragon/assign
dragon/broadcast_to
dragon/cast
......@@ -200,6 +199,7 @@ dragon
dragon/one_hot
dragon/pad
dragon/python_plugin
dragon/range
dragon/repeat
dragon/reset_workspace
dragon/reshape
......
......@@ -21,6 +21,9 @@ dragon.random
`normal_like(...) <random/normal_like.html>`_
: Return a tensor initialized from the normal distribution with shape as the other.
`permutation(...) <random/permutation.html>`_
: Return a tensor with value in the permuted range.
`set_seed(...) <random/set_seed.html>`_
: Set the global random seed.
......@@ -41,6 +44,7 @@ dragon.random
random/multinomial
random/normal
random/normal_like
random/permutation
random/set_seed
random/truncated_normal
random/uniform
......
permutation
===========
.. autofunction:: dragon.random.permutation
.. raw:: html
<style>
h1:before {
content: "dragon.random.";
color: #103d3e;
}
</style>
arange
======
range
=====
.. autofunction:: dragon.arange
.. autofunction:: dragon.range
.. raw:: html
......
......@@ -94,6 +94,9 @@ vm.torch
`eye(...) <torch/eye.html>`_
: Return a tensor constructed as the identity matrix.
`flatten(...) <torch/flatten.html>`_
: Return a tensor with dimensions flattened.
`floor(...) <torch/floor.html>`_
: Compute the largest integer not greater than input.
......@@ -184,6 +187,9 @@ vm.torch
`randn(...) <torch/randn.html>`_
: Return a tensor from the normal distribution of N(0, 1).
`randperm(...) <torch/randperm.html>`_
: Return a tensor with value in the permuted range.
`reciprocal(...) <torch/reciprocal.html>`_
: Compute the reciprocal of input.
......@@ -268,6 +274,7 @@ vm.torch
torch/eq
torch/exp
torch/eye
torch/flatten
torch/floor
torch/from_numpy
torch/ge
......@@ -299,6 +306,7 @@ vm.torch
torch/pow
torch/rand
torch/randn
torch/randperm
torch/reciprocal
torch/repeat
torch/reshape
......
......@@ -189,6 +189,14 @@ fill\_
#######
.. automethod:: dragon.vm.torch.Tensor.fill_
flatten
#######
.. automethod:: dragon.vm.torch.Tensor.flatten
flatten\_
#########
.. automethod:: dragon.vm.torch.Tensor.flatten_
float
#####
.. automethod:: dragon.vm.torch.Tensor.float
......@@ -470,6 +478,7 @@ zero\_
.. _torch.div(...): div.html
.. _torch.eq(...): eq.html
.. _torch.exp(...): exp.html
.. _torch.flatten(...): flatten.html
.. _torch.floor(...): floor.html
.. _torch.ge(...): ge.html
.. _torch.gt(...): gt.html
......
flatten
=======
.. autofunction:: dragon.vm.torch.flatten
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
Flatten
=======
.. autoclass:: dragon.vm.torch.nn.Flatten
__init__
--------
.. automethod:: dragon.vm.torch.nn.Flatten.__init__
.. _torch.flatten(...): ../flatten.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
randperm
========
.. autofunction:: dragon.vm.torch.randperm
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
......@@ -255,40 +255,50 @@ DEFINE_REGISTRY(
/* Macros */
#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname) \
template <> \
DRAGON_API T OperatorBase::Arg(const string& name, const T& default_value) { \
if (args_.count(name) == 0) { \
return default_value; \
} \
CHECK(args_[name]->has_##fieldname()); \
return static_cast<T>(args_[name]->fieldname()); \
#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, default) \
template <> \
DRAGON_API T OperatorBase::GetArgument( \
const string& name, const T& default_value) { \
if (args_.count(name) == 0) return default_value; \
CHECK(args_[name]->has_##fieldname()); \
return static_cast<T>(args_[name]->fieldname()); \
} \
template <> \
DRAGON_API T OperatorBase::GetArgument(const string& name) { \
return OperatorBase::GetArgument<T>(name, default); \
}
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f)
INSTANTIATE_GET_SINGLE_ARGUMENT(double, f)
INSTANTIATE_GET_SINGLE_ARGUMENT(int, i)
INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i)
INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i)
INSTANTIATE_GET_SINGLE_ARGUMENT(string, s)
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, 0.f)
INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, 0.);
INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, 0);
INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false);
INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, int64_t(0));
INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, "");
#undef INSTANTIATE_GET_SINGLE_ARGUMENT
#define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname) \
template <> \
vector<T> DRAGON_API OperatorBase::Args<T>(const string& name) { \
if (args_.count(name) == 0) return vector<T>(); \
vector<T> values; \
for (const auto& v : args_[name]->fieldname()) \
values.push_back(static_cast<T>(v)); \
return values; \
#define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname) \
template <> \
vector<T> DRAGON_API OperatorBase::GetArgument<vector<T>>( \
const string& name, const vector<T>& default_value) { \
if (args_.count(name) == 0) return default_value; \
vector<T> values; \
for (const auto& v : args_[name]->fieldname()) { \
values.push_back(static_cast<T>(v)); \
} \
return values; \
} \
template <> \
vector<T> DRAGON_API OperatorBase::GetArgument<vector<T>>( \
const string& name) { \
return OperatorBase::GetArgument<vector<T>>(name, vector<T>()); \
}
INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats)
INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats)
INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints)
INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints)
INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints)
INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings)
INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats);
INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats);
INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints);
INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints);
INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints);
INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings);
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
template class Operator<CPUContext>;
......
......@@ -74,13 +74,13 @@ class DRAGON_API OperatorBase {
return (int)outputs_.size();
}
/*! \brief Return the value of single argument */
/*! \brief Return the value of argument */
template <typename T>
T Arg(const string& name, const T& default_value);
T GetArgument(const string& name);
/*! \brief Return the value of repeated argument */
/*! \brief Return the value of argument with default */
template <typename T>
vector<T> Args(const string& name);
T GetArgument(const string& name, const T& default_value);
/*! \brief Return the message for supported value */
string MessageForUnsupported(
......@@ -199,7 +199,7 @@ class DRAGON_API Operator : public OperatorBase {
Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws),
ctx_(def.device_option()),
do_sync_(OperatorBase::Arg<bool>("do_sync", false)) {}
do_sync_(OperatorBase::GetArgument<bool>("do_sync", false)) {}
/*! \brief Prepare the content of inputs */
virtual void Prepare();
......@@ -279,19 +279,20 @@ OperatorBase* NewOperator(const OperatorDef&, Workspace*);
/* Dispatchers */
#define XIsType(X, type) X.template IsType<type>()
template <typename... Types>
struct TensorTypes {};
using IntegralTensorTypes = TensorTypes<bool, int8_t, uint8_t, int, int64_t>;
using IntegralTensorTypes = TensorTypes<int8_t, uint8_t, int, int64_t>;
using FloatingTensorTypes = TensorTypes<float16, float, double>;
using MathTensorTypes =
using NumericalTensorTypes =
TensorTypes<int8_t, uint8_t, int, int64_t, float16, float, double>;
using AllTensorTypes =
using BooleanIntegralTensorTypes =
TensorTypes<bool, int8_t, uint8_t, int, int64_t, bool>;
using FullTensorTypes =
TensorTypes<bool, int8_t, uint8_t, int, int64_t, float16, float, double>;
template <typename Sizes, typename... Args>
......@@ -382,30 +383,33 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
/* Arguments */
#define OpArg OperatorBase::Arg
#define OpArgs OperatorBase::Args
#define OP_SINGLE_ARG(type, name, default) \
OperatorBase::GetArgument<type>(name, (default))
#define OP_REPEATED_ARG(type, name) \
OperatorBase::GetArgument<vector<type>>(name)
#define DECLARE_ARG_WITH_DESC(type, arg) \
type arg##_; \
string arg##_desc_; \
#define DECLARE_OP_SINGLE_ARG_WITH_DESC(type, arg) \
type arg##_; \
string arg##_desc_; \
type arg()
#define DECLARE_ARGS_WITH_DESC(type, arg) \
string arg##_desc_; \
vector<type> arg##_; \
vector<string> arg##_descs_; \
#define DECLARE_OP_REPEATED_ARG_WITH_DESC(type, arg) \
string arg##_desc_; \
vector<type> arg##_; \
vector<string> arg##_descs_; \
type arg(int i, int* num = nullptr)
#define GET_ARG_WITH_DESC(type, arg, default_value) \
arg##_ = OpArg<type>(#arg, default_value); \
arg##_desc_ = OpArg<string>(string(#arg) + "_desc", "")
#define INIT_OP_SINGLE_ARG_WITH_DESC(type, arg, default_value) \
arg##_ = OP_SINGLE_ARG(type, #arg, default_value); \
arg##_desc_ = OP_SINGLE_ARG(string, string(#arg) + "_desc", "")
#define GET_ARGS_WITH_DESC(type, arg) \
arg##_ = OpArgs<type>(#arg); \
arg##_desc_ = OpArg<string>(string(#arg) + "_desc", ""); \
arg##_descs_ = OpArgs<string>(string(#arg) + "_descs")
#define INIT_OP_REPEATED_ARG_WITH_DESC(type, arg) \
arg##_ = OP_REPEATED_ARG(type, #arg); \
arg##_desc_ = OP_SINGLE_ARG(string, string(#arg) + "_desc", ""); \
arg##_descs_ = OP_REPEATED_ARG(string, string(#arg) + "_descs")
#define DEFINE_ARG_WITH_DESC(type, classname, arg) \
#define DEFINE_OP_SINGLE_ARG_WITH_DESC(type, classname, arg) \
template <class Context> \
type classname<Context>::arg() { \
if (arg##_desc_.empty()) return arg##_; \
......@@ -419,7 +423,7 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
return arg##_tensor->template data<type, CPUContext>()[0]; \
}
#define DEFINE_ARGS_WITH_DESC(type, classname, arg) \
#define DEFINE_OP_REPEATED_ARG_WITH_DESC(type, classname, arg) \
template <class Context> \
type classname<Context>::arg(int i, int* num) { \
const type* data; \
......@@ -451,13 +455,13 @@ DEFINE_TENSOR_TYPES_DISPATCHER(DoRunWithType);
}
#define CANONICALIZE_AXIS_WITH_TENSOR_AND_OFFSET(tensor, offset) \
auto axis = OpArg<int64_t>("axis", INT_MAX); \
auto axis = OP_SINGLE_ARG(int64_t, "axis", INT_MAX); \
if (axis != INT_MAX) { \
axis = axis < 0 ? axis + tensor.ndim() + offset : axis; \
CHECK(axis >= 0 && axis < tensor.ndim() + offset) \
<< "\nExcepted the axis in [-" << tensor.ndim() + offset << ", " \
<< tensor.ndim() + offset << "), got " \
<< OpArg<int64_t>("axis", INT_MAX) << "."; \
<< OP_SINGLE_ARG(int64_t, "axis", INT_MAX) << "."; \
}
#define CANONICALIZE_AXIS_WITH_TENSOR(tensor) \
......@@ -509,24 +513,24 @@ DECLARE_REGISTRY(
#define REGISTER_CNML_OPERATOR(name, ...) \
REGISTER_CLASS(CNMLOperatorRegistry, name, __VA_ARGS__)
#define DEPLOY_CPU(name) \
#define DEPLOY_CPU_OPERATOR(name) \
REGISTER_CPU_OPERATOR(name, name##Op<CPUContext>); \
INSTANTIATE_OPERATOR(name, CPUContext);
#define DEPLOY_CUDA(name) \
#define DEPLOY_CUDA_OPERATOR(name) \
REGISTER_CUDA_OPERATOR(name, name##Op<CUDAContext>); \
INSTANTIATE_OPERATOR(name, CUDAContext);
#define DEPLOY_CPU_CUDA(name) \
#define DEPLOY_CPU_CUDA_OPERATOR(name) \
REGISTER_CPU_OPERATOR(name, name##Op<CPUContext>); \
REGISTER_CUDA_OPERATOR(name, name##Op<CPUContext>); \
INSTANTIATE_OPERATOR(name, CPUContext);
#define DEPLOY_CUDNN(name) \
#define DEPLOY_CUDNN_OPERATOR(name) \
REGISTER_CUDNN_OPERATOR(name, CuDNN##name##Op<CUDAContext>); \
INSTANTIATE_CUDNN_OPERATOR(name);
#define DEPLOY_CNML(name) \
#define DEPLOY_CNML_OPERATOR(name) \
REGISTER_CNML_OPERATOR(name, CnML##name##Op<CNMLContext>); \
INSTANTIATE_CNML_OPERATOR(name);
......
......@@ -15,7 +15,7 @@ void _DropBlock2dNCHW(
const int seed_h,
const int seed_w,
const int block_size,
const uint32_t* seed,
const uint32_t* r,
int* mask) {
const int HW = H * W;
const int CHW = C * HW;
......@@ -24,7 +24,7 @@ void _DropBlock2dNCHW(
std::array<int, 3> dims = {N, seed_h, seed_w};
int offset;
for (int i = 0; i < count; ++i) {
if (seed[i] > 0) {
if (r[i] > 0) {
offset = idx[0] * CHW + idx[1] * W + idx[2];
for (int c = 0; c < C; ++c) {
for (int bh = 0; bh < block_size; ++bh) {
......@@ -84,15 +84,15 @@ void DropBlock2d<CPUContext>(
const int block_size,
const float gamma,
const string& data_format,
uint32_t* seed,
uint32_t* r,
int* mask,
CPUContext* ctx) {
const int count = N * seed_h * seed_w;
math::RandomBernoulli(count, gamma, seed, ctx);
math::RandomBernoulli(count, gamma, r, ctx);
if (data_format == "NCHW") {
_DropBlock2dNCHW(N, C, H, W, seed_h, seed_w, block_size, seed, mask);
_DropBlock2dNCHW(N, C, H, W, seed_h, seed_w, block_size, r, mask);
} else if (data_format == "NHWC") {
_DropBlock2dNHWC(N, C, H, W, seed_h, seed_w, block_size, seed, mask);
_DropBlock2dNHWC(N, C, H, W, seed_h, seed_w, block_size, r, mask);
} else {
LOG(FATAL) << "Unknown DataFormat: " << data_format;
}
......
......@@ -19,10 +19,10 @@ __global__ void _DropBlock2dNCHW(
const int seed_w,
const int block_size,
const uint32_t thresh,
const uint32_t* seed,
const uint32_t* r,
int* mask) {
CUDA_1D_KERNEL_LOOP(idx, nthreads) {
if (seed[idx] < thresh) {
if (r[idx] < thresh) {
const int wstart = idx % seed_w;
const int hstart = (idx / seed_w) % seed_h;
const int n = idx / seed_w / seed_h;
......@@ -47,10 +47,10 @@ __global__ void _DropBlock2dNHWC(
const int seed_w,
const int block_size,
const uint32_t thresh,
const uint32_t* seed,
const uint32_t* r,
int* mask) {
CUDA_1D_KERNEL_LOOP(idx, nthreads) {
if (seed[idx] < thresh) {
if (r[idx] < thresh) {
const int wstart = idx % seed_w;
const int hstart = (idx / seed_w) % seed_h;
const int n = idx / seed_w / seed_h;
......@@ -81,11 +81,11 @@ void DropBlock2d<CUDAContext>(
const int block_size,
const float gamma,
const string& data_format,
uint32_t* seed,
uint32_t* r,
int* mask,
CUDAContext* ctx) {
auto nthreads = N * seed_h * seed_w;
math::RandomUniform(nthreads, 0.f, 1.f, seed, ctx);
math::Random(nthreads, r, ctx);
auto mask_thresh = (uint32_t)(UINT_MAX * gamma);
if (data_format == "NCHW") {
_DropBlock2dNCHW<<<
......@@ -93,14 +93,14 @@ void DropBlock2d<CUDAContext>(
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
nthreads, C, H, W, seed_h, seed_w, block_size, mask_thresh, seed, mask);
nthreads, C, H, W, seed_h, seed_w, block_size, mask_thresh, r, mask);
} else if (data_format == "NHWC") {
_DropBlock2dNHWC<<<
CUDA_BLOCKS(nthreads),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
nthreads, C, H, W, seed_h, seed_w, block_size, mask_thresh, seed, mask);
nthreads, C, H, W, seed_h, seed_w, block_size, mask_thresh, r, mask);
} else {
LOG(FATAL) << "Unknown DataFormat: " << data_format;
}
......
......@@ -82,7 +82,7 @@ void _Dropout<float16>(
const T* x, \
uint8_t* mask, \
T* y, \
uint32_t* scratch, \
uint32_t* r, \
CPUContext* ctx) { \
_Dropout(count, cast::to<T>(prob), cast::to<T>(scale), x, mask, y, ctx); \
}
......
......@@ -97,15 +97,15 @@ void Dropout<float16, CUDAContext>(
const float16* x,
uint8_t* mask,
float16* y,
uint32_t* scratch,
uint32_t* r,
CUDAContext* ctx) {
math::RandomUniform(count, 0.f, 1.f, scratch, ctx);
math::Random(count, r, ctx);
_Dropout<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count,
static_cast<uint32_t>(UINT_MAX * prob),
cast::to<half>(scale),
reinterpret_cast<const half*>(x),
scratch,
r,
mask,
reinterpret_cast<half*>(y));
}
......@@ -130,12 +130,12 @@ void Dropout<float16, CUDAContext>(
const T* x, \
uint8_t* mask, \
T* y, \
uint32_t* scratch, \
uint32_t* r, \
CUDAContext* ctx) { \
math::RandomUniform(count, 0.f, 1.f, scratch, ctx); \
math::Random(count, r, ctx); \
auto threshold = static_cast<uint32_t>(UINT_MAX * prob); \
_Dropout<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, threshold, cast::to<T>(scale), x, scratch, mask, y); \
count, threshold, cast::to<T>(scale), x, r, mask, y); \
}
DEFINE_KERNEL_LAUNCHER(float);
......
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _SwapByKey(const int count, const uint32_t* r, T* y) {
for (int i = 0; i < count; ++i) {
std::swap(y[i], y[i + (r[i] % (count - i))]);
}
}
} // namespace
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Permutation<T, CPUContext>( \
const int count, T* y, uint32_t* r, CPUContext* ctx) { \
math::Random(count, r, ctx); \
kernel::Range(count, 0.f, 1.f, y, ctx); \
_SwapByKey(count, r, y); \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_thrust.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
__global__ void _Sequence(const int nthreads, half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __float2half(float(i));
}
}
} // namespace
template <>
void Permutation<float16, CUDAContext>(
const int count,
float16* y,
uint32_t* r,
CUDAContext* ctx) {
math::Random(count, r, ctx);
auto values = thrust::device_ptr<half>(reinterpret_cast<half*>(y));
auto keys = thrust::device_ptr<uint32_t>(r);
auto policy = thrust::cuda::par.on(ctx->cuda_stream());
_Sequence<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, reinterpret_cast<half*>(y));
thrust::sort_by_key(policy, keys, keys + count, values);
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Permutation<T, CUDAContext>( \
const int count, T* y, uint32_t* r, CUDAContext* ctx) { \
math::Random(count, r, ctx); \
auto values = thrust::device_ptr<T>(y); \
auto keys = thrust::device_ptr<uint32_t>(r); \
auto policy = thrust::cuda::par.on(ctx->cuda_stream()); \
thrust::sequence(policy, values, values + count); \
thrust::sort_by_key(policy, keys, keys + count, values); \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
......@@ -9,12 +9,12 @@ namespace kernel {
namespace {
template <typename T>
void _Arange(const int count, const float start, const float step, T* y) {
void _Range(const int count, const float start, const float delta, T* y) {
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = static_cast<T>(start + i * step);
y[i] = static_cast<T>(start + i * delta);
}
}
......@@ -23,29 +23,29 @@ void _Arange(const int count, const float start, const float step, T* y) {
/* ------------------- Launcher Separator ------------------- */
template <>
void Arange<float16, CPUContext>(
void Range<float16, CPUContext>(
const int count,
const float start,
const float step,
const float delta,
float16* y,
CPUContext* ctx) {
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = cast::to<float16>(start + (float)i * step);
y[i] = cast::to<float16>(start + (float)i * delta);
}
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Arange<T, CPUContext>( \
void Range<T, CPUContext>( \
const int count, \
const float start, \
const float step, \
const float delta, \
T* y, \
CPUContext* ctx) { \
_Arange(count, start, step, y); \
_Range(count, start, delta, y); \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
......@@ -54,7 +54,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -11,20 +11,20 @@ namespace {
template <typename T>
__global__ void
_Arange(const int nthreads, const float start, const float step, T* y) {
_Range(const int nthreads, const float start, const float delta, T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = start + (float)i * step;
y[i] = T(start + (float)i * delta);
}
}
template <>
__global__ void _Arange<half>(
__global__ void _Range<half>(
const int nthreads,
const float start,
const float step,
const float delta,
half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __float2half(start + (float)i * step);
y[i] = __float2half(start + (float)i * delta);
}
}
......@@ -33,26 +33,26 @@ __global__ void _Arange<half>(
/* ------------------- Launcher Separator ------------------- */
template <>
void Arange<float16, CUDAContext>(
void Range<float16, CUDAContext>(
const int count,
const float start,
const float step,
const float delta,
float16* y,
CUDAContext* ctx) {
_Arange<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, start, step, reinterpret_cast<half*>(y));
_Range<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
count, start, delta, reinterpret_cast<half*>(y));
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Arange<T, CUDAContext>( \
const int count, \
const float start, \
const float step, \
T* y, \
CUDAContext* ctx) { \
_Arange<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, start, step, y); \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Range<T, CUDAContext>( \
const int count, \
const float start, \
const float delta, \
T* y, \
CUDAContext* ctx) { \
_Range<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
count, start, delta, y); \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
......@@ -61,7 +61,6 @@ DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
......
......@@ -104,8 +104,9 @@ class NumpyFeeder : public TensorFeederBase {
int ndim = PyArray_NDIM(array);
vec64_t dims(ndim);
auto* npy_dims = PyArray_DIMS(array);
for (int i = 0; i < ndim; i++)
for (int i = 0; i < ndim; i++) {
dims[i] = npy_dims[i];
}
tensor->Reshape(dims);
if (option.device_type() == PROTO_CUDA) {
#ifdef USE_CUDA
......
......@@ -16,9 +16,9 @@ PythonPluginInferOp<Context>::PythonPluginInferOp(
const OperatorDef& def,
Workspace* ws)
: Operator<Context>(def, ws),
module_name_(OpArg<string>("module_name", "")),
class_name_(OpArg<string>("class_name", "")),
kwargs_str_((OpArg<string>("kwargs_str", ""))) {
module_name_(OP_SINGLE_ARG(string, "module_name", "")),
class_name_(OP_SINGLE_ARG(string, "class_name", "")),
kwargs_str_(OP_SINGLE_ARG(string, "kwargs_str", "")) {
// Optimization for all python ops
this->do_sync_ = false;
......@@ -118,21 +118,21 @@ void PythonPluginGradientOp<Context>::RunOnDevice() {
}
}
DEPLOY_CPU(PythonPluginInfer);
DEPLOY_CPU_OPERATOR(PythonPluginInfer);
#ifdef USE_CUDA
DEPLOY_CUDA(PythonPluginInfer);
DEPLOY_CUDA_OPERATOR(PythonPluginInfer);
#endif
OPERATOR_SCHEMA(PythonPluginInfer);
DEPLOY_CPU(PythonPlugin);
DEPLOY_CPU_OPERATOR(PythonPlugin);
#ifdef USE_CUDA
DEPLOY_CUDA(PythonPlugin);
DEPLOY_CUDA_OPERATOR(PythonPlugin);
#endif
OPERATOR_SCHEMA(PythonPlugin);
DEPLOY_CPU(PythonPluginGradient);
DEPLOY_CPU_OPERATOR(PythonPluginGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(PythonPluginGradient);
DEPLOY_CUDA_OPERATOR(PythonPluginGradient);
#endif
OPERATOR_SCHEMA(PythonPluginGradient);
......
......@@ -108,9 +108,9 @@ void DropBlock2dGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(DropBlock2d);
DEPLOY_CPU_OPERATOR(DropBlock2d);
#ifdef USE_CUDA
DEPLOY_CUDA(DropBlock2d);
DEPLOY_CUDA_OPERATOR(DropBlock2d);
#endif
OPERATOR_SCHEMA(DropBlock2d)
......@@ -121,9 +121,9 @@ OPERATOR_SCHEMA(DropBlock2d)
/* X => Y */
.AllowInplace({{0, 0}});
DEPLOY_CPU(DropBlock2dGradient);
DEPLOY_CPU_OPERATOR(DropBlock2dGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(DropBlock2dGradient);
DEPLOY_CUDA_OPERATOR(DropBlock2dGradient);
#endif
OPERATOR_SCHEMA(DropBlock2dGradient)
......
......@@ -22,10 +22,10 @@ class DropBlock2dOp final : public Operator<Context> {
public:
DropBlock2dOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
block_size_(OpArg<int64_t>("block_size", 7)),
alpha_(OpArg<float>("alpha", 1.f)),
decrement_(OpArg<float>("decrement", 0.f)) {
GET_ARG_WITH_DESC(float, keep_prob, 0.9f);
block_size_(OP_SINGLE_ARG(int64_t, "block_size", 7)),
alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)),
decrement_(OP_SINGLE_ARG(float, "decrement", 0.f)) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, keep_prob, 0.9f);
}
USE_OPERATOR_FUNCTIONS;
......@@ -37,7 +37,7 @@ class DropBlock2dOp final : public Operator<Context> {
protected:
int64_t block_size_;
float alpha_, decrement_, prob_ = 1.;
DECLARE_ARG_WITH_DESC(float, keep_prob);
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, keep_prob);
};
template <class Context>
......@@ -52,7 +52,7 @@ class DropBlock2dGradientOp final : public Operator<Context> {
void DoRunWithType();
};
DEFINE_ARG_WITH_DESC(float, DropBlock2dOp, keep_prob);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropBlock2dOp, keep_prob);
} // namespace dragon
......
......@@ -72,14 +72,14 @@ void DropPathGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(DropPath);
DEPLOY_CPU_OPERATOR(DropPath);
#ifdef USE_CUDA
DEPLOY_CUDA(DropPath);
DEPLOY_CUDA_OPERATOR(DropPath);
#endif
DEPLOY_CPU(DropPathGradient);
DEPLOY_CPU_OPERATOR(DropPathGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(DropPathGradient);
DEPLOY_CUDA_OPERATOR(DropPathGradient);
#endif
OPERATOR_SCHEMA(DropPath)
......
......@@ -21,8 +21,9 @@ template <class Context>
class DropPathOp final : public Operator<Context> {
public:
DropPathOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), inc_(OpArg<float>("increment", 0.f)) {
GET_ARG_WITH_DESC(float, prob, 0.2f);
: Operator<Context>(def, ws),
inc_(OP_SINGLE_ARG(float, "increment", 0.f)) {
INIT_OP_SINGLE_ARG_WITH_DESC(float, prob, 0.2f);
}
USE_OPERATOR_FUNCTIONS;
......@@ -33,7 +34,7 @@ class DropPathOp final : public Operator<Context> {
protected:
float inc_, drop_prob_ = 0.f;
DECLARE_ARG_WITH_DESC(float, prob);
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, prob);
};
template <class Context>
......@@ -48,7 +49,7 @@ class DropPathGradientOp final : public Operator<Context> {
void DoRunWithType();
};
DEFINE_ARG_WITH_DESC(float, DropPathOp, prob);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropPathOp, prob);
} // namespace dragon
......
......@@ -56,14 +56,14 @@ void DropoutGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Dropout);
DEPLOY_CPU_OPERATOR(Dropout);
#ifdef USE_CUDA
DEPLOY_CUDA(Dropout);
DEPLOY_CUDA_OPERATOR(Dropout);
#endif
DEPLOY_CPU(DropoutGradient);
DEPLOY_CPU_OPERATOR(DropoutGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(DropoutGradient);
DEPLOY_CUDA_OPERATOR(DropoutGradient);
#endif
OPERATOR_SCHEMA(Dropout)
......
......@@ -22,7 +22,7 @@ class DropoutOp : public Operator<Context> {
public:
DropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARG_WITH_DESC(float, prob, 0.5f);
INIT_OP_SINGLE_ARG_WITH_DESC(float, prob, 0.5f);
}
USE_OPERATOR_FUNCTIONS;
......@@ -32,7 +32,7 @@ class DropoutOp : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARG_WITH_DESC(float, prob);
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, prob);
};
template <class Context>
......@@ -40,7 +40,7 @@ class DropoutGradientOp : public Operator<Context> {
public:
DropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARG_WITH_DESC(float, prob, 0.5f);
INIT_OP_SINGLE_ARG_WITH_DESC(float, prob, 0.5f);
}
USE_OPERATOR_FUNCTIONS;
......@@ -50,11 +50,11 @@ class DropoutGradientOp : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARG_WITH_DESC(float, prob);
DECLARE_OP_SINGLE_ARG_WITH_DESC(float, prob);
};
DEFINE_ARG_WITH_DESC(float, DropoutOp, prob);
DEFINE_ARG_WITH_DESC(float, DropoutGradientOp, prob);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutOp, prob);
DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutGradientOp, prob);
#ifdef USE_CUDNN
......
......@@ -119,8 +119,8 @@ void CuDNNDropoutGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Dropout);
DEPLOY_CUDNN(DropoutGradient);
DEPLOY_CUDNN_OPERATOR(Dropout);
DEPLOY_CUDNN_OPERATOR(DropoutGradient);
} // namespace dragon
......
......@@ -38,14 +38,14 @@ void EluGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Elu);
DEPLOY_CPU_OPERATOR(Elu);
#ifdef USE_CUDA
DEPLOY_CUDA(Elu);
DEPLOY_CUDA_OPERATOR(Elu);
#endif
DEPLOY_CPU(EluGradient);
DEPLOY_CPU_OPERATOR(EluGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(EluGradient);
DEPLOY_CUDA_OPERATOR(EluGradient);
#endif
OPERATOR_SCHEMA(Elu)
......
......@@ -21,7 +21,8 @@ template <class Context>
class EluOp : public Operator<Context> {
public:
EluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), alpha_(OpArg<float>("alpha", 1.f)) {}
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
template <typename T>
......@@ -37,7 +38,8 @@ template <class Context>
class EluGradientOp : public Operator<Context> {
public:
EluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), alpha_(OpArg<float>("alpha", 1.f)) {}
: Operator<Context>(def, ws),
alpha_(OP_SINGLE_ARG(float, "alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -52,8 +52,8 @@ void CuDNNEluGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Elu);
DEPLOY_CUDNN(EluGradient);
DEPLOY_CUDNN_OPERATOR(Elu);
DEPLOY_CUDNN_OPERATOR(EluGradient);
} // namespace dragon
......
......@@ -113,14 +113,14 @@ void PReluGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1));
}
DEPLOY_CPU(PRelu);
DEPLOY_CPU_OPERATOR(PRelu);
#ifdef USE_CUDA
DEPLOY_CUDA(PRelu);
DEPLOY_CUDA_OPERATOR(PRelu);
#endif
DEPLOY_CPU(PReluGradient);
DEPLOY_CPU_OPERATOR(PReluGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(PReluGradient);
DEPLOY_CUDA_OPERATOR(PReluGradient);
#endif
OPERATOR_SCHEMA(PRelu)
......
......@@ -22,7 +22,7 @@ class PReluOp final : public Operator<Context> {
public:
PReluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
channel_shared_(OpArg<int64_t>("channel_shared", 0)) {}
channel_shared_(OP_SINGLE_ARG(int64_t, "channel_shared", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -57,14 +57,14 @@ void ReluGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Relu);
DEPLOY_CPU_OPERATOR(Relu);
#ifdef USE_CUDA
DEPLOY_CUDA(Relu);
DEPLOY_CUDA_OPERATOR(Relu);
#endif
DEPLOY_CPU(ReluGradient);
DEPLOY_CPU_OPERATOR(ReluGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(ReluGradient);
DEPLOY_CUDA_OPERATOR(ReluGradient);
#endif
OPERATOR_SCHEMA(Relu)
......
......@@ -22,8 +22,8 @@ class ReluOp : public Operator<Context> {
public:
ReluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OpArg<float>("alpha", 0.f)),
max_value_(OpArg<float>("max_value", 0.f)) {}
alpha_(OP_SINGLE_ARG(float, "alpha", 0.f)),
max_value_(OP_SINGLE_ARG(float, "max_value", 0.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -40,8 +40,8 @@ class ReluGradientOp : public Operator<Context> {
public:
ReluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OpArg<float>("alpha", 0.f)),
max_value_(OpArg<float>("max_value", 0.f)) {}
alpha_(OP_SINGLE_ARG(float, "alpha", 0.f)),
max_value_(OP_SINGLE_ARG(float, "max_value", 0.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -58,8 +58,8 @@ void CuDNNReluGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Relu);
DEPLOY_CUDNN(ReluGradient);
DEPLOY_CUDNN_OPERATOR(Relu);
DEPLOY_CUDNN_OPERATOR(ReluGradient);
} // namespace dragon
......
......@@ -40,14 +40,14 @@ void SeluGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Selu);
DEPLOY_CPU_OPERATOR(Selu);
#ifdef USE_CUDA
DEPLOY_CUDA(Selu);
DEPLOY_CUDA_OPERATOR(Selu);
#endif
DEPLOY_CPU(SeluGradient);
DEPLOY_CPU_OPERATOR(SeluGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(SeluGradient);
DEPLOY_CUDA_OPERATOR(SeluGradient);
#endif
OPERATOR_SCHEMA(Selu)
......
......@@ -22,8 +22,8 @@ class SeluOp final : public Operator<Context> {
public:
SeluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OpArg<float>("alpha", 1.67326f)),
gamma_(OpArg<float>("gamma", 1.0507f)) {}
alpha_(OP_SINGLE_ARG(float, "alpha", 1.67326f)),
gamma_(OP_SINGLE_ARG(float, "gamma", 1.0507f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -40,8 +40,8 @@ class SeluGradientOp final : public Operator<Context> {
public:
SeluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha_(OpArg<float>("alpha", 1.67326f)),
gamma_(OpArg<float>("gamma", 1.0507f)) {}
alpha_(OP_SINGLE_ARG(float, "alpha", 1.67326f)),
gamma_(OP_SINGLE_ARG(float, "gamma", 1.0507f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -36,14 +36,14 @@ void SigmoidGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Sigmoid);
DEPLOY_CPU_OPERATOR(Sigmoid);
#ifdef USE_CUDA
DEPLOY_CUDA(Sigmoid);
DEPLOY_CUDA_OPERATOR(Sigmoid);
#endif
DEPLOY_CPU(SigmoidGradient);
DEPLOY_CPU_OPERATOR(SigmoidGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(SigmoidGradient);
DEPLOY_CUDA_OPERATOR(SigmoidGradient);
#endif
OPERATOR_SCHEMA(Sigmoid)
......
......@@ -50,8 +50,8 @@ void CuDNNSigmoidGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Sigmoid);
DEPLOY_CUDNN(SigmoidGradient);
DEPLOY_CUDNN_OPERATOR(Sigmoid);
DEPLOY_CUDNN_OPERATOR(SigmoidGradient);
} // namespace dragon
......
......@@ -44,14 +44,14 @@ void SoftmaxGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Softmax);
DEPLOY_CPU_OPERATOR(Softmax);
#ifdef USE_CUDA
DEPLOY_CUDA(Softmax);
DEPLOY_CUDA_OPERATOR(Softmax);
#endif
DEPLOY_CPU(SoftmaxGradient);
DEPLOY_CPU_OPERATOR(SoftmaxGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(SoftmaxGradient);
DEPLOY_CUDA_OPERATOR(SoftmaxGradient);
#endif
OPERATOR_SCHEMA(Softmax)
......
......@@ -54,8 +54,8 @@ void CuDNNSoftmaxGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Softmax);
DEPLOY_CUDNN(SoftmaxGradient);
DEPLOY_CUDNN_OPERATOR(Softmax);
DEPLOY_CUDNN_OPERATOR(SoftmaxGradient);
} // namespace dragon
......
......@@ -36,14 +36,14 @@ void TanhGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Tanh);
DEPLOY_CPU_OPERATOR(Tanh);
#ifdef USE_CUDA
DEPLOY_CUDA(Tanh);
DEPLOY_CUDA_OPERATOR(Tanh);
#endif
DEPLOY_CPU(TanhGradient);
DEPLOY_CPU_OPERATOR(TanhGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(TanhGradient);
DEPLOY_CUDA_OPERATOR(TanhGradient);
#endif
OPERATOR_SCHEMA(Tanh)
......
......@@ -50,8 +50,8 @@ void CuDNNTanhGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CUDNN(Tanh);
DEPLOY_CUDNN(TanhGradient);
DEPLOY_CUDNN_OPERATOR(Tanh);
DEPLOY_CUDNN_OPERATOR(TanhGradient);
} // namespace dragon
......
......@@ -22,7 +22,7 @@ class ArgMaxOp final : public Operator<Context> {
public:
ArgMaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
keep_dims_(OpArg<int64_t>("keep_dims", 0)) {}
keep_dims_(OP_SINGLE_ARG(int64_t, "keep_dims", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -39,7 +39,7 @@ class ArgMinOp final : public Operator<Context> {
public:
ArgMinOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
keep_dims_(OpArg<int64_t>("keep_dims", 0)) {}
keep_dims_(OP_SINGLE_ARG(int64_t, "keep_dims", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -49,12 +49,12 @@ void ArgMaxOp<Context>::DoRunWithType() {
template <class Context>
void ArgMaxOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(ArgMax);
DEPLOY_CPU_OPERATOR(ArgMax);
#ifdef USE_CUDA
DEPLOY_CUDA(ArgMax);
DEPLOY_CUDA_OPERATOR(ArgMax);
#endif
OPERATOR_SCHEMA(ArgMax)
......
......@@ -49,12 +49,12 @@ void ArgMinOp<Context>::DoRunWithType() {
template <class Context>
void ArgMinOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(ArgMin);
DEPLOY_CPU_OPERATOR(ArgMin);
#ifdef USE_CUDA
DEPLOY_CUDA(ArgMin);
DEPLOY_CUDA_OPERATOR(ArgMin);
#endif
OPERATOR_SCHEMA(ArgMin)
......
......@@ -39,21 +39,21 @@ namespace dragon {
LOG(FATAL) << MessageForUnsupported(dtype(), ELIGIBLE_TENSOR_TYPES);
#define DISPATCH_WITH_TENSOR(X) \
if (XIsType(X, bool)) { \
if (X.template IsType<bool>()) { \
DISPATCH_TYPE_TO_ALL(bool); \
} else if (XIsType(X, int8_t)) { \
} else if (X.template IsType<int8_t>()) { \
DISPATCH_TYPE_TO_ALL(int8_t); \
} else if (XIsType(X, uint8_t)) { \
} else if (X.template IsType<uint8_t>()) { \
DISPATCH_TYPE_TO_ALL(uint8_t); \
} else if (XIsType(X, int)) { \
} else if (X.template IsType<int>()) { \
DISPATCH_TYPE_TO_ALL(int); \
} else if (XIsType(X, int64_t)) { \
} else if (X.template IsType<int64_t>()) { \
DISPATCH_TYPE_TO_ALL(int64_t); \
} else if (XIsType(X, float16)) { \
} else if (X.template IsType<float16>()) { \
DISPATCH_TYPE_TO_ALL(float16); \
} else if (XIsType(X, float)) { \
} else if (X.template IsType<float>()) { \
DISPATCH_TYPE_TO_ALL(float); \
} else if (XIsType(X, double)) { \
} else if (X.template IsType<double>()) { \
DISPATCH_TYPE_TO_ALL(double); \
} else { \
LOG(FATAL) << MessageForUnsupported( \
......@@ -78,14 +78,14 @@ void CastGradientOp<Context>::RunOnDevice() {
DISPATCH_WITH_TENSOR(Input(-1));
}
DEPLOY_CPU(Cast);
DEPLOY_CPU_OPERATOR(Cast);
#ifdef USE_CUDA
DEPLOY_CUDA(Cast);
DEPLOY_CUDA_OPERATOR(Cast);
#endif
DEPLOY_CPU(CastGradient);
DEPLOY_CPU_OPERATOR(CastGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(CastGradient);
DEPLOY_CUDA_OPERATOR(CastGradient);
#endif
OPERATOR_SCHEMA(Cast)
......
......@@ -59,12 +59,12 @@ void ChannelNormalizeOp<Context>::DoRunWithType() {
template <class Context>
void ChannelNormalizeOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(ChannelNormalize);
DEPLOY_CPU_OPERATOR(ChannelNormalize);
#ifdef USE_CUDA
DEPLOY_CUDA(ChannelNormalize);
DEPLOY_CUDA_OPERATOR(ChannelNormalize);
#endif
OPERATOR_SCHEMA(ChannelNormalize)
......
......@@ -22,9 +22,9 @@ class ChannelNormalizeOp final : public Operator<Context> {
public:
ChannelNormalizeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, perm);
auto mean = OpArgs<float>("mean");
auto std = OpArgs<float>("std");
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, perm);
auto mean = OP_REPEATED_ARG(float, "mean");
auto std = OP_REPEATED_ARG(float, "std");
CHECK_EQ(mean.size(), std.size())
<< "\nSize of <mean> and <std> should be same.";
X_mean_.Reshape({(int64_t)mean.size()});
......@@ -47,10 +47,10 @@ class ChannelNormalizeOp final : public Operator<Context> {
protected:
Tensor X_mean_, X_std_;
DECLARE_ARGS_WITH_DESC(int64_t, perm);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, perm);
};
DEFINE_ARGS_WITH_DESC(int64_t, ChannelNormalizeOp, perm);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, ChannelNormalizeOp, perm);
} // namespace dragon
......
......@@ -25,7 +25,7 @@ void ChannelShuffleOp<Context>::DoRunWithType() {
template <class Context>
void ChannelShuffleOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -49,14 +49,14 @@ void ChannelShuffleGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(ChannelShuffle);
DEPLOY_CPU_OPERATOR(ChannelShuffle);
#ifdef USE_CUDA
DEPLOY_CUDA(ChannelShuffle);
DEPLOY_CUDA_OPERATOR(ChannelShuffle);
#endif
DEPLOY_CPU(ChannelShuffleGradient);
DEPLOY_CPU_OPERATOR(ChannelShuffleGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(ChannelShuffleGradient);
DEPLOY_CUDA_OPERATOR(ChannelShuffleGradient);
#endif
OPERATOR_SCHEMA(ChannelShuffle)
......
......@@ -21,7 +21,8 @@ template <class Context>
class ChannelShuffleOp final : public Operator<Context> {
public:
ChannelShuffleOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), group_(OpArg<int64_t>("group", 1)) {}
: Operator<Context>(def, ws),
group_(OP_SINGLE_ARG(int64_t, "group", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -37,7 +38,8 @@ template <class Context>
class ChannelShuffleGradientOp final : public Operator<Context> {
public:
ChannelShuffleGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), group_(OpArg<int64_t>("group", 1)) {}
: Operator<Context>(def, ws),
group_(OP_SINGLE_ARG(int64_t, "group", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -45,7 +45,7 @@ void ConcatOp<Context>::DoRunWithType() {
template <class Context>
void ConcatOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -77,14 +77,14 @@ void ConcatGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Concat);
DEPLOY_CPU_OPERATOR(Concat);
#ifdef USE_CUDA
DEPLOY_CUDA(Concat);
DEPLOY_CUDA_OPERATOR(Concat);
#endif
DEPLOY_CPU(ConcatGradient);
DEPLOY_CPU_OPERATOR(ConcatGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(ConcatGradient);
DEPLOY_CUDA_OPERATOR(ConcatGradient);
#endif
OPERATOR_SCHEMA(Concat)
......
......@@ -17,23 +17,23 @@
namespace dragon {
#define DECLARE_CUM_OP(name) \
template <class Context> \
class name##Op final : public Operator<Context> { \
public: \
name##Op(const OperatorDef& def, Workspace* ws) \
: Operator<Context>(def, ws), \
exclusive_(OpArg<int64_t>("exclusive", 0)), \
reverse_(OpArg<int64_t>("reverse", 0)) {} \
USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \
\
template <typename T> \
void DoRunWithType(); \
\
protected: \
int64_t exclusive_, reverse_; \
#define DECLARE_CUM_OP(name) \
template <class Context> \
class name##Op final : public Operator<Context> { \
public: \
name##Op(const OperatorDef& def, Workspace* ws) \
: Operator<Context>(def, ws), \
exclusive_(OP_SINGLE_ARG(int64_t, "exclusive", 0)), \
reverse_(OP_SINGLE_ARG(int64_t, "reverse", 0)) {} \
USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \
\
template <typename T> \
void DoRunWithType(); \
\
protected: \
int64_t exclusive_, reverse_; \
};
DECLARE_CUM_OP(CumSum);
......
......@@ -23,7 +23,7 @@ void CumSumOp<Context>::DoRunWithType() {
template <class Context>
void CumSumOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -48,14 +48,14 @@ void CumSumGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(CumSum);
DEPLOY_CPU_OPERATOR(CumSum);
#ifdef USE_CUDA
DEPLOY_CUDA(CumSum);
DEPLOY_CUDA_OPERATOR(CumSum);
#endif
DEPLOY_CPU(CumSumGradient);
DEPLOY_CPU_OPERATOR(CumSumGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(CumSumGradient);
DEPLOY_CUDA_OPERATOR(CumSumGradient);
#endif
OPERATOR_SCHEMA(CumSum)
......
......@@ -29,14 +29,14 @@ void ExpandDimsOp<Context>::RunOnDevice() {
Y->Reshape(out_shape)->CopyFrom(X, ctx());
}
DEPLOY_CPU(ExpandDims);
DEPLOY_CPU_OPERATOR(ExpandDims);
#ifdef USE_CUDA
DEPLOY_CUDA(ExpandDims);
DEPLOY_CUDA_OPERATOR(ExpandDims);
#endif
DEPLOY_CPU(ExpandDimsGradient);
DEPLOY_CPU_OPERATOR(ExpandDimsGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(ExpandDimsGradient);
DEPLOY_CUDA_OPERATOR(ExpandDimsGradient);
#endif
OPERATOR_SCHEMA(ExpandDims)
......
......@@ -36,7 +36,7 @@ void ExpandOp<Context>::DoRunWithType() {
template <class Context>
void ExpandOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -71,14 +71,14 @@ void ExpandGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Expand);
DEPLOY_CPU_OPERATOR(Expand);
#ifdef USE_CUDA
DEPLOY_CUDA(Expand);
DEPLOY_CUDA_OPERATOR(Expand);
#endif
DEPLOY_CPU(ExpandGradient);
DEPLOY_CPU_OPERATOR(ExpandGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(ExpandGradient);
DEPLOY_CUDA_OPERATOR(ExpandGradient);
#endif
OPERATOR_SCHEMA(Expand)
......
......@@ -21,7 +21,7 @@ template <class Context>
class ExpandOp final : public Operator<Context> {
public:
ExpandOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, dims);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, dims);
}
USE_OPERATOR_FUNCTIONS;
......@@ -31,7 +31,7 @@ class ExpandOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARGS_WITH_DESC(int64_t, dims);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, dims);
};
template <class Context>
......@@ -46,7 +46,7 @@ class ExpandGradientOp final : public Operator<Context> {
void DoRunWithType();
};
DEFINE_ARGS_WITH_DESC(int64_t, ExpandOp, dims);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, ExpandOp, dims);
} // namespace dragon
......
......@@ -40,14 +40,14 @@ void FlattenOp<Context>::RunOnDevice() {
Y->Reshape(out_shape)->CopyFrom(X, ctx());
}
DEPLOY_CPU(Flatten);
DEPLOY_CPU_OPERATOR(Flatten);
#ifdef USE_CUDA
DEPLOY_CUDA(Flatten);
DEPLOY_CUDA_OPERATOR(Flatten);
#endif
DEPLOY_CPU(FlattenGradient);
DEPLOY_CPU_OPERATOR(FlattenGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(FlattenGradient);
DEPLOY_CUDA_OPERATOR(FlattenGradient);
#endif
OPERATOR_SCHEMA(Flatten)
......
......@@ -7,7 +7,7 @@ namespace dragon {
#define CANONICALIZE_AXES_WITH_TENSOR(tensor) \
CANONICALIZE_AXIS_WITH_TENSOR(tensor); \
auto num_axes = OpArg<int64_t>("num_axes", 1); \
auto num_axes = OP_SINGLE_ARG(int64_t, "num_axes", 1); \
if (num_axes < 0) { \
num_axes = tensor.ndim() - axis; \
} else if (num_axes == 0) { \
......@@ -24,7 +24,8 @@ void IndexSelectOp<Context>::DoRunWithType() {
CANONICALIZE_AXES_WITH_TENSOR(X);
CHECK_GT(X_index.count(), 0) << "\nLength of indices must > 0.";
CHECK(XIsType(X_index, int64_t)) << "\nType of index should be int64.";
CHECK(X_index.template IsType<int64_t>())
<< "\nType of index should be int64.";
vec64_t X_dims(X.dims());
vec64_t Y_dims(X_dims.begin(), X_dims.begin() + axis);
......@@ -48,7 +49,7 @@ void IndexSelectOp<Context>::DoRunWithType() {
template <class Context>
void IndexSelectOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -81,14 +82,14 @@ void IndexSelectGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(IndexSelect);
DEPLOY_CPU_OPERATOR(IndexSelect);
#ifdef USE_CUDA
DEPLOY_CUDA(IndexSelect);
DEPLOY_CUDA_OPERATOR(IndexSelect);
#endif
DEPLOY_CPU(IndexSelectGradient);
DEPLOY_CPU_OPERATOR(IndexSelectGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(IndexSelectGradient);
DEPLOY_CUDA_OPERATOR(IndexSelectGradient);
#endif
OPERATOR_SCHEMA(IndexSelect)
......
......@@ -90,50 +90,50 @@ DISPATCH_WITH_TENSOR_TYPES(RandomUniform, FloatingTensorTypes);
DISPATCH_WITH_TENSOR_TYPES(TruncatedNormal, FloatingTensorTypes);
DISPATCH_WITH_TENSOR_TYPES(GlorotNormal, FloatingTensorTypes);
DISPATCH_WITH_TENSOR_TYPES(GlorotUniform, FloatingTensorTypes);
DISPATCH_WITH_TENSOR_TYPES(Fill, AllTensorTypes);
DISPATCH_WITH_TENSOR_TYPES(Eye, AllTensorTypes);
DISPATCH_WITH_TENSOR_TYPES(Fill, FullTensorTypes);
DISPATCH_WITH_TENSOR_TYPES(Eye, FullTensorTypes);
#undef DISPATCH_WITH_TYPES
#undef DISPATCH_WITH_TENSOR_TYPES
DEPLOY_CPU(Fill);
DEPLOY_CPU_OPERATOR(Fill);
#ifdef USE_CUDA
DEPLOY_CUDA(Fill);
DEPLOY_CUDA_OPERATOR(Fill);
#endif
DEPLOY_CPU(Eye);
DEPLOY_CPU_OPERATOR(Eye);
#ifdef USE_CUDA
DEPLOY_CUDA(Eye);
DEPLOY_CUDA_OPERATOR(Eye);
#endif
DEPLOY_CPU(GivenTensorFill);
DEPLOY_CPU_OPERATOR(GivenTensorFill);
#ifdef USE_CUDA
DEPLOY_CUDA(GivenTensorFill);
DEPLOY_CUDA_OPERATOR(GivenTensorFill);
#endif
DEPLOY_CPU(RandomNormal);
DEPLOY_CPU_OPERATOR(RandomNormal);
#ifdef USE_CUDA
DEPLOY_CUDA(RandomNormal);
DEPLOY_CUDA_OPERATOR(RandomNormal);
#endif
DEPLOY_CPU(RandomUniform);
DEPLOY_CPU_OPERATOR(RandomUniform);
#ifdef USE_CUDA
DEPLOY_CUDA(RandomUniform);
DEPLOY_CUDA_OPERATOR(RandomUniform);
#endif
#ifdef USE_CUDA
DEPLOY_CPU_CUDA(TruncatedNormal);
DEPLOY_CPU_CUDA_OPERATOR(TruncatedNormal);
#else
DEPLOY_CPU(TruncatedNormal);
DEPLOY_CPU_OPERATOR(TruncatedNormal);
#endif
DEPLOY_CPU(GlorotNormal);
DEPLOY_CPU_OPERATOR(GlorotNormal);
#ifdef USE_CUDA
DEPLOY_CUDA(GlorotNormal);
DEPLOY_CUDA_OPERATOR(GlorotNormal);
#endif
DEPLOY_CPU(GlorotUniform);
DEPLOY_CPU_OPERATOR(GlorotUniform);
#ifdef USE_CUDA
DEPLOY_CUDA(GlorotUniform);
DEPLOY_CUDA_OPERATOR(GlorotUniform);
#endif
OPERATOR_SCHEMA(Fill).NumInputs(0, 1).NumOutputs(1);
......
......@@ -23,7 +23,7 @@ class InitializeOp : public Operator<Context> {
public:
InitializeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, dims);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, dims);
}
USE_OPERATOR_FUNCTIONS;
......@@ -31,14 +31,15 @@ class InitializeOp : public Operator<Context> {
protected:
FillerInfo filler_info_;
DECLARE_ARGS_WITH_DESC(int64_t, dims);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, dims);
};
template <class Context>
class FillOp final : public InitializeOp<Context> {
public:
FillOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws), value_(OpArg<float>("value", 0.f)) {}
: InitializeOp<Context>(def, ws),
value_(OP_SINGLE_ARG(float, "value", 0.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -51,10 +52,10 @@ class FillOp final : public InitializeOp<Context> {
};
template <class Context>
class ArangeOp final : public Operator<Context> {
class RangeOp final : public Operator<Context> {
public:
ArangeOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(float, slice);
RangeOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
INIT_OP_REPEATED_ARG_WITH_DESC(float, slice);
}
USE_OPERATOR_FUNCTIONS;
......@@ -64,14 +65,32 @@ class ArangeOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARGS_WITH_DESC(float, slice);
DECLARE_OP_REPEATED_ARG_WITH_DESC(float, slice);
};
template <class Context>
class PermutationOp final : public Operator<Context> {
public:
PermutationOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
INIT_OP_SINGLE_ARG_WITH_DESC(int64_t, limit, 0);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
DECLARE_OP_SINGLE_ARG_WITH_DESC(int64_t, limit);
};
template <class Context>
class EyeOp final : public InitializeOp<Context> {
public:
EyeOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws), k_(OpArg<int64_t>("k", 0)) {}
: InitializeOp<Context>(def, ws), k_(OP_SINGLE_ARG(int64_t, "k", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -96,7 +115,7 @@ template <class Context>
class GivenTensorFillOp final : public Operator<Context> {
public:
GivenTensorFillOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), shape_(OpArgs<int64_t>("shape")) {}
: Operator<Context>(def, ws), shape_(OP_REPEATED_ARG(int64_t, "shape")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -108,7 +127,7 @@ class GivenTensorFillOp final : public Operator<Context> {
template <typename T>
void ExtractImpl(TypeIdentity<T>) {
auto raw_values = OpArgs<T>("values");
auto raw_values = OP_REPEATED_ARG(T, "values");
auto nelements = (int64_t)raw_values.size();
values_.Reshape({nelements});
memcpy(
......@@ -118,7 +137,7 @@ class GivenTensorFillOp final : public Operator<Context> {
}
void ExtractImpl(TypeIdentity<float16>) {
auto raw_values = OpArgs<float>("values");
auto raw_values = OP_REPEATED_ARG(float, "values");
auto nelements = (int64_t)raw_values.size();
values_.Reshape({nelements});
memcpy(
......@@ -140,8 +159,8 @@ class RandomNormalOp final : public InitializeOp<Context> {
public:
RandomNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
auto mu = OpArg<float>("mean", 0.f);
auto sigma = OpArg<float>("std", 1.f);
auto mu = OP_SINGLE_ARG(float, "mean", 0.f);
auto sigma = OP_SINGLE_ARG(float, "std", 1.f);
this->filler_info_.set_mean(mu);
this->filler_info_.set_std(sigma);
this->filler_info_.set_type("normal");
......@@ -159,8 +178,8 @@ class RandomUniformOp final : public InitializeOp<Context> {
public:
RandomUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
auto low = OpArg<float>("low", -1.f);
auto high = OpArg<float>("high", 1.f);
auto low = OP_SINGLE_ARG(float, "low", -1.f);
auto high = OP_SINGLE_ARG(float, "high", 1.f);
this->filler_info_.set_low(low);
this->filler_info_.set_high(high);
this->filler_info_.set_type("uniform");
......@@ -178,8 +197,8 @@ class TruncatedNormalOp final : public InitializeOp<Context> {
public:
TruncatedNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
auto mu = OpArg<float>("mean", 0.f);
auto sigma = OpArg<float>("std", 1.f);
auto mu = OP_SINGLE_ARG(float, "mean", 0.f);
auto sigma = OP_SINGLE_ARG(float, "std", 1.f);
this->filler_info_.set_mean(mu);
this->filler_info_.set_std(sigma);
this->filler_info_.set_low(mu - 2 * sigma);
......@@ -199,8 +218,8 @@ class GlorotNormalOp final : public InitializeOp<Context> {
public:
GlorotNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
auto scale = OpArg<float>("scale", 2.f);
auto mode = OpArg<string>("mode", "fan_in");
auto scale = OP_SINGLE_ARG(float, "scale", 2.f);
auto mode = OP_SINGLE_ARG(string, "mode", "fan_in");
this->filler_info_.set_type("glorot_normal");
if (mode == "fan_avg") {
this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_AVG);
......@@ -224,8 +243,8 @@ class GlorotUniformOp final : public InitializeOp<Context> {
public:
GlorotUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
auto scale = OpArg<float>("scale", 3.f);
auto mode = OpArg<string>("mode", "fan_in");
auto scale = OP_SINGLE_ARG(float, "scale", 3.f);
auto mode = OP_SINGLE_ARG(string, "mode", "fan_in");
this->filler_info_.set_type("glorot_uniform");
if (mode == "fan_avg") {
this->filler_info_.set_variance_norm(FillerInfo_VarianceNorm_FAN_AVG);
......@@ -244,8 +263,9 @@ class GlorotUniformOp final : public InitializeOp<Context> {
void DoRunWithType();
};
DEFINE_ARGS_WITH_DESC(int64_t, InitializeOp, dims);
DEFINE_ARGS_WITH_DESC(float, ArangeOp, slice);
DEFINE_OP_SINGLE_ARG_WITH_DESC(int64_t, PermutationOp, limit);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, InitializeOp, dims);
DEFINE_OP_REPEATED_ARG_WITH_DESC(float, RangeOp, slice);
} // namespace dragon
......
......@@ -11,7 +11,7 @@ void MaskedSelectOp<Context>::DoRunWithType() {
CHECK_EQ(X.count(), X_mask.count())
<< "\nSize of mask and input should be equal.";
CHECK(XIsType(X_mask, bool) || XIsType(X_mask, uint8_t))
CHECK(X_mask.template IsType<bool>() || X_mask.template IsType<uint8_t>())
<< "\nExcepted bool or uint8 mask.";
// Store for the gradient calculation
......@@ -52,7 +52,7 @@ void MaskedSelectOp<Context>::DoRunWithType() {
template <class Context>
void MaskedSelectOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -75,14 +75,14 @@ void MaskedSelectGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(MaskedSelect);
DEPLOY_CPU_OPERATOR(MaskedSelect);
#ifdef USE_CUDA
DEPLOY_CUDA(MaskedSelect);
DEPLOY_CUDA_OPERATOR(MaskedSelect);
#endif
DEPLOY_CPU(MaskedSelectGradient);
DEPLOY_CPU_OPERATOR(MaskedSelectGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(MaskedSelectGradient);
DEPLOY_CUDA_OPERATOR(MaskedSelectGradient);
#endif
OPERATOR_SCHEMA(MaskedSelect)
......
......@@ -64,9 +64,9 @@ void MultinomialOp<Context>::RunOnDevice() {
DispatchHelper<TensorTypes<float, double>>::Call(this, Input(0));
}
DEPLOY_CPU(Multinomial);
DEPLOY_CPU_OPERATOR(Multinomial);
#ifdef USE_CUDA
DEPLOY_CUDA(Multinomial);
DEPLOY_CUDA_OPERATOR(Multinomial);
#endif
OPERATOR_SCHEMA(Multinomial)
......
......@@ -22,9 +22,9 @@ class MultinomialOp final : public Operator<Context> {
public:
MultinomialOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
epsilon_(OpArg<double>("epsilon", 0.)),
normalize_(OpArg<int64_t>("normalize", 0)),
num_samples_(OpArg<int64_t>("num_samples", 1)) {}
epsilon_(OP_SINGLE_ARG(double, "epsilon", 0.)),
normalize_(OP_SINGLE_ARG(int64_t, "normalize", 0)),
num_samples_(OP_SINGLE_ARG(int64_t, "num_samples", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -55,12 +55,12 @@ void NonZeroOp<Context>::DoRunWithType() {
template <class Context>
void NonZeroOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(NonZero);
DEPLOY_CPU_OPERATOR(NonZero);
#ifdef USE_CUDA
DEPLOY_CUDA(NonZero);
DEPLOY_CUDA_OPERATOR(NonZero);
#endif
OPERATOR_SCHEMA(NonZero)
......
......@@ -34,9 +34,9 @@ void OneHotOp<Context>::RunOnDevice() {
DispatchHelper<TensorTypes<int, int64_t, float>>::Call(this, Input(0));
}
DEPLOY_CPU(OneHot);
DEPLOY_CPU_OPERATOR(OneHot);
#ifdef USE_CUDA
DEPLOY_CUDA(OneHot);
DEPLOY_CUDA_OPERATOR(OneHot);
#endif
OPERATOR_SCHEMA(OneHot)
......
......@@ -22,9 +22,9 @@ class OneHotOp final : public Operator<Context> {
public:
OneHotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
depth_(OpArg<int64_t>("depth", -1)),
on_value_(OpArg<int64_t>("on_value", 1)),
off_value_(OpArg<int64_t>("off_value", 0)) {}
depth_(OP_SINGLE_ARG(int64_t, "depth", -1)),
on_value_(OP_SINGLE_ARG(int64_t, "on_value", 1)),
off_value_(OP_SINGLE_ARG(int64_t, "off_value", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -80,7 +80,7 @@ void PadOp<Context>::DoRunWithType() {
template <class Context>
void PadOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -124,14 +124,14 @@ void PadGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Pad);
DEPLOY_CPU_OPERATOR(Pad);
#ifdef USE_CUDA
DEPLOY_CUDA(Pad);
DEPLOY_CUDA_OPERATOR(Pad);
#endif
DEPLOY_CPU(PadGradient);
DEPLOY_CPU_OPERATOR(PadGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(PadGradient);
DEPLOY_CUDA_OPERATOR(PadGradient);
#endif
OPERATOR_SCHEMA(Pad)
......
......@@ -22,9 +22,9 @@ class PadOp final : public Operator<Context> {
public:
PadOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
value_(OpArg<float>("value", 0.f)),
mode_(OpArg<string>("mode", "CONSTANT")) {
GET_ARGS_WITH_DESC(int64_t, pads);
value_(OP_SINGLE_ARG(float, "value", 0.f)),
mode_(OP_SINGLE_ARG(string, "mode", "CONSTANT")) {
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, pads);
}
USE_OPERATOR_FUNCTIONS;
......@@ -36,7 +36,7 @@ class PadOp final : public Operator<Context> {
protected:
float value_;
string mode_;
DECLARE_ARGS_WITH_DESC(int64_t, pads);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, pads);
};
template <class Context>
......@@ -44,9 +44,9 @@ class PadGradientOp final : public Operator<Context> {
public:
PadGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
pad_l_(OpArgs<int64_t>("pad_l")),
pad_r_(OpArgs<int64_t>("pad_r")),
mode_(OpArg<string>("mode", "CONSTANT")) {
pad_l_(OP_REPEATED_ARG(int64_t, "pad_l")),
pad_r_(OP_REPEATED_ARG(int64_t, "pad_r")),
mode_(OP_SINGLE_ARG(string, "mode", "CONSTANT")) {
if (pad_r_.empty()) {
pad_r_ = pad_l_;
} else {
......@@ -66,7 +66,7 @@ class PadGradientOp final : public Operator<Context> {
vec64_t pad_l_, pad_r_;
};
DEFINE_ARGS_WITH_DESC(int64_t, PadOp, pads);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, PadOp, pads);
} // namespace dragon
......
#include "dragon/core/workspace.h"
#include "dragon/operators/array/initialize_ops.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void PermutationOp<Context>::DoRunWithType() {
auto* Y = Output(0)->Reshape({limit()});
kernel::Permutation(
Y->count(),
Y->template mutable_data<T, Context>(),
ws()->template data<uint32_t, Context>({Y->count()})[0],
ctx());
}
template <class Context>
void PermutationOp<Context>::RunOnDevice() {
DispatchHelper<NumericalTensorTypes>::Call(this);
}
DEPLOY_CPU_OPERATOR(Permutation);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Permutation);
#endif
OPERATOR_SCHEMA(Permutation).NumInputs(0).NumOutputs(1);
NO_GRADIENT(Permutation);
} // namespace dragon
......@@ -6,46 +6,46 @@ namespace dragon {
template <class Context>
template <typename T>
void ArangeOp<Context>::DoRunWithType() {
void RangeOp<Context>::DoRunWithType() {
// Determine the slice arguments
int num_args;
float start = 0.f, stop, step;
float start = 0.f, limit, delta;
slice(0, &num_args);
if (num_args == 2) {
stop = slice(0), step = slice(1);
limit = slice(0), delta = slice(1);
} else if (num_args == 3) {
start = slice(0), stop = slice(1), step = slice(2);
start = slice(0), limit = slice(1), delta = slice(2);
} else {
LOG(FATAL) << "Unexcepted number of slice arguments: " << num_args;
}
// Determine the generating range
// Values are in a half-open interval: [start, stop)
auto count = (int64_t)std::ceil((stop - start) / step);
auto count = (int64_t)std::ceil((limit - start) / delta);
CHECK_GT(count, 0) << "\nInvalid generating range: "
<< "[" << start << ", " << stop << ") with step = " << step
<< ".";
<< "[" << start << ", " << limit
<< ") with delta = " << delta << ".";
kernel::Arange(
kernel::Range(
count,
start,
step,
delta,
Output(0)->Reshape({count})->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void ArangeOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this);
void RangeOp<Context>::RunOnDevice() {
DispatchHelper<NumericalTensorTypes>::Call(this);
}
DEPLOY_CPU(Arange);
DEPLOY_CPU_OPERATOR(Range);
#ifdef USE_CUDA
DEPLOY_CUDA(Arange);
DEPLOY_CUDA_OPERATOR(Range);
#endif
OPERATOR_SCHEMA(Arange).NumInputs(0).NumOutputs(1);
OPERATOR_SCHEMA(Range).NumInputs(0).NumOutputs(1);
NO_GRADIENT(Arange);
NO_GRADIENT(Range);
} // namespace dragon
......@@ -47,12 +47,12 @@ void ReduceMaxOp<Context>::DoRunWithType() {
template <class Context>
void ReduceMaxOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(ReduceMax);
DEPLOY_CPU_OPERATOR(ReduceMax);
#ifdef USE_CUDA
DEPLOY_CUDA(ReduceMax);
DEPLOY_CUDA_OPERATOR(ReduceMax);
#endif
OPERATOR_SCHEMA(ReduceMax)
......
......@@ -55,7 +55,7 @@ void ReduceMeanOp<Context>::DoRunWithType() {
template <class Context>
void ReduceMeanOp<Context>::RunOnDevice() {
STORE_INPUT_SPEC(0);
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -86,14 +86,14 @@ void ReduceMeanGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(ReduceMean);
DEPLOY_CPU_OPERATOR(ReduceMean);
#ifdef USE_CUDA
DEPLOY_CUDA(ReduceMean);
DEPLOY_CUDA_OPERATOR(ReduceMean);
#endif
DEPLOY_CPU(ReduceMeanGradient);
DEPLOY_CPU_OPERATOR(ReduceMeanGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(ReduceMeanGradient);
DEPLOY_CUDA_OPERATOR(ReduceMeanGradient);
#endif
OPERATOR_SCHEMA(ReduceMean)
......
......@@ -47,12 +47,12 @@ void ReduceMinOp<Context>::DoRunWithType() {
template <class Context>
void ReduceMinOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(ReduceMin);
DEPLOY_CPU_OPERATOR(ReduceMin);
#ifdef USE_CUDA
DEPLOY_CUDA(ReduceMin);
DEPLOY_CUDA_OPERATOR(ReduceMin);
#endif
OPERATOR_SCHEMA(ReduceMin)
......
......@@ -17,24 +17,24 @@
namespace dragon {
#define DECLARE_REDUCE_OP(name) \
template <class Context> \
class name##Op final : public Operator<Context> { \
public: \
name##Op(const OperatorDef& def, Workspace* ws) \
: Operator<Context>(def, ws), \
axes_(OpArgs<int64_t>("axes")), \
keep_dims_(OpArg<int64_t>("keep_dims", 0)) {} \
USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \
\
template <typename T> \
void DoRunWithType(); \
\
protected: \
int64_t keep_dims_; \
vec64_t axes_; \
#define DECLARE_REDUCE_OP(name) \
template <class Context> \
class name##Op final : public Operator<Context> { \
public: \
name##Op(const OperatorDef& def, Workspace* ws) \
: Operator<Context>(def, ws), \
axes_(OP_REPEATED_ARG(int64_t, "axes")), \
keep_dims_(OP_SINGLE_ARG(int64_t, "keep_dims", 0)) {} \
USE_OPERATOR_FUNCTIONS; \
\
void RunOnDevice() override; \
\
template <typename T> \
void DoRunWithType(); \
\
protected: \
int64_t keep_dims_; \
vec64_t axes_; \
};
#define DECLARE_REDUCE_GRAD_OP(name) \
......
......@@ -54,7 +54,7 @@ void ReduceSumOp<Context>::DoRunWithType() {
template <class Context>
void ReduceSumOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -85,14 +85,14 @@ void ReduceSumGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(ReduceSum);
DEPLOY_CPU_OPERATOR(ReduceSum);
#ifdef USE_CUDA
DEPLOY_CUDA(ReduceSum);
DEPLOY_CUDA_OPERATOR(ReduceSum);
#endif
DEPLOY_CPU(ReduceSumGradient);
DEPLOY_CPU_OPERATOR(ReduceSumGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(ReduceSumGradient);
DEPLOY_CUDA_OPERATOR(ReduceSumGradient);
#endif
OPERATOR_SCHEMA(ReduceSum)
......
......@@ -43,7 +43,7 @@ void RepeatOp<Context>::DoRunWithType() {
template <class Context>
void RepeatOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -79,14 +79,14 @@ void RepeatGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Repeat);
DEPLOY_CPU_OPERATOR(Repeat);
#ifdef USE_CUDA
DEPLOY_CUDA(Repeat);
DEPLOY_CUDA_OPERATOR(Repeat);
#endif
DEPLOY_CPU(RepeatGradient);
DEPLOY_CPU_OPERATOR(RepeatGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(RepeatGradient);
DEPLOY_CUDA_OPERATOR(RepeatGradient);
#endif
OPERATOR_SCHEMA(Repeat)
......
......@@ -21,7 +21,7 @@ template <class Context>
class RepeatOp final : public Operator<Context> {
public:
RepeatOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
GET_ARG_WITH_DESC(int64_t, repeats, 1);
INIT_OP_SINGLE_ARG_WITH_DESC(int64_t, repeats, 1);
}
USE_OPERATOR_FUNCTIONS;
......@@ -31,7 +31,7 @@ class RepeatOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARG_WITH_DESC(int64_t, repeats);
DECLARE_OP_SINGLE_ARG_WITH_DESC(int64_t, repeats);
};
template <class Context>
......@@ -39,7 +39,7 @@ class RepeatGradientOp final : public Operator<Context> {
public:
RepeatGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARG_WITH_DESC(int64_t, repeats, 1);
INIT_OP_SINGLE_ARG_WITH_DESC(int64_t, repeats, 1);
}
USE_OPERATOR_FUNCTIONS;
......@@ -49,11 +49,11 @@ class RepeatGradientOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARG_WITH_DESC(int64_t, repeats);
DECLARE_OP_SINGLE_ARG_WITH_DESC(int64_t, repeats);
};
DEFINE_ARG_WITH_DESC(int64_t, RepeatOp, repeats);
DEFINE_ARG_WITH_DESC(int64_t, RepeatGradientOp, repeats);
DEFINE_OP_SINGLE_ARG_WITH_DESC(int64_t, RepeatOp, repeats);
DEFINE_OP_SINGLE_ARG_WITH_DESC(int64_t, RepeatGradientOp, repeats);
} // namespace dragon
......
......@@ -53,14 +53,14 @@ void ReshapeOp<Context>::RunOnDevice() {
Y->Reshape(out_shape)->CopyFrom(X, ctx());
}
DEPLOY_CPU(Reshape);
DEPLOY_CPU_OPERATOR(Reshape);
#ifdef USE_CUDA
DEPLOY_CUDA(Reshape);
DEPLOY_CUDA_OPERATOR(Reshape);
#endif
DEPLOY_CPU(ReshapeGradient);
DEPLOY_CPU_OPERATOR(ReshapeGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(ReshapeGradient);
DEPLOY_CUDA_OPERATOR(ReshapeGradient);
#endif
OPERATOR_SCHEMA(Reshape)
......
......@@ -35,14 +35,14 @@ class ReshapeOp final : public Operator<Context> {
public:
ReshapeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, dims);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, dims);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
protected:
DECLARE_ARGS_WITH_DESC(int64_t, dims);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, dims);
};
template <class Context>
......@@ -50,8 +50,8 @@ class FlattenOp final : public Operator<Context> {
public:
FlattenOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
num_axes_(OpArg<int64_t>("num_axes", -1)),
keep_axes_(OpArg<int64_t>("keep_axes", INT_MAX)) {}
num_axes_(OP_SINGLE_ARG(int64_t, "num_axes", -1)),
keep_axes_(OP_SINGLE_ARG(int64_t, "keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -64,7 +64,7 @@ template <class Context>
class ExpandDimsOp final : public Operator<Context> {
public:
ExpandDimsOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), axes_(OpArgs<int64_t>("axes")) {}
: Operator<Context>(def, ws), axes_(OP_REPEATED_ARG(int64_t, "axes")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -77,7 +77,7 @@ template <class Context>
class SqueezeOp final : public Operator<Context> {
public:
SqueezeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), axes_(OpArgs<int64_t>("axes")) {}
: Operator<Context>(def, ws), axes_(OP_REPEATED_ARG(int64_t, "axes")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -100,7 +100,7 @@ DEFINE_GRADIENT_OP(ExpandDims);
DEFINE_GRADIENT_OP(Squeeze);
#undef DEFINE_GRADIENT_OP
DEFINE_ARGS_WITH_DESC(int64_t, ReshapeOp, dims);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, ReshapeOp, dims);
} // namespace dragon
......
......@@ -7,9 +7,9 @@ void ShapeOp<Context>::RunOnDevice() {
Output(0)->template CopyFrom<int64_t>(Input(0).dims());
}
DEPLOY_CPU(Shape);
DEPLOY_CPU_OPERATOR(Shape);
#ifdef USE_CUDA
DEPLOY_CUDA(Shape);
DEPLOY_CUDA_OPERATOR(Shape);
#endif
OPERATOR_SCHEMA(Shape)
......
......@@ -66,7 +66,7 @@ void SliceOp<Context>::DoRunWithType() {
template <class Context>
void SliceOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -102,17 +102,17 @@ void SliceGradientOp<Context>::DoRunWithType() {
template <class Context>
void SliceGradientOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Slice);
DEPLOY_CPU_OPERATOR(Slice);
#ifdef USE_CUDA
DEPLOY_CUDA(Slice);
DEPLOY_CUDA_OPERATOR(Slice);
#endif
DEPLOY_CPU(SliceGradient);
DEPLOY_CPU_OPERATOR(SliceGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(SliceGradient);
DEPLOY_CUDA_OPERATOR(SliceGradient);
#endif
OPERATOR_SCHEMA(Slice)
......
......@@ -21,8 +21,8 @@ template <class Context>
class SliceOp final : public Operator<Context> {
public:
SliceOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, starts);
GET_ARGS_WITH_DESC(int64_t, sizes);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, starts);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, sizes);
}
USE_OPERATOR_FUNCTIONS;
......@@ -32,8 +32,8 @@ class SliceOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARGS_WITH_DESC(int64_t, starts);
DECLARE_ARGS_WITH_DESC(int64_t, sizes);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, starts);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, sizes);
};
template <class Context>
......@@ -48,8 +48,8 @@ class SliceGradientOp final : public Operator<Context> {
void DoRunWithType();
};
DEFINE_ARGS_WITH_DESC(int64_t, SliceOp, starts);
DEFINE_ARGS_WITH_DESC(int64_t, SliceOp, sizes);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, SliceOp, starts);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, SliceOp, sizes);
} // namespace dragon
......
......@@ -6,8 +6,8 @@
namespace dragon {
#define DETERMINE_RUNTIME_ARGS(tensor) \
auto size_splits = OpArgs<int64_t>("size_splits"); \
auto slice_points = OpArgs<int64_t>("slice_points"); \
auto size_splits = OP_REPEATED_ARG(int64_t, "size_splits"); \
auto slice_points = OP_REPEATED_ARG(int64_t, "slice_points"); \
if (!slice_points.empty()) { \
int64_t index = 0; \
size_splits = vec64_t(num_splits); \
......@@ -60,7 +60,7 @@ void SplitOp<Context>::DoRunWithType() {
template <class Context>
void SplitOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -109,14 +109,14 @@ void SplitGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, X);
}
DEPLOY_CPU(Split);
DEPLOY_CPU_OPERATOR(Split);
#ifdef USE_CUDA
DEPLOY_CUDA(Split);
DEPLOY_CUDA_OPERATOR(Split);
#endif
DEPLOY_CPU(SplitGradient);
DEPLOY_CPU_OPERATOR(SplitGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(SplitGradient);
DEPLOY_CUDA_OPERATOR(SplitGradient);
#endif
OPERATOR_SCHEMA(Split)
......
......@@ -29,14 +29,14 @@ void SqueezeOp<Context>::RunOnDevice() {
Y->Reshape(out_shape)->CopyFrom(X, ctx());
}
DEPLOY_CPU(Squeeze);
DEPLOY_CPU_OPERATOR(Squeeze);
#ifdef USE_CUDA
DEPLOY_CUDA(Squeeze);
DEPLOY_CUDA_OPERATOR(Squeeze);
#endif
DEPLOY_CPU(SqueezeGradient);
DEPLOY_CPU_OPERATOR(SqueezeGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(SqueezeGradient);
DEPLOY_CUDA_OPERATOR(SqueezeGradient);
#endif
OPERATOR_SCHEMA(Squeeze)
......
......@@ -43,7 +43,7 @@ void StackOp<Context>::DoRunWithType() {
template <class Context>
void StackOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -71,17 +71,17 @@ void StackGradientOp<Context>::DoRunWithType() {
template <class Context>
void StackGradientOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Stack);
DEPLOY_CPU_OPERATOR(Stack);
#ifdef USE_CUDA
DEPLOY_CUDA(Stack);
DEPLOY_CUDA_OPERATOR(Stack);
#endif
DEPLOY_CPU(StackGradient);
DEPLOY_CPU_OPERATOR(StackGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(StackGradient);
DEPLOY_CUDA_OPERATOR(StackGradient);
#endif
OPERATOR_SCHEMA(Stack)
......
......@@ -33,7 +33,7 @@ void TileOp<Context>::DoRunWithType() {
template <class Context>
void TileOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -101,14 +101,14 @@ void TileGradientOp<Context>::RunOnDevice() {
}
}
DEPLOY_CPU(Tile);
DEPLOY_CPU_OPERATOR(Tile);
#ifdef USE_CUDA
DEPLOY_CUDA(Tile);
DEPLOY_CUDA_OPERATOR(Tile);
#endif
DEPLOY_CPU(TileGradient);
DEPLOY_CPU_OPERATOR(TileGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(TileGradient);
DEPLOY_CUDA_OPERATOR(TileGradient);
#endif
OPERATOR_SCHEMA(Tile)
......
......@@ -21,7 +21,7 @@ template <class Context>
class TileOp final : public Operator<Context> {
public:
TileOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, repeats);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, repeats);
}
USE_OPERATOR_FUNCTIONS;
......@@ -31,7 +31,7 @@ class TileOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARGS_WITH_DESC(int64_t, repeats);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, repeats);
};
template <class Context>
......@@ -39,7 +39,7 @@ class TileGradientOp final : public Operator<Context> {
public:
TileGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, repeats);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, repeats);
}
USE_OPERATOR_FUNCTIONS;
......@@ -51,11 +51,11 @@ class TileGradientOp final : public Operator<Context> {
protected:
Tensor *dest_, *src_, nav_;
int64_t axis_, repeat_;
DECLARE_ARGS_WITH_DESC(int64_t, repeats);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, repeats);
};
DEFINE_ARGS_WITH_DESC(int64_t, TileOp, repeats);
DEFINE_ARGS_WITH_DESC(int64_t, TileGradientOp, repeats);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, TileOp, repeats);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, TileGradientOp, repeats);
} // namespace dragon
......
......@@ -31,12 +31,12 @@ void TopKOp<Context>::DoRunWithType() {
template <class Context>
void TopKOp<Context>::RunOnDevice() {
DispatchHelper<MathTensorTypes>::Call(this, Input(0));
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(TopK);
DEPLOY_CPU_OPERATOR(TopK);
#ifdef USE_CUDA
DEPLOY_CUDA(TopK);
DEPLOY_CUDA_OPERATOR(TopK);
#endif
OPERATOR_SCHEMA(TopK)
......
......@@ -22,8 +22,8 @@ class TopKOp final : public Operator<Context> {
public:
TopKOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
k_(OpArg<int64_t>("k", 1)),
largest_(OpArg<int64_t>("largest", 1)) {}
k_(OP_SINGLE_ARG(int64_t, "k", 1)),
largest_(OP_SINGLE_ARG(int64_t, "largest", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -39,7 +39,7 @@ void TransposeOp<Context>::DoRunWithType() {
template <class Context>
void TransposeOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -66,14 +66,14 @@ void TransposeGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Transpose);
DEPLOY_CPU_OPERATOR(Transpose);
#ifdef USE_CUDA
DEPLOY_CUDA(Transpose);
DEPLOY_CUDA_OPERATOR(Transpose);
#endif
DEPLOY_CPU(TransposeGradient);
DEPLOY_CPU_OPERATOR(TransposeGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(TransposeGradient);
DEPLOY_CUDA_OPERATOR(TransposeGradient);
#endif
OPERATOR_SCHEMA(Transpose)
......
......@@ -22,7 +22,7 @@ class TransposeOp final : public Operator<Context> {
public:
TransposeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, perm);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, perm);
}
USE_OPERATOR_FUNCTIONS;
......@@ -32,7 +32,7 @@ class TransposeOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARGS_WITH_DESC(int64_t, perm);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, perm);
};
template <class Context>
......@@ -47,7 +47,7 @@ class TransposeGradientOp : public Operator<Context> {
void DoRunWithType();
};
DEFINE_ARGS_WITH_DESC(int64_t, TransposeOp, perm);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, TransposeOp, perm);
} // namespace dragon
......
......@@ -10,7 +10,7 @@ void WhereOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1);
auto &C = Input(2), *Y = Output(0);
CHECK(XIsType(C, bool) || XIsType(C, uint8_t))
CHECK(C.template IsType<bool>() || C.template IsType<uint8_t>())
<< "\nExcepted bool or uint8 condition tensor.";
vec64_t AB_dims, Y_dims;
......@@ -36,7 +36,7 @@ void WhereOp<Context>::DoRunWithType() {
template <class Context>
void WhereOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
template <class Context>
......@@ -45,7 +45,7 @@ void WhereGradientOp<Context>::DoRunWithType() {
auto &A = Input(0), &B = Input(1), &C = Input(2), &dY = Input(3);
auto *dA = Output(0), *dB = Output(1);
CHECK(XIsType(C, bool) || XIsType(C, uint8_t))
CHECK(C.template IsType<bool>() || C.template IsType<uint8_t>())
<< "\nExcepted bool or uint8 condition tensor.";
vec32_t A_broadcast_axes, B_broadcast_axes;
......@@ -155,14 +155,14 @@ void WhereGradientOp<Context>::RunOnDevice() {
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Where);
DEPLOY_CPU_OPERATOR(Where);
#ifdef USE_CUDA
DEPLOY_CUDA(Where);
DEPLOY_CUDA_OPERATOR(Where);
#endif
DEPLOY_CPU(WhereGradient);
DEPLOY_CPU_OPERATOR(WhereGradient);
#ifdef USE_CUDA
DEPLOY_CUDA(WhereGradient);
DEPLOY_CUDA_OPERATOR(WhereGradient);
#endif
OPERATOR_SCHEMA(Where)
......
......@@ -79,12 +79,12 @@ void AssignOp<Context>::DoRunWithType() {
template <class Context>
void AssignOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Assign);
DEPLOY_CPU_OPERATOR(Assign);
#ifdef USE_CUDA
DEPLOY_CUDA(Assign);
DEPLOY_CUDA_OPERATOR(Assign);
#endif
OPERATOR_SCHEMA(Assign)
......
......@@ -21,8 +21,8 @@ template <class Context>
class AssignOp final : public Operator<Context> {
public:
AssignOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
GET_ARGS_WITH_DESC(int64_t, starts);
GET_ARGS_WITH_DESC(int64_t, sizes);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, starts);
INIT_OP_REPEATED_ARG_WITH_DESC(int64_t, sizes);
}
USE_OPERATOR_FUNCTIONS;
......@@ -32,8 +32,8 @@ class AssignOp final : public Operator<Context> {
void DoRunWithType();
protected:
DECLARE_ARGS_WITH_DESC(int64_t, starts);
DECLARE_ARGS_WITH_DESC(int64_t, sizes);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, starts);
DECLARE_OP_REPEATED_ARG_WITH_DESC(int64_t, sizes);
};
template <class Context>
......@@ -48,8 +48,8 @@ class MaskedAssignOp final : public Operator<Context> {
void DoRunWithType();
};
DEFINE_ARGS_WITH_DESC(int64_t, AssignOp, starts);
DEFINE_ARGS_WITH_DESC(int64_t, AssignOp, sizes);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, AssignOp, starts);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, AssignOp, sizes);
} // namespace dragon
......
......@@ -16,12 +16,12 @@ void CopyOp<Context>::DoRunWithType() {
template <class Context>
void CopyOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(Copy);
DEPLOY_CPU_OPERATOR(Copy);
#ifdef USE_CUDA
DEPLOY_CUDA(Copy);
DEPLOY_CUDA_OPERATOR(Copy);
#endif
OPERATOR_SCHEMA(Copy)
......
......@@ -10,7 +10,7 @@ template <typename T>
void MaskedAssignOp<Context>::DoRunWithType() {
auto &X = Input(0), &X_mask = Input(1), *Y = Output(0);
CHECK(XIsType(X_mask, bool) || XIsType(X_mask, uint8_t))
CHECK(X_mask.template IsType<bool>() || X_mask.template IsType<uint8_t>())
<< "\nExcepted bool or uint8 mask.";
vec64_t X_dims, Y_dims;
......@@ -37,12 +37,12 @@ void MaskedAssignOp<Context>::DoRunWithType() {
template <class Context>
void MaskedAssignOp<Context>::RunOnDevice() {
DispatchHelper<AllTensorTypes>::Call(this, Input(0));
DispatchHelper<FullTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU(MaskedAssign);
DEPLOY_CPU_OPERATOR(MaskedAssign);
#ifdef USE_CUDA
DEPLOY_CUDA(MaskedAssign);
DEPLOY_CUDA_OPERATOR(MaskedAssign);
#endif
OPERATOR_SCHEMA(MaskedAssign)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!