Commit e90a8f1a by Ting PAN

Reformat the code style

1 parent a739c49b
Showing with 1392 additions and 1327 deletions
...@@ -17,6 +17,7 @@ List Brief ...@@ -17,6 +17,7 @@ List Brief
`Tensor.dtype`_ Return or Set the data type. `Tensor.dtype`_ Return or Set the data type.
`Tensor.set_value`_ Feed the values to C++ backend. `Tensor.set_value`_ Feed the values to C++ backend.
`Tensor.get_value`_ Fetch the values from C++ backend. `Tensor.get_value`_ Fetch the values from C++ backend.
`Tensor.convert_to`_ Converts the given value to a Tensor.
`Tensor.copy`_ Return a Tensor with same content. `Tensor.copy`_ Return a Tensor with same content.
`Tensor.reshape`_ Reshape the dimensions of input. `Tensor.reshape`_ Reshape the dimensions of input.
`Tensor.dimshuffle`_ Shuffle the dimensions. `Tensor.dimshuffle`_ Shuffle the dimensions.
...@@ -131,6 +132,7 @@ API Reference ...@@ -131,6 +132,7 @@ API Reference
.. _Tensor.dtype: #dragon.core.tensor.Tensor.dtype .. _Tensor.dtype: #dragon.core.tensor.Tensor.dtype
.. _Tensor.set_value: #dragon.core.tensor.Tensor.set_value .. _Tensor.set_value: #dragon.core.tensor.Tensor.set_value
.. _Tensor.get_value: #dragon.core.tensor.Tensor.get_value .. _Tensor.get_value: #dragon.core.tensor.Tensor.get_value
.. _Tensor.convert_to: #dragon.core.tensor.Tensor.convert_to
.. _Tensor.copy: #dragon.core.tensor.Tensor.copy .. _Tensor.copy: #dragon.core.tensor.Tensor.copy
.. _Tensor.reshape: #dragon.core.tensor.Tensor.reshape .. _Tensor.reshape: #dragon.core.tensor.Tensor.reshape
.. _Tensor.dimshuffle: #dragon.core.tensor.Tensor.dimshuffle .. _Tensor.dimshuffle: #dragon.core.tensor.Tensor.dimshuffle
......
...@@ -70,6 +70,7 @@ List Brief ...@@ -70,6 +70,7 @@ List Brief
`SElu`_ Scaled Exponential Linear Unit function. `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_. `SElu`_ Scaled Exponential Linear Unit function. `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`Softmax`_ *Softmax* function. `Softmax`_ *Softmax* function.
`Dropout`_ Randomly set a unit into zero. `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_. `Dropout`_ Randomly set a unit into zero. `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
`DropPath`_ Randomly set a example of batch into zero. `[Larsson et.al, 2016] <https://arxiv.org/abs/1605.07648>`_.
=============== ====================================================================== =============== ======================================================================
Loss Loss
...@@ -134,7 +135,7 @@ Array ...@@ -134,7 +135,7 @@ Array
=============== ====================================================================== =============== ======================================================================
List Brief List Brief
=============== ====================================================================== =============== ======================================================================
`Gather`_ Gather the input according to the indices along the given axis. `IndexSelect`_ Select the elements according to the indices along the given axis.
`Reduce`_ Reduce the inputs along the axis in given axes. `Reduce`_ Reduce the inputs along the axis in given axes.
`Sum`_ Compute the sum along the given axis. `Sum`_ Compute the sum along the given axis.
`Mean`_ Compute the mean along the given axis. `Mean`_ Compute the mean along the given axis.
...@@ -241,6 +242,7 @@ List Brief ...@@ -241,6 +242,7 @@ List Brief
.. _SElu: operators/activation.html#dragon.operators.activation.SElu .. _SElu: operators/activation.html#dragon.operators.activation.SElu
.. _Softmax: operators/activation.html#dragon.operators.activation.Softmax .. _Softmax: operators/activation.html#dragon.operators.activation.Softmax
.. _Dropout: operators/activation.html#dragon.operators.activation.Dropout .. _Dropout: operators/activation.html#dragon.operators.activation.Dropout
.. _DropPath: operators/activation.html#dragon.operators.activation.DropPath
.. _NLLLoss: operators/loss.html#dragon.operators.loss.NLLLoss .. _NLLLoss: operators/loss.html#dragon.operators.loss.NLLLoss
.. _SparseSoftmaxCrossEntropy: operators/loss.html#dragon.operators.loss.SparseSoftmaxCrossEntropy .. _SparseSoftmaxCrossEntropy: operators/loss.html#dragon.operators.loss.SparseSoftmaxCrossEntropy
...@@ -281,7 +283,7 @@ List Brief ...@@ -281,7 +283,7 @@ List Brief
.. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm .. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm
.. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm .. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm
.. _Gather: operators/array.html#dragon.operators.array.Gather .. _IndexSelect: operators/array.html#dragon.operators.array.IndexSelect
.. _Crop: operators/array.html#dragon.operators.array.Crop .. _Crop: operators/array.html#dragon.operators.array.Crop
.. _Reduce: operators/array.html#dragon.operators.array.Reduce .. _Reduce: operators/array.html#dragon.operators.array.Reduce
.. _Sum: operators/array.html#dragon.operators.array.Sum .. _Sum: operators/array.html#dragon.operators.array.Sum
......
...@@ -74,13 +74,6 @@ which will enhance all frameworks in the VirtualBox. ...@@ -74,13 +74,6 @@ which will enhance all frameworks in the VirtualBox.
|para| We remove the mechanism of `SharedVaraible`_ due to the **memory-storage** is taken by the backend. |para| We remove the mechanism of `SharedVaraible`_ due to the **memory-storage** is taken by the backend.
Following the `Caffe2`_ and `TensorFlow`_, we attribute it to the **Feed** of data streams. Following the `Caffe2`_ and `TensorFlow`_, we attribute it to the **Feed** of data streams.
|sectitle| □ |nbsp| `Scan`_
|para| We use this primitive to create the dynamic computation graphs.
|para| By taking a template of the sub-graph, `Scan`_ unfolds it for specific loop steps,
which is very useful to model sentence-level **Recurrent Neural Networks**.
|context| For detailed Documentation, see: `Compile`_. |context| For detailed Documentation, see: `Compile`_.
|paratitle| **Tensor** |paratitle| **Tensor**
...@@ -148,7 +141,6 @@ We are sorry for removing some odd implementations supported by the original `Th ...@@ -148,7 +141,6 @@ We are sorry for removing some odd implementations supported by the original `Th
.. _Function: theano/compile.html#dragon.vm.theano.compile.function.function .. _Function: theano/compile.html#dragon.vm.theano.compile.function.function
.. _Shared: theano/compile.html#dragon.vm.theano.compile.sharedvalue.shared .. _Shared: theano/compile.html#dragon.vm.theano.compile.sharedvalue.shared
.. _Scan: theano/compile.html#dragon.vm.theano.compile.scan.scan
.. _FeedTensor: ../core/workspace.html#dragon.core.workspace.FeedTensor .. _FeedTensor: ../core/workspace.html#dragon.core.workspace.FeedTensor
.. _SharedVaraible: http://deeplearning.net/software/theano/library/compile/shared.html .. _SharedVaraible: http://deeplearning.net/software/theano/library/compile/shared.html
......
...@@ -14,7 +14,6 @@ List Brief ...@@ -14,7 +14,6 @@ List Brief
============================== ======================================================================= ============================== =======================================================================
`function`_ Return a callable function that will compute outputs. `function`_ Return a callable function that will compute outputs.
`shared`_ Construct a Tensor initialized with numerical values. `shared`_ Construct a Tensor initialized with numerical values.
`scan`_ Run a dynamic loop of the given one step function.
============================== ======================================================================= ============================== =======================================================================
...@@ -27,9 +26,6 @@ API Reference ...@@ -27,9 +26,6 @@ API Reference
.. automodule:: dragon.vm.theano.compile.sharedvalue .. automodule:: dragon.vm.theano.compile.sharedvalue
:members: :members:
.. automodule:: dragon.vm.theano.compile.scan
:members:
.. _config.SetDebugMode(*args, **kwargs): ../../config.html#dragon.config.SetDebugMode .. _config.SetDebugMode(*args, **kwargs): ../../config.html#dragon.config.SetDebugMode
.. _memonger.share_grads(*args, **kwargs): ../../memonger.html#dragon.memonger.share_grads .. _memonger.share_grads(*args, **kwargs): ../../memonger.html#dragon.memonger.share_grads
.. _config.EnableCPU(): ../../config.html#dragon.config.EnableCPU .. _config.EnableCPU(): ../../config.html#dragon.config.EnableCPU
...@@ -38,7 +34,4 @@ API Reference ...@@ -38,7 +34,4 @@ API Reference
.. _T.grad(*args, **kwargs): tensor.html#dragon.vm.theano.gradient.grad .. _T.grad(*args, **kwargs): tensor.html#dragon.vm.theano.gradient.grad
.. _function: #dragon.vm.theano.compile.function.function .. _function: #dragon.vm.theano.compile.function.function
.. _shared: #dragon.vm.theano.compile.sharedvalue.shared .. _shared: #dragon.vm.theano.compile.sharedvalue.shared
.. _scan: #dragon.vm.theano.compile.scan.scan \ No newline at end of file
...@@ -45,13 +45,9 @@ class CPUContext { ...@@ -45,13 +45,9 @@ class CPUContext {
/*! \brief Malloc the memory */ /*! \brief Malloc the memory */
static void* New(size_t nbytes) { static void* New(size_t nbytes) {
void* data; void* data = malloc(nbytes);
#ifdef WITH_CUDA_HOST_MEM CHECK(data) << "\nMalloc mem: "
CUDA_CHECK(cudaMallocHost(&data, nbytes)); << nbytes << " bytes failed.";
#else
data = malloc(nbytes);
#endif
CHECK(data) << "\nMalloc mem: " << nbytes << " bytes failed.";
return data; return data;
} }
......
...@@ -111,7 +111,7 @@ class CNMLContext { ...@@ -111,7 +111,7 @@ class CNMLContext {
static std::mutex& mutex() { static std::mutex m; return m; } static std::mutex& mutex() { static std::mutex m; return m; }
/*! \brief Return the thread local cnrt object */ /*! \brief Return the thread local cnrt object */
static CNRTObject* cnrt_object(); static CNRTObject* obj();
private: private:
int device_id_, stream_id_ = 1, random_seed_; int device_id_, stream_id_ = 1, random_seed_;
......
...@@ -41,7 +41,7 @@ class GraphOptimizer { ...@@ -41,7 +41,7 @@ class GraphOptimizer {
/*! \brief Plan the recomputing for inputs (-O3) */ /*! \brief Plan the recomputing for inputs (-O3) */
GraphDef MirrorStage( GraphDef MirrorStage(
const GraphDef& input_def, const GraphDef& input_def,
Map< string, vector<int> >& op_indices); Map< string, vec32_t >& op_indices);
/*! \brief Allocate the buffer for outputs (-O3) */ /*! \brief Allocate the buffer for outputs (-O3) */
GraphDef SimulateGC(const GraphDef& input_def); GraphDef SimulateGC(const GraphDef& input_def);
......
...@@ -37,16 +37,16 @@ class OperatorBase { ...@@ -37,16 +37,16 @@ class OperatorBase {
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
/*! \brief Return the specified input tensor */ /*! \brief Return the specified input tensor */
Tensor& Input(int idx); Tensor& X(int i);
/*! \brief Return the specified output tensor */ /*! \brief Return the specified output tensor */
Tensor* Output(int idx); Tensor* Y(int i);
/*! \brief Return the number of inputs */ /*! \brief Return the number of inputs */
int InputSize() { return (int)inputs_.size(); } int XSize() { return (int)inputs_.size(); }
/*! \brief Return the number of outputs */ /*! \brief Return the number of outputs */
int OutputSize() { return (int)outputs_.size(); } int YSize() { return (int)outputs_.size(); }
/*! \brief Modify this operator according to the given def */ /*! \brief Modify this operator according to the given def */
void UpdateFrom(const OperatorDef& def); void UpdateFrom(const OperatorDef& def);
...@@ -72,8 +72,14 @@ class OperatorBase { ...@@ -72,8 +72,14 @@ class OperatorBase {
/*! \brief Return the anchor name of this operator */ /*! \brief Return the anchor name of this operator */
const string& anchor() const { return anchor_; } const string& anchor() const { return anchor_; }
/*! \brief Return the mount name in this operator */ /*! \brief Return the data type of this operator */
const string mount_name(const string& name) const { const string& dtype() const { return dtype_; }
/*! \brief Return the data format of this operator */
const string& data_format() const { return data_format_; }
/*! \brief Return the unique name in this operator */
const string unique_name(const string& name) const {
return "/mnt/" + anchor_ + "/" + name; return "/mnt/" + anchor_ + "/" + name;
} }
...@@ -110,23 +116,24 @@ class OperatorBase { ...@@ -110,23 +116,24 @@ class OperatorBase {
/*! \brief Return the debug string of the stored operator def */ /*! \brief Return the debug string of the stored operator def */
string DebugString() const { return def_.DebugString(); } string DebugString() const { return def_.DebugString(); }
/*! \brief Return the debug DType string on given tensor */ /*! \brief Return the dtype string according to given tensor */
string DTypeHelper( string DTypeString(
const Tensor& tensor, const Tensor& tensor,
const Set<string>& dtypes) const; const Set<string>& dtypes) const;
/* \brief Return the debug DType string on given type */ /* \brief Return the dtype string according to given type */
string DTypeHelper( string DTypeString(
const string& dtype, const string& dtype,
const Set<string>& dtypes) const; const Set<string>& dtypes) const;
protected: protected:
string phase_, anchor_; Workspace* ws_;
Map<std::string, const Argument*> args_; OperatorDef def_;
SubGraph subgraph_; SubGraph subgraph_;
string phase_, anchor_;
string dtype_, data_format_;
vector<Tensor*> inputs_, outputs_; vector<Tensor*> inputs_, outputs_;
OperatorDef def_; Map<string, const Argument*> args_;
Workspace* ws_;
}; };
template <class Context> template <class Context>
...@@ -134,29 +141,30 @@ class Operator : public OperatorBase { ...@@ -134,29 +141,30 @@ class Operator : public OperatorBase {
public: public:
/*! \brief Default constructor */ /*! \brief Default constructor */
Operator(const OperatorDef& def, Workspace* ws) Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws), ctx_(def.device_option()), : OperatorBase(def, ws),
allow_recomputing_(OperatorBase::Arg<bool>( ctx_(def.device_option()),
"allow_recomputing", false)),
do_sync_(OperatorBase::Arg<bool>( do_sync_(OperatorBase::Arg<bool>(
"do_sync", false)) { "do_sync", false)),
allow_recomp_(OperatorBase::Arg<bool>(
"allow_recomp", false)) {
allow_run_ = true; allow_run_ = true;
allow_run_ &= MPICheck(); allow_run_ &= MPICheck();
allow_run_ &= (!(OutputSize() == 1 && allow_run_ &= (!(YSize() == 1 &&
Output(0)->name() == "NULL")); Y(0)->name() == "NULL"));
} }
/*! \brief Run this operator on the specified stream */ /*! \brief Run this operator on the specified stream */
void Run(int stream_id = 0) final { void Run(int stream_id = 0) final {
if (!allow_run_) return; if (!allow_run_) return;
if (allow_recomputing_) PrepareResource(); if (allow_recomp_) PrepareResource();
ctx()->SwitchToDevice(stream_id); ctx()->SwitchToDevice(stream_id);
MemorySwitch(); MemorySwitch();
RunOnDevice(); RunOnDevice();
if (do_sync_ || stream_id > 0) { if (do_sync_ || stream_id > 0) {
// We will sync the stream 0 at the specific time // Sync the stream(0) at the specific time
ctx()->FinishDeviceCompution(); ctx()->FinishDeviceCompution();
} }
if (allow_recomputing_) ReleaseResource(); if (allow_recomp_) ReleaseResource();
} }
/*! \brief Prepare the content of inputs */ /*! \brief Prepare the content of inputs */
...@@ -187,7 +195,7 @@ class Operator : public OperatorBase { ...@@ -187,7 +195,7 @@ class Operator : public OperatorBase {
protected: protected:
/*! \brief Store the internal context */ /*! \brief Store the internal context */
Context ctx_; Context ctx_;
bool allow_run_, allow_recomputing_, do_sync_; bool allow_run_, allow_recomp_, do_sync_;
private: private:
/*! \brief Check the MPI conditions */ /*! \brief Check the MPI conditions */
...@@ -195,7 +203,7 @@ class Operator : public OperatorBase { ...@@ -195,7 +203,7 @@ class Operator : public OperatorBase {
#ifndef WITH_MPI #ifndef WITH_MPI
return true; return true;
#else #else
vector<int> allow_ranks = vec32_t allow_ranks =
OperatorBase::Args<int>("mpi_ranks"); OperatorBase::Args<int>("mpi_ranks");
if (allow_ranks.empty()) return true; if (allow_ranks.empty()) return true;
int cur_rank; int cur_rank;
...@@ -215,25 +223,30 @@ OperatorBase* NewOperator( ...@@ -215,25 +223,30 @@ OperatorBase* NewOperator(
/*! Macros */ /*! Macros */
#define USE_SIMPLE_CTOR_DTOR(name) \ #define OpArg OperatorBase::Arg
#define OpArgs OperatorBase::Args
#define SIMPLE_CTOR_DTOR(name) \
name(const OperatorDef& def, Workspace* ws) \ name(const OperatorDef& def, Workspace* ws) \
: Operator<Context>(def, ws) {} \ : Operator<Context>(def, ws) {} \
virtual ~name() {} virtual ~name() {}
#define USE_OPERATOR_BASE_FUNCTIONS \ #define USE_OPERATOR_BASE_FUNCTIONS \
using OperatorBase::Input; \
using OperatorBase::Output; \
using OperatorBase::ws; \ using OperatorBase::ws; \
using OperatorBase::name; \ using OperatorBase::name; \
using OperatorBase::type; \ using OperatorBase::type; \
using OperatorBase::phase; \ using OperatorBase::phase; \
using OperatorBase::anchor; \ using OperatorBase::anchor; \
using OperatorBase::mount_name; \ using OperatorBase::dtype; \
using OperatorBase::data_format; \
using OperatorBase::unique_name; \
using OperatorBase::def; \ using OperatorBase::def; \
using OperatorBase::InputSize; \ using OperatorBase::X; \
using OperatorBase::OutputSize; \ using OperatorBase::Y; \
using OperatorBase::XSize; \
using OperatorBase::YSize; \
using OperatorBase::DebugString; \ using OperatorBase::DebugString; \
using OperatorBase::DTypeHelper; \ using OperatorBase::DTypeString; \
using OperatorBase::SwitchToPhase using OperatorBase::SwitchToPhase
#define USE_OPERATOR_FUNCTIONS \ #define USE_OPERATOR_FUNCTIONS \
...@@ -322,63 +335,63 @@ DECLARE_REGISTRY( ...@@ -322,63 +335,63 @@ DECLARE_REGISTRY(
name = mp->template data<T, Context>(); \ name = mp->template data<T, Context>(); \
} }
#define DECLARE_ARGUMENT_WITH_DESC(type, argument) \ #define DECLARE_ARG_WITH_DESC(type, arg) \
type argument##_value; \ type arg##_; \
string argument##_desc; \ string arg##_desc_; \
type argument() type arg()
#define DECLARE_ARGUMENTS_WITH_DESC(type, argument) \ #define DECLARE_ARGS_WITH_DESC(type, arg) \
vector<type> argument##_value; \ vector<type> arg##_; \
vector<string> argument##_desc; \ vector<string> arg##_desc_; \
type argument(int idx) type arg(int i)
#define GET_ARGUMENT_WITH_DESC(type, argument, default_value) \ #define GET_ARG_WITH_DESC(type, arg, default_value) \
argument##_value = OperatorBase::Arg<type>(#argument, default_value); \ arg##_ = OpArg<type>(#arg, default_value); \
argument##_desc = OperatorBase::Arg<string>(string(#argument) + "_desc", "") arg##_desc_ = OpArg<string>(string(#arg) + "_desc", "")
#define GET_ARGUMENTS_WITH_DESC(type, argument) \ #define GET_ARGS_WITH_DESC(type, arg) \
argument##_value = OperatorBase::Args<type>(#argument); \ arg##_ = OpArgs<type>(#arg); \
argument##_desc = OperatorBase::Args<string>(string(#argument) + "_desc") arg##_desc_ = OpArgs<string>(string(#arg) + "_desc")
#define DEFINE_ARGUMENT_WITH_DESC(type, classname, argument) \ #define DEFINE_ARG_WITH_DESC(type, classname, arg) \
template <class Context> \ template <class Context> \
type classname<Context>::argument() { \ type classname<Context>::arg() { \
if (argument##_desc.empty()) return argument##_value; \ if (arg##_desc_.empty()) return arg##_; \
Tensor* argument##_tensor = ws()->GetTensor(argument##_desc); \ auto* arg##T = ws()->GetTensor(arg##_desc_); \
CHECK(argument##_tensor->IsType<type>()) \ CHECK(arg##T->template IsType<type>()) \
<< "\nThe type of " << #argument << " should be " << #type << "."; \ << "\nThe type of " << #arg << " should be " << #type << "."; \
CHECK_EQ(argument##_tensor->count(), 1) \ CHECK_EQ(arg##T->count(), 1) \
<< "\nThe argument of " << #argument << " should be a scalar."; \ << "\nThe argument of " << #arg << " should be a scalar."; \
return argument##_tensor->template data<type, CPUContext>()[0]; \ return arg##T->template data<type, CPUContext>()[0]; \
} }
#define DEFINE_ARGUMENTS_WITH_DESC(type, classname, argument) \ #define DEFINE_ARGS_WITH_DESC(type, classname, arg) \
template <class Context> \ template <class Context> \
type classname<Context>::argument(int idx) { \ type classname<Context>::arg(int i) { \
if (argument##_desc.empty()) { \ if (arg##_desc_.empty()) { \
CHECK_LT(idx, argument##_value.size()) \ CHECK_LT(i, arg##_.size()) \
<< "\nExcepted the size of " << #argument \ << "\nExcepted the size of " << #arg \
<< " > " << idx << ". (Got " \ << " > " << i << ". (Got " \
<< argument##_value.size() << ")."; \ << arg##_.size() << ")."; \
return argument##_value[idx]; \ return arg##_[i]; \
} \ } \
CHECK_LT(idx, argument##_desc.size()) \ CHECK_LT(i, arg##_desc_.size()) \
<< "\nExcepted the size of " << #argument \ << "\nExcepted the size of " << #arg \
<< " > " << idx << ". (Got " \ << " > " << i << ". (Got " \
<< argument##_desc.size() << ")."; \ << arg##_desc_.size() << ")."; \
Tensor* argument##_tensor = ws()->GetTensor( \ auto* arg##T = ws()->GetTensor( \
str::replace_first(argument##_desc[idx], \ str::replace_first(arg##_desc_[i], \
"${ANCHOR}", anchor())); \ "${ANCHOR}", anchor())); \
CHECK(argument##_tensor->IsType<type>()) \ CHECK(arg##T->template IsType<type>()) \
<< "\nThe type of " << #argument << " should be " << #type << "."; \ << "\nThe type of " << #arg << " should be " << #type << "."; \
CHECK_EQ(argument##_tensor->count(), 1) \ CHECK_EQ(arg##T->count(), 1) \
<< "\nThe argument of " << #argument << " at pos(" \ << "\nThe argument of " << #arg << " at pos(" \
<< idx << ") should be a scalar."; \ << i << ") should be a scalar."; \
return argument##_tensor->template data<type, CPUContext>()[0]; \ return arg##T->template data<type, CPUContext>()[0]; \
} }
#define GET_ARGUMENTS_SIZE(argument) \ #define GET_ARGS_SIZE(arg) \
(int)std::max(argument##_value.size(), argument##_desc.size()) (int)std::max(arg##_.size(), arg##_desc_.size())
#define XIsType(x, dtype) \ #define XIsType(x, dtype) \
x.template IsType<dtype>() x.template IsType<dtype>()
......
...@@ -45,7 +45,7 @@ class GradientMakerBase { ...@@ -45,7 +45,7 @@ class GradientMakerBase {
virtual bool CopyArguments() const { return true; } virtual bool CopyArguments() const { return true; }
virtual Gradient Make() { virtual Gradient Make() {
vector<OperatorDef> new_defs = MakeDefs(); vector<OperatorDef> new_defs = MakeDef();
if (def.has_uid()) { if (def.has_uid()) {
// Attach the anchor to the name if having UID // Attach the anchor to the name if having UID
for (int i = 0; i < new_defs.size(); i++) for (int i = 0; i < new_defs.size(); i++)
...@@ -60,8 +60,7 @@ class GradientMakerBase { ...@@ -60,8 +60,7 @@ class GradientMakerBase {
return Gradient(new_defs, g_inputs_, DefaultValues()); return Gradient(new_defs, g_inputs_, DefaultValues());
}; };
virtual vector<OperatorDef> MakeDefs() { virtual vector<OperatorDef> MakeDef() {
NOT_IMPLEMENTED;
return vector<OperatorDef>(); return vector<OperatorDef>();
} }
...@@ -106,18 +105,21 @@ Gradient MakeGradientForOp( ...@@ -106,18 +105,21 @@ Gradient MakeGradientForOp(
const OperatorDef& op_def, const OperatorDef& op_def,
const vector<string>& g_outputs); const vector<string>& g_outputs);
# define GRADIENT_MAKER_CTOR(name) \ #define GRADIENT_MAKER_CTOR(name) \
name(const OperatorDef& def, const vector<string>& g_output) \ name(const OperatorDef& def, \
const vector<string>& g_output) \
: GradientMakerBase(def, g_output) {} : GradientMakerBase(def, g_output) {}
class NoGradient : public GradientMakerBase { class NoGradient : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(NoGradient); GRADIENT_MAKER_CTOR(NoGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDef() override {
return vector<OperatorDef>(); return vector<OperatorDef>();
} }
}; };
namespace {
// Here we define some common gradient makers // Here we define some common gradient makers
// Reuse them to make the codes cleaner // Reuse them to make the codes cleaner
...@@ -131,7 +133,7 @@ class SimpleGradientMaker final : public GradientMakerBase { ...@@ -131,7 +133,7 @@ class SimpleGradientMaker final : public GradientMakerBase {
* *
*/ */
GRADIENT_MAKER_CTOR(SimpleGradientMaker); GRADIENT_MAKER_CTOR(SimpleGradientMaker);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDef() override {
vector<string> inputs, outputs; vector<string> inputs, outputs;
for (const auto& input : def.input()) { for (const auto& input : def.input()) {
inputs.push_back(input); inputs.push_back(input);
...@@ -155,7 +157,7 @@ class InplaceGradientMaker final : public GradientMakerBase { ...@@ -155,7 +157,7 @@ class InplaceGradientMaker final : public GradientMakerBase {
* *
*/ */
GRADIENT_MAKER_CTOR(InplaceGradientMaker); GRADIENT_MAKER_CTOR(InplaceGradientMaker);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDef() override {
return SingleDef( return SingleDef(
def.type() + "Gradient", /*! OpType */ def.type() + "Gradient", /*! OpType */
"", /*! OpName */ "", /*! OpName */
...@@ -164,17 +166,21 @@ class InplaceGradientMaker final : public GradientMakerBase { ...@@ -164,17 +166,21 @@ class InplaceGradientMaker final : public GradientMakerBase {
} }
}; };
} // namespace
DECLARE_REGISTRY( DECLARE_REGISTRY(
GradientRegistry, GradientRegistry,
GradientMakerBase, GradientMakerBase,
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&
);
DECLARE_REGISTRY( DECLARE_REGISTRY(
NoGradientRegistry, NoGradientRegistry,
GradientMakerBase, GradientMakerBase,
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&
);
// Defined in the operator.cc // Defined in the operator.cc
#define REGISTER_GRADIENT(name, ...) \ #define REGISTER_GRADIENT(name, ...) \
......
...@@ -27,15 +27,15 @@ class Tensor { ...@@ -27,15 +27,15 @@ class Tensor {
Tensor(const string& name) : name_(name) {} Tensor(const string& name) : name_(name) {}
/*! \brief Constructor with the known int64 dimensions */ /*! \brief Constructor with the known int64 dimensions */
Tensor(const vector<int64_t>& dims) { Reshape(dims); } Tensor(const vec64_t& dims) { Reshape(dims); }
/*! \brief Constructor with the known int32 dimensions */ /*! \brief Constructor with the known int32 dimensions */
Tensor(const vector<int>& dims) { Tensor(const vec32_t& dims) {
Reshape(vector<int64_t>(dims.begin(), dims.end())); Reshape(vec64_t(dims.begin(), dims.end()));
} }
/*! \brief Reshape to the given dimensions */ /*! \brief Reshape to the given dimensions */
Tensor* Reshape(const vector<int64_t>& dims) { Tensor* Reshape(const vec64_t& dims) {
dims_ = dims; strides_.resize(dims.size()); dims_ = dims; strides_.resize(dims.size());
size_t new_size = 1; int64_t d; size_t new_size = 1; int64_t d;
for (int i = (int)dims.size() - 1; i >= 0; i--) { for (int i = (int)dims.size() - 1; i >= 0; i--) {
...@@ -61,7 +61,9 @@ class Tensor { ...@@ -61,7 +61,9 @@ class Tensor {
} }
/*! \brief Reshape the dimensions like the given tensor */ /*! \brief Reshape the dimensions like the given tensor */
Tensor* ReshapeLike(const Tensor& other) { return Reshape(other.dims_); } Tensor* ReshapeLike(const Tensor& other) {
return Reshape(other.dims_);
}
/*! \brief Return the tensor name */ /*! \brief Return the tensor name */
const string& name() const { return name_; } const string& name() const { return name_; }
...@@ -83,7 +85,7 @@ class Tensor { ...@@ -83,7 +85,7 @@ class Tensor {
int64_t dim(int64_t i) const{ return dims_[axis(i)]; } int64_t dim(int64_t i) const{ return dims_[axis(i)]; }
/*! \brief Return all the dimensions */ /*! \brief Return all the dimensions */
const vector<int64_t>& dims() const { return dims_; } const vec64_t& dims() const { return dims_; }
/*! \brief Return the total number of elements of this tensor */ /*! \brief Return the total number of elements of this tensor */
size_t size() const { return size_; } size_t size() const { return size_; }
...@@ -111,7 +113,7 @@ class Tensor { ...@@ -111,7 +113,7 @@ class Tensor {
int64_t stride(int64_t i) const { return strides_[axis(i)]; } int64_t stride(int64_t i) const { return strides_[axis(i)]; }
/*! \brief Return all the strides */ /*! \brief Return all the strides */
const vector<int64_t>& strides() const { return strides_; } const vec64_t& strides() const { return strides_; }
/*! \brief Return a string to describe the given dimensions */ /*! \brief Return a string to describe the given dimensions */
static string DimString(const vector<int64_t>& dims) { static string DimString(const vector<int64_t>& dims) {
...@@ -178,15 +180,16 @@ class Tensor { ...@@ -178,15 +180,16 @@ class Tensor {
if (TypeMeta::Id<Context>() == if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CPUContext>()) { TypeMeta::Id<CPUContext>()) {
*data_ptr = mem->mutable_cpu_data(nbytes()); *data_ptr = mem->mutable_cpu_data(nbytes());
} else if (TypeMeta::Id<Context>() == } else if (
TypeMeta::Id<Context>() ==
TypeMeta::Id<CUDAContext>()) { TypeMeta::Id<CUDAContext>()) {
*data_ptr = mem->mutable_cuda_data(nbytes()); *data_ptr = mem->mutable_cuda_data(nbytes());
} else if (TypeMeta::Id<Context>() == } else if (
TypeMeta::Id<Context>() ==
TypeMeta::Id<CNMLContext>()) { TypeMeta::Id<CNMLContext>()) {
*data_ptr = mem->mutable_cnml_data(); *data_ptr = mem->mutable_cnml_data();
} else { } else {
LOG(FATAL) << "Unknown memory type.\n" LOG(FATAL) << "Unknown memory type.";
<< "Only CPU, CUDA and CNML are supported.";
} }
} }
} }
...@@ -199,15 +202,16 @@ class Tensor { ...@@ -199,15 +202,16 @@ class Tensor {
if (TypeMeta::Id<Context>() == if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CPUContext>()) { TypeMeta::Id<CPUContext>()) {
return mem->cpu_data(nbytes()); return mem->cpu_data(nbytes());
} else if (TypeMeta::Id<Context>() == } else if (
TypeMeta::Id<Context>() ==
TypeMeta::Id<CUDAContext>()) { TypeMeta::Id<CUDAContext>()) {
return mem->cuda_data(nbytes()); return mem->cuda_data(nbytes());
} else if (TypeMeta::Id<Context>() == } else if (
TypeMeta::Id<Context>() ==
TypeMeta::Id<CNMLContext>()) { TypeMeta::Id<CNMLContext>()) {
return mem->cnml_data(); return mem->cnml_data();
} else { } else {
LOG(FATAL) << "Unknown memory type.\n" LOG(FATAL) << "Unknown memory type.";
<< "Only CPU, CUDA, and CNML are supported.";
return nullptr; return nullptr;
} }
} }
...@@ -268,8 +272,9 @@ class Tensor { ...@@ -268,8 +272,9 @@ class Tensor {
return static_cast<T*>(data_ptr); return static_cast<T*>(data_ptr);
} }
} }
return static_cast<T*>(raw_mutable_data return static_cast<T*>(
<Context>(TypeMeta::Make<T>())); raw_mutable_data<Context>
(TypeMeta::Make<T>()));
} }
/*! \brief Get the typed const data pointer */ /*! \brief Get the typed const data pointer */
...@@ -284,13 +289,15 @@ class Tensor { ...@@ -284,13 +289,15 @@ class Tensor {
/*! \brief Copy the contents from the given tensor */ /*! \brief Copy the contents from the given tensor */
template <class Context> template <class Context>
void CopyFrom(const Tensor& other, Context* ctx) { Tensor* CopyFrom(const Tensor& other, Context* ctx) {
if ((void*)&other == (void*)this) return; if ((void*)&other == (void*)this) return this;
CHECK_EQ(size_, other.size_); CHECK_EQ(size_, other.size_);
auto* src = other.template raw_data<Context>(); auto* src = other.template raw_data<Context>();
auto* dst = raw_mutable_data<Context>(other.meta_); auto* dst = raw_mutable_data<Context>(other.meta_);
ctx->template MemcpyAsync<Context, Context>( ctx->template MemcpyAsync
nbytes(), dst, src); <Context, Context>(
nbytes(), dst, src);
return this;
} }
/*! \brief Move the external memory */ /*! \brief Move the external memory */
...@@ -337,7 +344,7 @@ class Tensor { ...@@ -337,7 +344,7 @@ class Tensor {
int version_ = -1; int version_ = -1;
/*! \brief Store the dimensions and strides */ /*! \brief Store the dimensions and strides */
vector<int64_t> dims_, strides_; vec64_t dims_, strides_;
/*! \brief The internal memory */ /*! \brief The internal memory */
shared_ptr<MixedMemory> memory_; shared_ptr<MixedMemory> memory_;
......
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
#define DRAGON_CORE_TYPES_H_ #define DRAGON_CORE_TYPES_H_
#include <cstdint> #include <cstdint>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include "core/typeid.h" #include "core/typeid.h"
namespace dragon { namespace dragon {
typedef std::vector<int> vec32_t;
typedef std::vector<int64_t> vec64_t;
#ifdef _MSC_VER #ifdef _MSC_VER
typedef struct __declspec(align(2)) { typedef struct __declspec(align(2)) {
......
...@@ -80,13 +80,13 @@ class Workspace { ...@@ -80,13 +80,13 @@ class Workspace {
/*! \brief Return the specified filler */ /*! \brief Return the specified filler */
const TensorFillerProto* GetFiller(const string& name) const; const TensorFillerProto* GetFiller(const string& name) const;
/*! \brief Create temporal cache segments */ /*! \brief Create temporal data segments */
template <class Context> template <class Context>
vector<void*> caches(const vector<size_t>& segments) { vector<void*> data(const vector<size_t>& segments) {
int64_t nbytes = 0; int64_t nbytes = 0;
vector<void*> ret(segments.size()); vector<void*> ret(segments.size());
for (auto& segment : segments) nbytes += (int64_t)segment; for (auto& segment : segments) nbytes += (int64_t)segment;
auto* T = CreateTensor("/share/cache")->Reshape({ nbytes }); auto* T = CreateTensor("/share/data")->Reshape({ nbytes });
ret[0] = T->template mutable_data<uint8_t, Context>(); ret[0] = T->template mutable_data<uint8_t, Context>();
for (int i = 1; i < segments.size(); i++) for (int i = 1; i < segments.size(); i++)
ret[i] = (uint8_t*)ret[i - 1] + segments[i - 1]; ret[i] = (uint8_t*)ret[i - 1] + segments[i - 1];
...@@ -95,12 +95,12 @@ class Workspace { ...@@ -95,12 +95,12 @@ class Workspace {
/*! \brief Create temporal cache segments with the specified type */ /*! \brief Create temporal cache segments with the specified type */
template <typename T, class Context> template <typename T, class Context>
vector<T*> caches(const vector<int64_t>& segments) { vector<T*> data(const vector<int64_t>& segments) {
vector<size_t> segments_in_byte; vector<size_t> segments_in_byte;
vector<T*> ret(segments.size()); vector<T*> ret(segments.size());
for (const auto& e : segments) for (const auto& e : segments)
segments_in_byte.emplace_back(e * sizeof(T)); segments_in_byte.emplace_back(e * sizeof(T));
auto ret_in_byte = caches<Context>(segments_in_byte); auto ret_in_byte = data<Context>(segments_in_byte);
for (int i = 0; i < segments.size(); i++) for (int i = 0; i < segments.size(); i++)
ret[i] = (T*)ret_in_byte[i]; ret[i] = (T*)ret_in_byte[i];
return ret; return ret;
......
...@@ -23,18 +23,18 @@ class DropoutOp final : public Operator<Context> { ...@@ -23,18 +23,18 @@ class DropoutOp final : public Operator<Context> {
public: public:
DropoutOp(const OperatorDef& def, Workspace* ws) DropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
use_scale(OperatorBase::Arg<bool>("scale", true)) { use_scale_(OpArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5f); SwitchToPhase(OpArg<string>("phase", ""));
SwitchToPhase(OperatorBase::Arg<string>("phase", "")); GET_ARG_WITH_DESC(float, prob, 0.5f);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
bool use_scale; bool use_scale_;
DECLARE_ARGUMENT_WITH_DESC(float, prob); DECLARE_ARG_WITH_DESC(float, prob);
}; };
template <class Context> template <class Context>
...@@ -42,22 +42,22 @@ class DropoutGradientOp final : public Operator<Context> { ...@@ -42,22 +42,22 @@ class DropoutGradientOp final : public Operator<Context> {
public: public:
DropoutGradientOp(const OperatorDef& def, Workspace* ws) DropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
use_scale(OperatorBase::Arg<bool>("scale", true)) { use_scale_(OpArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5f); SwitchToPhase(OpArg<string>("phase", ""));
SwitchToPhase(OperatorBase::Arg<string>("phase", "")); GET_ARG_WITH_DESC(float, prob, 0.5f);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
bool use_scale; bool use_scale_;
DECLARE_ARGUMENT_WITH_DESC(float, prob); DECLARE_ARG_WITH_DESC(float, prob);
}; };
DEFINE_ARGUMENT_WITH_DESC(float, DropoutOp, prob); DEFINE_ARG_WITH_DESC(float, DropoutOp, prob);
DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob); DEFINE_ARG_WITH_DESC(float, DropoutGradientOp, prob);
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -65,68 +65,70 @@ DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob); ...@@ -65,68 +65,70 @@ DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob);
template <class Context> template <class Context>
class CuDNNDropoutOp final : public Operator<Context> { class CuDNNDropoutOp final : public Operator<Context> {
public: public:
CuDNNDropoutOp(const OperatorDef& def, Workspace* ws) CuDNNDropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), states_initialized(false), : Operator<Context>(def, ws),
use_scale(OperatorBase::Arg<bool>("scale", true)), states_initialized_(false),
random_seed(DEFAULT_RNG_SEED) { rng_seed_(DEFAULT_RNG_SEED),
GET_ARGUMENT_WITH_DESC(float, prob, 0.5f); use_scale_(OpArg<bool>("scale", true)) {
SwitchToPhase(OperatorBase::Arg<string>("phase", "")); SwitchToPhase(OpArg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); GET_ARG_WITH_DESC(float, prob, 0.5f);
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNDropoutOp() { ~CuDNNDropoutOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc)); CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
bool use_scale, states_initialized; bool use_scale_, states_initialized_;
cudnnTensorDescriptor_t input_desc; cudnnTensorDescriptor_t input_desc_;
cudnnDropoutDescriptor_t dropout_desc; cudnnDropoutDescriptor_t dropout_desc_;
size_t states_size, reserve_space_size; size_t states_size_, reserve_size_;
unsigned long long random_seed; unsigned long long rng_seed_;
DECLARE_ARGUMENT_WITH_DESC(float, prob); DECLARE_ARG_WITH_DESC(float, prob);
}; };
template <class Context> template <class Context>
class CuDNNDropoutGradientOp final : public Operator<Context> { class CuDNNDropoutGradientOp final : public Operator<Context> {
public: public:
CuDNNDropoutGradientOp(const OperatorDef& def, Workspace* ws) CuDNNDropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), states_initialized(false), : Operator<Context>(def, ws),
use_scale(OperatorBase::Arg<bool>("scale", true)), states_initialized_(false),
random_seed(DEFAULT_RNG_SEED) { rng_seed_(DEFAULT_RNG_SEED),
GET_ARGUMENT_WITH_DESC(float, prob, 0.5f); use_scale_(OpArg<bool>("scale", true)) {
SwitchToPhase(OperatorBase::Arg<string>("phase", "")); SwitchToPhase(OpArg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); GET_ARG_WITH_DESC(float, prob, 0.5f);
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNDropoutGradientOp() { ~CuDNNDropoutGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc)); CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
bool use_scale, states_initialized; bool use_scale_, states_initialized_;
cudnnTensorDescriptor_t input_desc; cudnnTensorDescriptor_t input_desc_;
cudnnDropoutDescriptor_t dropout_desc; cudnnDropoutDescriptor_t dropout_desc_;
size_t states_size, reserve_space_size; size_t states_size_, reserve_size_;
unsigned long long random_seed; unsigned long long rng_seed_;
DECLARE_ARGUMENT_WITH_DESC(float, prob); DECLARE_ARG_WITH_DESC(float, prob);
}; };
DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutOp, prob); DEFINE_ARG_WITH_DESC(float, CuDNNDropoutOp, prob);
DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutGradientOp, prob); DEFINE_ARG_WITH_DESC(float, CuDNNDropoutGradientOp, prob);
#endif #endif
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ACTIVATION_DROPPATH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROPPATH_OP_H_
#include "core/operator.h"
#include "utils/math_functions.h"
namespace dragon {
template <class Context>
class DropPathOp final : public Operator<Context> {
public:
DropPathOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
inc_(OpArg<float>("increment", 0.f)) {
SwitchToPhase(OpArg<string>("phase", ""));
GET_ARG_WITH_DESC(float, prob, 0.2f);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunImpl();
protected:
int64_t rows_, cols_;
float inc_, drop_prob_ = 0.f;
DECLARE_ARG_WITH_DESC(float, prob);
};
template <class Context>
class DropPathGradientOp final : public Operator<Context> {
public:
DropPathGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
SwitchToPhase(OpArg<string>("phase", ""));
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunImpl();
protected:
int64_t rows_, cols_;
};
DEFINE_ARG_WITH_DESC(float, DropPathOp, prob);
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_DROPPATH_OP_H_
\ No newline at end of file
...@@ -22,14 +22,14 @@ class EluOp : public Operator<Context> { ...@@ -22,14 +22,14 @@ class EluOp : public Operator<Context> {
public: public:
EluOp(const OperatorDef& def, Workspace* ws) EluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.f)) {} alpha_(OpArg<float>("alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float alpha; float alpha_;
}; };
template <class Context> template <class Context>
...@@ -37,14 +37,14 @@ class EluGradientOp : public Operator<Context> { ...@@ -37,14 +37,14 @@ class EluGradientOp : public Operator<Context> {
public: public:
EluGradientOp(const OperatorDef& def, Workspace* ws) EluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.f)) {} alpha_(OpArg<float>("alpha", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float alpha; float alpha_;
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -56,26 +56,28 @@ class CuDNNEluOp final : public EluOp<Context> { ...@@ -56,26 +56,28 @@ class CuDNNEluOp final : public EluOp<Context> {
public: public:
CuDNNEluOp(const OperatorDef& def, Workspace* ws) CuDNNEluOp(const OperatorDef& def, Workspace* ws)
: EluOp<Context>(def, ws) { : EluOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnSetActivationDescriptor(
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, act_desc_,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha)); CUDNN_ACTIVATION_ELU,
CUDNN_PROPAGATE_NAN,
this->alpha_
));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNEluOp() { ~CuDNNEluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc_;
}; };
template <class Context> template <class Context>
...@@ -83,26 +85,28 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> { ...@@ -83,26 +85,28 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> {
public: public:
CuDNNEluGradientOp(const OperatorDef& def, Workspace* ws) CuDNNEluGradientOp(const OperatorDef& def, Workspace* ws)
: EluGradientOp<Context>(def, ws) { : EluGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnSetActivationDescriptor(
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, act_desc_,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha)); CUDNN_ACTIVATION_ELU,
CUDNN_PROPAGATE_NAN,
this->alpha_
));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNEluGradientOp() { ~CuDNNEluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc_;
}; };
#endif #endif
......
...@@ -22,16 +22,15 @@ class PReluOp final : public Operator<Context> { ...@@ -22,16 +22,15 @@ class PReluOp final : public Operator<Context> {
public: public:
PReluOp(const OperatorDef& def, Workspace* ws) PReluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
channel_shared(OperatorBase::Arg<bool>("channel_shared", false)), channel_shared_(OpArg<bool>(
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {} "channel_shared", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t channel_shared, channels, dim; int64_t channel_shared_, channels_, dim_;
string data_format;
}; };
template <class Context> template <class Context>
...@@ -39,16 +38,15 @@ class PReluGradientOp final : public Operator<Context> { ...@@ -39,16 +38,15 @@ class PReluGradientOp final : public Operator<Context> {
public: public:
PReluGradientOp(const OperatorDef& def, Workspace* ws) PReluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
channel_shared(OperatorBase::Arg<bool>("channel_shared", false)), channel_shared_(OpArg<bool>(
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {} "channel_shared", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t channel_shared, channels, dim; int64_t channel_shared_, channels_, dim_;
string data_format;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,14 +22,14 @@ class ReluOp : public Operator<Context> { ...@@ -22,14 +22,14 @@ class ReluOp : public Operator<Context> {
public: public:
ReluOp(const OperatorDef& def, Workspace* ws) ReluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
slope(OperatorBase::Arg<float>("slope", 0.f)) {} slope_(OpArg<float>("slope", 0.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float slope; float slope_;
}; };
template <class Context> template <class Context>
...@@ -37,14 +37,14 @@ class ReluGradientOp : public Operator<Context> { ...@@ -37,14 +37,14 @@ class ReluGradientOp : public Operator<Context> {
public: public:
ReluGradientOp(const OperatorDef& def, Workspace* ws) ReluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
slope(OperatorBase::Arg<float>("slope", 0.f)) {} slope_(OpArg<float>("slope", 0.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float slope; float slope_;
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -54,26 +54,27 @@ class CuDNNReluOp final : public ReluOp<Context> { ...@@ -54,26 +54,27 @@ class CuDNNReluOp final : public ReluOp<Context> {
public: public:
CuDNNReluOp(const OperatorDef& def, Workspace* ws) CuDNNReluOp(const OperatorDef& def, Workspace* ws)
: ReluOp<Context>(def, ws) { : ReluOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnSetActivationDescriptor(
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, act_desc_,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_RELU,
CUDNN_PROPAGATE_NAN, 0
));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNReluOp() { ~CuDNNReluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc_;
}; };
template <class Context> template <class Context>
...@@ -81,26 +82,27 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> { ...@@ -81,26 +82,27 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
public: public:
CuDNNReluGradientOp(const OperatorDef& def, Workspace* ws) CuDNNReluGradientOp(const OperatorDef& def, Workspace* ws)
: ReluGradientOp<Context>(def, ws) { : ReluGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnSetActivationDescriptor(
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, act_desc_,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_RELU,
CUDNN_PROPAGATE_NAN, 0
));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNReluGradientOp() { ~CuDNNReluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc_;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -20,21 +20,21 @@ namespace dragon { ...@@ -20,21 +20,21 @@ namespace dragon {
template <class Context> template <class Context>
class SEluOp final : public Operator<Context> { class SEluOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SEluOp); SIMPLE_CTOR_DTOR(SEluOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class SEluGradientOp final : public Operator<Context> { class SEluGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SEluGradientOp); SIMPLE_CTOR_DTOR(SEluGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,21 +20,21 @@ namespace dragon { ...@@ -20,21 +20,21 @@ namespace dragon {
template <class Context> template <class Context>
class SigmoidOp : public Operator<Context> { class SigmoidOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SigmoidOp); SIMPLE_CTOR_DTOR(SigmoidOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class SigmoidGradientOp : public Operator<Context> { class SigmoidGradientOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SigmoidGradientOp); SIMPLE_CTOR_DTOR(SigmoidGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -44,26 +44,27 @@ class CuDNNSigmoidOp final : public SigmoidOp<Context> { ...@@ -44,26 +44,27 @@ class CuDNNSigmoidOp final : public SigmoidOp<Context> {
public: public:
CuDNNSigmoidOp(const OperatorDef& def, Workspace* ws) CuDNNSigmoidOp(const OperatorDef& def, Workspace* ws)
: SigmoidOp<Context>(def, ws) { : SigmoidOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnSetActivationDescriptor(
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, act_desc_,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_SIGMOID,
CUDNN_PROPAGATE_NAN, 0
));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNSigmoidOp() { ~CuDNNSigmoidOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc_;
}; };
template <class Context> template <class Context>
...@@ -71,26 +72,27 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> { ...@@ -71,26 +72,27 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
public: public:
CuDNNSigmoidGradientOp(const OperatorDef& def, Workspace* ws) CuDNNSigmoidGradientOp(const OperatorDef& def, Workspace* ws)
: SigmoidGradientOp<Context>(def, ws) { : SigmoidGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc_));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnSetActivationDescriptor(
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, act_desc_,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_SIGMOID,
CUDNN_PROPAGATE_NAN, 0
));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNSigmoidGradientOp() { ~CuDNNSigmoidGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc_;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -22,14 +22,15 @@ class SoftmaxOp final : public Operator<Context> { ...@@ -22,14 +22,15 @@ class SoftmaxOp final : public Operator<Context> {
public: public:
SoftmaxOp(const OperatorDef& def, Workspace* ws) SoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)) {} axis_(OpArg<int64_t>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; int64_t axis_, axis_dim_;
int64_t outer_dim_, inner_dim_;
}; };
template <class Context> template <class Context>
...@@ -37,14 +38,15 @@ class SoftmaxGradientOp final : public Operator<Context> { ...@@ -37,14 +38,15 @@ class SoftmaxGradientOp final : public Operator<Context> {
public: public:
SoftmaxGradientOp(const OperatorDef& def, Workspace* ws) SoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)) {} axis_(OpArg<int64_t>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; int64_t axis_, axis_dim_;
int64_t outer_dim_, inner_dim_;
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -54,23 +56,21 @@ class CuDNNSoftmaxOp final : public Operator<Context> { ...@@ -54,23 +56,21 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
public: public:
CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws) CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)) { axis_(OpArg<int64_t>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNSoftmaxOp() { ~CuDNNSoftmaxOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; int64_t axis_, outer_dim_, inner_dim_;
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
}; };
template <class Context> template <class Context>
...@@ -78,23 +78,21 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> { ...@@ -78,23 +78,21 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> {
public: public:
CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws) CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)) { axis_(OpArg<int64_t>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNSoftmaxGradientOp() { ~CuDNNSoftmaxGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; int64_t axis_, outer_dim_, inner_dim_;
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -20,21 +20,21 @@ namespace dragon { ...@@ -20,21 +20,21 @@ namespace dragon {
template <class Context> template <class Context>
class TanhOp : public Operator<Context> { class TanhOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(TanhOp); SIMPLE_CTOR_DTOR(TanhOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class TanhGradientOp : public Operator<Context> { class TanhGradientOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(TanhGradientOp); SIMPLE_CTOR_DTOR(TanhGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -44,53 +44,56 @@ class CuDNNTanhOp final : public TanhOp<Context> { ...@@ -44,53 +44,56 @@ class CuDNNTanhOp final : public TanhOp<Context> {
public: public:
CuDNNTanhOp(const OperatorDef& def, Workspace* ws) CuDNNTanhOp(const OperatorDef& def, Workspace* ws)
: TanhOp<Context>(def, ws) { : TanhOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnSetActivationDescriptor(
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, act_desc_,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_TANH,
CUDNN_PROPAGATE_NAN, 0
));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNTanhOp() { ~CuDNNTanhOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc_;
}; };
template <class Context> template <class Context>
class CuDNNTanhGradientOp final : public TanhGradientOp<Context> { class CuDNNTanhGradientOp
final : public TanhGradientOp<Context> {
public: public:
CuDNNTanhGradientOp(const OperatorDef& def, Workspace* ws) CuDNNTanhGradientOp(const OperatorDef& def, Workspace* ws)
: TanhGradientOp<Context>(def, ws) { : TanhGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnSetActivationDescriptor(
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, act_desc_,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_TANH,
CUDNN_PROPAGATE_NAN, 0
));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNTanhGradientOp() { ~CuDNNTanhGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc));
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc_;
cudnnActivationDescriptor_t act_desc; cudnnActivationDescriptor_t act_desc_;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -22,15 +22,17 @@ class AccumulateOp final : public Operator<Context> { ...@@ -22,15 +22,17 @@ class AccumulateOp final : public Operator<Context> {
public: public:
AccumulateOp(const OperatorDef& def, Workspace* ws) AccumulateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.f)), alpha_(OpArg<float>("alpha", 1.f)),
beta(OperatorBase::Arg<float>("beta", 1.f)) {} beta_(OpArg<float>("beta", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(Tensor* X, Tensor* Y);
template <typename T>
void RunImpl(Tensor* X, Tensor* Y);
protected: protected:
float alpha, beta; float alpha_, beta_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,16 +22,16 @@ class AffineOp final : public Operator<Context> { ...@@ -22,16 +22,16 @@ class AffineOp final : public Operator<Context> {
public: public:
AffineOp(const OperatorDef& def, Workspace* ws) AffineOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
num_axes(OperatorBase::Arg<int64_t>("num_axes", 1)) {} num_axes_(OpArg<int64_t>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, num_axes; int64_t outer_dim_, inner_dim_;
int64_t outer_dim, scale_dim, inner_dim; int64_t axis_, num_axes_, scale_dim_;
}; };
template <class Context> template <class Context>
...@@ -39,19 +39,18 @@ class AffineGradientOp final : public Operator<Context> { ...@@ -39,19 +39,18 @@ class AffineGradientOp final : public Operator<Context> {
public: public:
AffineGradientOp(const OperatorDef& def, Workspace* ws) AffineGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
num_axes(OperatorBase::Arg<int64_t>("num_axes", 1)) {} num_axes_(OpArg<int64_t>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void BiasRunWithType(); template <typename T> void Reduce(T* x, T* y);
template <typename T> void ScaleRunWithType(); template <typename T> void RunImpl();
template <typename T> void ComputeScaleGradient(T* dYxX, T* dA);
template <typename T> void RunWithType();
protected: protected:
int64_t axis, num_axes; int64_t axis_, num_axes_;
int64_t outer_dim, inner_dim, scale_dim, sum_dim, dim; int64_t outer_dim_, inner_dim_;
int64_t scale_dim_, reduce_dim_, dim_;
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -63,41 +62,42 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -63,41 +62,42 @@ class CuDNNAffineOpBase : public Operator<Context> {
public: public:
CuDNNAffineOpBase(const OperatorDef& def, Workspace* ws) CuDNNAffineOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
num_axes(OperatorBase::Arg<int64_t>("num_axes", 1)) { num_axes_(OpArg<int64_t>("num_axes", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CuDNNCreateTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc)); CuDNNCreateTensorDesc(&param_desc_);
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_desc)); CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_op_));
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&add_desc)); CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&add_op_));
CUDNN_CHECK(cudnnCreateReduceTensorDescriptor(&reduce_desc)); CUDNN_CHECK(cudnnCreateReduceTensorDescriptor(&reduce_desc_));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
virtual ~CuDNNAffineOpBase() { virtual ~CuDNNAffineOpBase() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CuDNNDestroyTensorDesc(&input_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(param_desc)); CuDNNDestroyTensorDesc(&param_desc_);
CUDNN_CHECK(cudnnDestroyOpTensorDescriptor(mul_desc)); CUDNN_CHECK(cudnnDestroyOpTensorDescriptor(mul_op_));
CUDNN_CHECK(cudnnDestroyReduceTensorDescriptor(reduce_desc)); CUDNN_CHECK(cudnnDestroyOpTensorDescriptor(add_op_));
CUDNN_CHECK(cudnnDestroyReduceTensorDescriptor(reduce_desc_));
} }
template <typename T> template <typename T>
void ResetDesc(const Tensor& X); void ResetDesc(const Tensor& X);
int64_t axis, num_axes; int64_t axis_, num_axes_;
cudnnTensorDescriptor_t input_desc, param_desc; cudnnTensorDescriptor_t input_desc_, param_desc_;
cudnnOpTensorDescriptor_t mul_desc, add_desc; cudnnOpTensorDescriptor_t mul_op_, add_op_;
cudnnReduceTensorDescriptor_t reduce_desc; cudnnReduceTensorDescriptor_t reduce_desc_;
}; };
#define USE_CUDNN_AFFINE_FUCNTIONS \ #define USE_CUDNN_AFFINE_FUCNTIONS \
USE_OPERATOR_FUNCTIONS; \ USE_OPERATOR_FUNCTIONS; \
using CuDNNAffineOpBase<Context>::axis; \ using CuDNNAffineOpBase<Context>::axis_; \
using CuDNNAffineOpBase<Context>::num_axes; \ using CuDNNAffineOpBase<Context>::num_axes_; \
using CuDNNAffineOpBase<Context>::input_desc; \ using CuDNNAffineOpBase<Context>::input_desc_; \
using CuDNNAffineOpBase<Context>::param_desc; \ using CuDNNAffineOpBase<Context>::param_desc_; \
using CuDNNAffineOpBase<Context>::mul_desc; \ using CuDNNAffineOpBase<Context>::mul_op_; \
using CuDNNAffineOpBase<Context>::add_desc; \ using CuDNNAffineOpBase<Context>::add_op_; \
using CuDNNAffineOpBase<Context>::reduce_desc using CuDNNAffineOpBase<Context>::reduce_desc_
template <class Context> template <class Context>
class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> { class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> {
...@@ -106,7 +106,7 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> { ...@@ -106,7 +106,7 @@ class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> {
: CuDNNAffineOpBase<Context>(def, ws) {} : CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename DT, typename CT> void RunWithType(); template <typename DT, typename CT> void RunImpl();
protected: protected:
USE_CUDNN_AFFINE_FUCNTIONS; USE_CUDNN_AFFINE_FUCNTIONS;
...@@ -124,13 +124,14 @@ public: ...@@ -124,13 +124,14 @@ public:
void RunOnDevice() override; void RunOnDevice() override;
template <typename DT, typename CT> template <typename DT, typename CT>
void ComputeScaleGradient(DT* dYxX, DT* dA); void CuDNNReduce(DT* x, DT* y);
template <typename T> void ComputeScaleGradient_v2(T* dYxX, T* dA); template <typename T> void Reduce(T* x, T* y);
template <typename DT, typename CT> void RunWithType(); template <typename DT, typename CT> void RunImpl();
protected: protected:
USE_CUDNN_AFFINE_FUCNTIONS; USE_CUDNN_AFFINE_FUCNTIONS;
int64_t outer_dim, inner_dim, scale_dim, dim, sum_dim; int64_t outer_dim_, inner_dim_;
int64_t scale_dim_, dim_, reduce_dim_;
}; };
#endif #endif
......
...@@ -22,15 +22,15 @@ class ClipOp final : public Operator<Context> { ...@@ -22,15 +22,15 @@ class ClipOp final : public Operator<Context> {
public: public:
ClipOp(const OperatorDef& def, Workspace* ws) ClipOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
low(OperatorBase::Arg<float>("low", -FLT_MAX)), low_(OpArg<float>("low", -FLT_MAX)),
high(OperatorBase::Arg<float>("high", FLT_MAX)) {} high_(OpArg<float>("high", FLT_MAX)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float low, high, lowT, highT; float low_, high_, lowT_, highT_;
}; };
template <class Context> template <class Context>
...@@ -38,15 +38,15 @@ class ClipGradientOp final : public Operator<Context> { ...@@ -38,15 +38,15 @@ class ClipGradientOp final : public Operator<Context> {
public: public:
ClipGradientOp(const OperatorDef& def, Workspace* ws) ClipGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
low(OperatorBase::Arg<float>("low", -FLT_MAX)), low_(OpArg<float>("low", -FLT_MAX)),
high(OperatorBase::Arg<float>("high", FLT_MAX)) {} high_(OpArg<float>("high", FLT_MAX)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float low, high, lowT, highT; float low_, high_, lowT_, highT_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,18 +22,19 @@ class DotOp final : public Operator<Context> { ...@@ -22,18 +22,19 @@ class DotOp final : public Operator<Context> {
public: public:
DotOp(const OperatorDef& def, Workspace* ws) DotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
transA(OperatorBase::Arg<bool>("transA", false)), transA_(OpArg<bool>("transA", false)),
transB(OperatorBase::Arg<bool>("transB", false)) {} transB_(OpArg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void DotRunWithType(); template <typename T> void DotRunImpl();
template <typename T> void GemmRunWithType(); template <typename T> void GemmRunImpl();
template <typename T> void GemvRunWithType(); template <typename T> void GemvRunImpl();
protected: protected:
int64_t M1, N1, M2, N2; int64_t transA_, transB_;
int64_t transA, transB, M, K1, K2, N; int64_t M_, K1_, K2_, N_;
int64_t M1_, N1_, M2_, N2_;
}; };
template <class Context> template <class Context>
...@@ -41,18 +42,19 @@ class DotGradientOp final : public Operator<Context> { ...@@ -41,18 +42,19 @@ class DotGradientOp final : public Operator<Context> {
public: public:
DotGradientOp(const OperatorDef& def, Workspace* ws) DotGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
transA(OperatorBase::Arg<bool>("transA", false)), transA_(OpArg<bool>("transA", false)),
transB(OperatorBase::Arg<bool>("transB", false)) {} transB_(OpArg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void DotRunWithType(); template <typename T> void DotRunImpl();
template <typename T> void GemmRunWithType(); template <typename T> void GemmRunImpl();
template <typename T> void GemvRunWithType(); template <typename T> void GemvRunImpl();
protected: protected:
int64_t M1, N1, M2, N2; int64_t transA_, transB_;
int64_t transA, transB, M, K1, K2, N; int64_t M_, K1_, K2_, N_;
int64_t M1_, N1_, M2_, N2_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,28 +22,30 @@ class EltwiseOp final : public Operator<Context> { ...@@ -22,28 +22,30 @@ class EltwiseOp final : public Operator<Context> {
public: public:
EltwiseOp(const OperatorDef& def, Workspace* ws) EltwiseOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
operation(OperatorBase::Arg<string>("operation", "SUM")), coef_(OpArgs<float>("coef")),
coeffs(OperatorBase::Args<float>("coefficients")) { operation_(OpArg<string>("operation", "SUM")) {
// Check the number of coeffients // Check the number of coeffients
if (coeffs.size() > 0) { if (coef_.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize()) CHECK_EQ(coef_.size(), XSize())
<< "\nOp has " << InputSize() << " inputs, " << "\nOp has " << XSize() << " inputs, "
<< "but provided " << coeffs.size() << " coeffs."; << "while providing " << coef_.size() << " coefs.";
} else coeffs.resize(InputSize(), 1.f); } else {
coef_.resize((size_t)XSize(), 1.f);
}
// Compute the alpha for product operation // Compute the alpha for product operation
for (auto e : coeffs) { if (e != 1.f) alpha *= e; } for (auto e : coef_) { if (e != 1.f) alpha_ *= e; }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template <typename T> void SumRunWithType(); template <typename T> void SumRunImpl();
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunImpl();
protected: protected:
string operation; string operation_;
float alpha = 1.f; float alpha_ = 1.f;
vector<float> coeffs; vector<float> coef_;
}; };
template <class Context> template <class Context>
...@@ -51,27 +53,29 @@ class EltwiseGradientOp final : public Operator<Context> { ...@@ -51,27 +53,29 @@ class EltwiseGradientOp final : public Operator<Context> {
public: public:
EltwiseGradientOp(const OperatorDef& def, Workspace* ws) EltwiseGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
operation(OperatorBase::Arg<string>("operation", "SUM")), coef_(OpArgs<float>("coef")),
coeffs(OperatorBase::Args<float>("coefficients")) { operation_(OpArg<string>("operation", "SUM")) {
if (coeffs.size() > 0) { if (coef_.size() > 0) {
CHECK_EQ(coeffs.size(), OutputSize()) CHECK_EQ(coef_.size(), YSize())
<< "\nOp has " << OutputSize() << " inputs, " << "\nOp has " << YSize() << " inputs, "
<< "but provided " << coeffs.size() << " coeffs."; << "while providing " << coef_.size() << " coefs.";
} else coeffs.resize(InputSize(), 1.f); } else {
coef_.resize(YSize(), 1.f);
}
// Compute the alpha for product operation // Compute the alpha for product operation
for (auto e : coeffs) { if (e != 1.f) alpha *= e; } for (auto e : coef_) { if (e != 1.f) alpha_ *= e; }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template <typename T> void SumRunWithType(); template <typename T> void SumRunImpl();
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunImpl();
protected: protected:
string operation; string operation_;
float alpha = 1.f; float alpha_ = 1.f;
vector<float> coeffs; vector<float> coef_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,21 +20,21 @@ namespace dragon { ...@@ -20,21 +20,21 @@ namespace dragon {
template <class Context> template <class Context>
class ExpOp final : public Operator<Context> { class ExpOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpOp); SIMPLE_CTOR_DTOR(ExpOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class ExpGradientOp final : public Operator<Context> { class ExpGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpGradientOp); SIMPLE_CTOR_DTOR(ExpGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,34 +22,37 @@ class FullyConnectedOp final : public Operator<Context> { ...@@ -22,34 +22,37 @@ class FullyConnectedOp final : public Operator<Context> {
public: public:
FullyConnectedOp(const OperatorDef& def, Workspace *ws) FullyConnectedOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
N(OperatorBase::Arg<int64_t>("num_output", 0)), N_(OpArg<int64_t>("num_output", 0)),
transW(OperatorBase::Arg<bool>("transW", true)) {} transW_(OpArg<bool>("transW", true)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice(); void RunOnDevice();
template <typename T> void TransRunWithType(); template <typename T> void TransRunImpl();
template <typename T> void NoTransRunWithType(); template <typename T> void NoTransRunImpl();
protected: protected:
int64_t axis, transW, M, K, N; int64_t axis_, transW_, M_, K_, N_;
}; };
template <class Context> template <class Context>
class FullyConnectedGradientOp final : public Operator<Context> { class FullyConnectedGradientOp
final : public Operator<Context> {
public: public:
FullyConnectedGradientOp(const OperatorDef& def, Workspace *ws) FullyConnectedGradientOp(
const OperatorDef& def,
Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
N(OperatorBase::Arg<int64_t>("num_output", 0)), N_(OpArg<int64_t>("num_output", 0)),
transW(OperatorBase::Arg<bool>("transW", true)) {} transW_(OpArg<bool>("transW", true)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, transW, M, K, N; int64_t axis_, transW_, M_, K_, N_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -18,17 +18,17 @@ ...@@ -18,17 +18,17 @@
namespace dragon { namespace dragon {
#define DECLARE_FUNDAMENTAL_OP(type) \ #define DECLARE_FUNDAMENTAL_OP(name) \
template <class Context> \ template <class Context> \
class type##Op final : public Operator<Context> { \ class name##Op final : public Operator<Context> { \
public: \ public: \
USE_SIMPLE_CTOR_DTOR(type##Op); \ SIMPLE_CTOR_DTOR(name##Op); \
USE_OPERATOR_FUNCTIONS; \ USE_OPERATOR_FUNCTIONS; \
void RunOnDevice() override; \ void RunOnDevice() override; \
template <typename T> void EltwiseRunWithType(); \ template <typename T> void EltwiseRunImpl(); \
template <typename T> void BroadcastRunWithType(int type); \ template <typename T> void BroadcastRunImpl(int type); \
protected: \ protected: \
int rows, cols; \ int rows_, cols_; \
}; };
DECLARE_FUNDAMENTAL_OP(Add); DECLARE_FUNDAMENTAL_OP(Add);
...@@ -51,50 +51,48 @@ DECLARE_FUNDAMENTAL_OP(RSubGradient); ...@@ -51,50 +51,48 @@ DECLARE_FUNDAMENTAL_OP(RSubGradient);
DECLARE_FUNDAMENTAL_OP(RMulGradient); DECLARE_FUNDAMENTAL_OP(RMulGradient);
DECLARE_FUNDAMENTAL_OP(RDivGradient); DECLARE_FUNDAMENTAL_OP(RDivGradient);
#define DECLARE_FUNDAMENTAL_OP_X1X2 \ #define DECLARE_INPUT_DESC \
ws()->CreateTensor(mount_name( \ ws()->CreateTensor(unique_name("A"))->ReshapeLike(X(0)); \
"fundamental/X1"))->ReshapeLike(Input(0)); \ ws()->CreateTensor(unique_name("B"))->ReshapeLike(X(1));
ws()->CreateTensor(mount_name( \
"fundamental/X2"))->ReshapeLike(Input(1));
#define DEFINE_FUNDAMENTAL_OP_X1X2 \ #define DEFINE_INPUT_DESC \
Tensor* X1 = ws()->GetTensor(mount_name("fundamental/X1")); \ auto* A = ws()->GetTensor(unique_name("A")); \
Tensor* X2 = ws()->GetTensor(mount_name("fundamental/X2")); auto* B = ws()->GetTensor(unique_name("B"));
#define DEFINE_FUNDAMENTAL_TYPED_CALLER(dtype) \ #define DEFINE_FUNDAMENTAL_TYPED_IMPL(dtype) \
DEFINE_FUNDAMENTAL_OP_X1X2; \ DEFINE_INPUT_DESC; \
if (X2->count() < X1->count() && \ if (B->count() < A->count() && \
utils::IsRowwiseBroadcast( \ utils::IsRowwiseBroadcast( \
X1->dims(), X2->dims(), &rows, &cols)) { \ A->dims(), B->dims(), &rows_, &cols_)) { \
BroadcastRunWithType<dtype>(0); \ BroadcastRunImpl<dtype>(0); \
} else if (X2->count() < X1->count() && \ } else if (B->count() < A->count() && \
utils::IsColwiseBroadcast( \ utils::IsColwiseBroadcast( \
X1->dims(), X2->dims(), &rows, &cols)) { \ A->dims(), B->dims(), &rows_, &cols_)) { \
BroadcastRunWithType<dtype>(1); \ BroadcastRunImpl<dtype>(1); \
} else if (X1->count() == X2->count()) { \ } else if (A->count() == B->count()) { \
EltwiseRunWithType<dtype>(); \ EltwiseRunImpl<dtype>(); \
} else { \ } else { \
LOG(FATAL) << "Could not broadcast with shapes: " \ LOG(FATAL) << "Could not broadcast with shapes: " \
<< X1->DimString() << " and " \ << A->DimString() << " and " \
<< X2->DimString(); \ << B->DimString(); \
} }
#define DEFINE_FUNDAMENTAL_TYPED_RCALLER(dtype) \ #define DEFINE_RFUNDAMENTAL_TYPED_IMPL(dtype) \
DEFINE_FUNDAMENTAL_OP_X1X2; \ DEFINE_INPUT_DESC; \
if (X2->count() > X1->count() && \ if (B->count() > A->count() && \
utils::IsRowwiseBroadcast( \ utils::IsRowwiseBroadcast( \
X1->dims(), X2->dims(), &rows, &cols)) { \ A->dims(), B->dims(), &rows_, &cols_)) { \
BroadcastRunWithType<dtype>(2); \ BroadcastRunImpl<dtype>(2); \
} else if (X2->count() > X1->count() && \ } else if (B->count() > A->count() && \
utils::IsColwiseBroadcast( \ utils::IsColwiseBroadcast( \
X1->dims(), X2->dims(), &rows, &cols)) { \ A->dims(), B->dims(), &rows_, &cols_)) { \
BroadcastRunWithType<dtype>(3); \ BroadcastRunImpl<dtype>(3); \
} else if (X1->count() == X2->count()) { \ } else if (A->count() == B->count()) { \
EltwiseRunWithType<dtype>(); \ EltwiseRunImpl<dtype>(); \
} else { \ } else { \
LOG(FATAL) << "Could not broadcast with shapes: " \ LOG(FATAL) << "Could not broadcast with shapes: " \
<< X1->DimString() << " and " \ << A->DimString() << " and " \
<< X2->DimString(); \ << B->DimString(); \
} }
#undef DECLARE_FUNDAMENTAL_OP #undef DECLARE_FUNDAMENTAL_OP
......
...@@ -22,15 +22,16 @@ class GramMatrixOp final : public Operator<Context> { ...@@ -22,15 +22,16 @@ class GramMatrixOp final : public Operator<Context> {
public: public:
GramMatrixOp(const OperatorDef& def, Workspace* ws) GramMatrixOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)) {} axis_(OpArg<int64_t>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, dim, inner_dim; int64_t x_ofs_, y_ofs_;
int64_t x_offset, y_offset; int64_t axis_, axis_dim_;
int64_t outer_dim_, inner_dim_;
}; };
template <class Context> template <class Context>
...@@ -38,15 +39,16 @@ class GramMatrixGradientOp final : public Operator<Context> { ...@@ -38,15 +39,16 @@ class GramMatrixGradientOp final : public Operator<Context> {
public: public:
GramMatrixGradientOp(const OperatorDef& def, Workspace* ws) GramMatrixGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)) {} axis_(OpArg<int64_t>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, dim, inner_dim; int64_t x_ofs_, y_ofs_;
int64_t x_offset, y_offset; int64_t axis_, axis_dim_;
int64_t outer_dim_, inner_dim_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,21 +20,21 @@ namespace dragon { ...@@ -20,21 +20,21 @@ namespace dragon {
template <class Context> template <class Context>
class LogOp final : public Operator<Context> { class LogOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(LogOp); SIMPLE_CTOR_DTOR(LogOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class LogGradientOp final : public Operator<Context> { class LogGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(LogGradientOp); SIMPLE_CTOR_DTOR(LogGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,17 +22,19 @@ class MatmulOp final : public Operator<Context> { ...@@ -22,17 +22,19 @@ class MatmulOp final : public Operator<Context> {
public: public:
MatmulOp(const OperatorDef& def, Workspace* ws) MatmulOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
transA(OperatorBase::Arg<bool>("transA", false)), transA_(OpArg<bool>("transA", false)),
transB(OperatorBase::Arg<bool>("transB", false)) {} transB_(OpArg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t M1, N1, M2, N2; int64_t batch_size_;
int64_t transA, transB, M, K1, K2, N; int64_t transA_, transB_;
int64_t batch_size, A_stride, B_stride, C_stride; int64_t M_, K1_, K2_, N_;
int64_t M1_, N1_, M2_, N2_;
int64_t A_stride_, B_stride_, Y_stride_;
}; };
template <class Context> template <class Context>
...@@ -40,17 +42,19 @@ class MatmulGradientOp final : public Operator<Context> { ...@@ -40,17 +42,19 @@ class MatmulGradientOp final : public Operator<Context> {
public: public:
MatmulGradientOp(const OperatorDef& def, Workspace* ws) MatmulGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
transA(OperatorBase::Arg<bool>("transA", false)), transA_(OpArg<bool>("transA", false)),
transB(OperatorBase::Arg<bool>("transB", false)) {} transB_(OpArg<bool>("transB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t M1, N1, M2, N2; int64_t batch_size_;
int64_t transA, transB, M, K1, K2, N; int64_t transA_, transB_;
int64_t batch_size, A_stride, B_stride, C_stride; int64_t M_, K1_, K2_, N_;
int64_t M1_, N1_, M2_, N2_;
int64_t A_stride_, B_stride_, Y_stride_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,25 +20,25 @@ namespace dragon { ...@@ -20,25 +20,25 @@ namespace dragon {
template <class Context> template <class Context>
class MaximumOp final : public Operator<Context> { class MaximumOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MaximumOp); SIMPLE_CTOR_DTOR(MaximumOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunImpl();
template <typename T> void BroadcastRunWithType(); template <typename T> void BroadcastRunImpl();
}; };
template <class Context> template <class Context>
class MaximumGradientOp final : public Operator<Context> { class MaximumGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MaximumGradientOp); SIMPLE_CTOR_DTOR(MaximumGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunImpl();
template <typename T> void BroadcastRunWithType(); template <typename T> void BroadcastRunImpl();
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,25 +20,25 @@ namespace dragon { ...@@ -20,25 +20,25 @@ namespace dragon {
template <class Context> template <class Context>
class MinimumOp final : public Operator<Context> { class MinimumOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MinimumOp); SIMPLE_CTOR_DTOR(MinimumOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunImpl();
template <typename T> void BroadcastRunWithType(); template <typename T> void BroadcastRunImpl();
}; };
template <class Context> template <class Context>
class MinimumGradientOp final : public Operator<Context> { class MinimumGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MinimumGradientOp); SIMPLE_CTOR_DTOR(MinimumGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunImpl();
template <typename T> void BroadcastRunWithType(); template <typename T> void BroadcastRunImpl();
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,17 +22,17 @@ class MomentsOp final : public Operator<Context> { ...@@ -22,17 +22,17 @@ class MomentsOp final : public Operator<Context> {
public: public:
MomentsOp(const OperatorDef& def, Workspace* ws) MomentsOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axes(OperatorBase::Args<int64_t>("axes")), axes_(OpArgs<int64_t>("axes")),
keep_dims(OperatorBase::Arg<int64_t>("keep_dims", 0)) {} keep_dims_(OpArg<int64_t>("keep_dims", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
int64_t keep_dims; int64_t keep_dims_;
vector<int64_t> dims, axes; vec64_t dims_, axes_;
vector<int> dims32, axes32; vec32_t dims32_, axes32_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,18 +22,19 @@ class PowOp final : public Operator<Context> { ...@@ -22,18 +22,19 @@ class PowOp final : public Operator<Context> {
public: public:
PowOp(const OperatorDef& def, Workspace* ws) PowOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.f)), scale_(OpArg<float>("scale", 1.f)),
shift(OperatorBase::Arg<float>("shift", 0.f)), shift_(OpArg<float>("shift", 0.f)),
power(OperatorBase::Arg<float>("power", 1.f)) { power_(OpArg<float>("power", 1.f)) {
power_scale = power * scale; power_scale_ = power_ * scale_;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float scale, shift, power, power_scale; float scale_, shift_;
float power_, power_scale_;
}; };
template <class Context> template <class Context>
...@@ -41,18 +42,19 @@ class PowGradientOp final : public Operator<Context> { ...@@ -41,18 +42,19 @@ class PowGradientOp final : public Operator<Context> {
public: public:
PowGradientOp(const OperatorDef& def, Workspace* ws) PowGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.f)), scale_(OpArg<float>("scale", 1.f)),
shift(OperatorBase::Arg<float>("shift", 0.f)), shift_(OpArg<float>("shift", 0.f)),
power(OperatorBase::Arg<float>("power", 1.f)) { power_(OpArg<float>("power", 1.f)) {
power_scale = power * scale; power_scale_ = power_ * scale_;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float scale, shift, power, power_scale; float scale_, shift_;
float power_, power_scale_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,21 +20,21 @@ namespace dragon { ...@@ -20,21 +20,21 @@ namespace dragon {
template <class Context> template <class Context>
class SqrtOp final : public Operator<Context> { class SqrtOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SqrtOp); SIMPLE_CTOR_DTOR(SqrtOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class SqrtGradientOp final : public Operator<Context> { class SqrtGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SqrtGradientOp); SIMPLE_CTOR_DTOR(SqrtGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,21 +20,21 @@ namespace dragon { ...@@ -20,21 +20,21 @@ namespace dragon {
template <class Context> template <class Context>
class SquareOp final : public Operator<Context> { class SquareOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SquareOp); SIMPLE_CTOR_DTOR(SquareOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class SquareGradientOp final : public Operator<Context> { class SquareGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SquareGradientOp); SIMPLE_CTOR_DTOR(SquareGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
} // namespace dragon } // namespace dragon
......
...@@ -21,28 +21,26 @@ template <class Context> ...@@ -21,28 +21,26 @@ template <class Context>
class ArangeOp final : public Operator<Context> { class ArangeOp final : public Operator<Context> {
public: public:
ArangeOp(const OperatorDef& def, Workspace* ws) ArangeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws) {
dtype(OperatorBase::Arg<string>("dtype", "float32")) { GET_ARG_WITH_DESC(int64_t, start, 0);
GET_ARGUMENT_WITH_DESC(int64_t, start, 0); GET_ARG_WITH_DESC(int64_t, stop, 0);
GET_ARGUMENT_WITH_DESC(int64_t, stop, 0); GET_ARG_WITH_DESC(int64_t, step, 1);
GET_ARGUMENT_WITH_DESC(int64_t, step, 1);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
string dtype; int64_t astart_, astop_, astep_, dim_;
int64_t astart, astop, astep, dim; DECLARE_ARG_WITH_DESC(int64_t, start);
DECLARE_ARGUMENT_WITH_DESC(int64_t, start); DECLARE_ARG_WITH_DESC(int64_t, stop);
DECLARE_ARGUMENT_WITH_DESC(int64_t, stop); DECLARE_ARG_WITH_DESC(int64_t, step);
DECLARE_ARGUMENT_WITH_DESC(int64_t, step);
}; };
DEFINE_ARGUMENT_WITH_DESC(int64_t, ArangeOp, start); DEFINE_ARG_WITH_DESC(int64_t, ArangeOp, start);
DEFINE_ARGUMENT_WITH_DESC(int64_t, ArangeOp, stop); DEFINE_ARG_WITH_DESC(int64_t, ArangeOp, stop);
DEFINE_ARGUMENT_WITH_DESC(int64_t, ArangeOp, step); DEFINE_ARG_WITH_DESC(int64_t, ArangeOp, step);
} // namespace dragon } // namespace dragon
......
...@@ -22,19 +22,20 @@ class ArgReduceOp final : public Operator<Context> { ...@@ -22,19 +22,20 @@ class ArgReduceOp final : public Operator<Context> {
public: public:
ArgReduceOp(const OperatorDef& def, Workspace* ws) ArgReduceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", INT_MAX)), top_k_(OpArg<int64_t>("top_k", 1)),
operation(OperatorBase::Arg<string>("operation", "NONE")), axis_(OpArg<int64_t>("axis", INT_MAX)),
keep_dims(OperatorBase::Arg<bool>("keep_dims", false)), keep_dims_(OpArg<int64_t>("keep_dims", 0)),
top_k(OperatorBase::Arg<int64_t>("top_k", 1)) {} operation_(OpArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
bool keep_dims; CPUContext cctx_;
string operation; string operation_;
int64_t axis, top_k, outer_dim, axis_dim, inner_dim; int64_t axis_, top_k_, keep_dims_;
int64_t outer_dim_, axis_dim_, inner_dim_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,17 +22,15 @@ class ConcatOp : public Operator<Context> { ...@@ -22,17 +22,15 @@ class ConcatOp : public Operator<Context> {
public: public:
ConcatOp(const OperatorDef& def, Workspace* ws) ConcatOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)) {} axis_(OpArg<int64_t>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; int64_t axis_, cat_dim_;
int64_t x_concat_dim, y_concat_dim; int64_t outer_dim_, inner_dim_;
int64_t x_offset, y_offset, concat_offset;
vector<int64_t> concat_dims;
}; };
template <class Context> template <class Context>
...@@ -40,17 +38,15 @@ class ConcatGradientOp : public Operator<Context> { ...@@ -40,17 +38,15 @@ class ConcatGradientOp : public Operator<Context> {
public: public:
ConcatGradientOp(const OperatorDef& def, Workspace* ws) ConcatGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)) {} axis_(OpArg<int64_t>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; int64_t axis_, cat_dim_;
int64_t x_concat_dim, y_concat_dim; int64_t outer_dim_, inner_dim_;
int64_t x_offset, y_offset, concat_offset;
vector<int64_t> concat_dims;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,45 +22,44 @@ class CropOp final : public Operator<Context> { ...@@ -22,45 +22,44 @@ class CropOp final : public Operator<Context> {
public: public:
CropOp(const OperatorDef& def, Workspace* ws) CropOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
start_axis(OperatorBase::Arg<int64_t>("start_axis", -1)), ofs_(OpArgs<int64_t>("offsets")),
offsets(OperatorBase::Args<int64_t>("offsets")), start_axis_(OpArg<int64_t>("start_axis", -1)),
shape_like(OperatorBase::Arg<string>("shape_like", "")) { shape_desc_(OpArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int64_t, starts); GET_ARGS_WITH_DESC(int64_t, starts);
GET_ARGUMENTS_WITH_DESC(int64_t, sizes); GET_ARGS_WITH_DESC(int64_t, sizes);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t start_axis; string shape_desc_;
string shape_like; int64_t start_axis_;
vector<int64_t> offsets; vec64_t st_, ed_, ofs_, keep_;
vector<int64_t> st, ed, keep_dims, y_dimsV; Tensor X_starts_, X_strides_, Y_dims_;
Tensor startsT, x_stridesT, y_dimsT; DECLARE_ARGS_WITH_DESC(int64_t, starts);
DECLARE_ARGUMENTS_WITH_DESC(int64_t, starts); DECLARE_ARGS_WITH_DESC(int64_t, sizes);
DECLARE_ARGUMENTS_WITH_DESC(int64_t, sizes);
}; };
DEFINE_ARGUMENTS_WITH_DESC(int64_t, CropOp, starts);
DEFINE_ARGUMENTS_WITH_DESC(int64_t, CropOp, sizes);
template <class Context> template <class Context>
class CropGradientOp final : public Operator<Context> { class CropGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(CropGradientOp); SIMPLE_CTOR_DTOR(CropGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
vector<int64_t> st, ed, y_dimsV; vec64_t st_, ed_;
Tensor startsT, x_stridesT, y_dimsT; Tensor X_starts_, X_strides_, Y_dims_;
}; };
DEFINE_ARGS_WITH_DESC(int64_t, CropOp, starts);
DEFINE_ARGS_WITH_DESC(int64_t, CropOp, sizes);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CROP_OP_H_ #endif // DRAGON_OPERATORS_ARRAY_CROP_OP_H_
\ No newline at end of file
...@@ -17,16 +17,12 @@ ...@@ -17,16 +17,12 @@
namespace dragon { namespace dragon {
/********************************************* /* Base */
* *
* Base *
* *
**********************************************/
template <class Context> template <class Context>
class DimOpBase : public Operator<Context> { class DimOpBase : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(DimOpBase); SIMPLE_CTOR_DTOR(DimOpBase);
void MemorySwitch() override { void MemorySwitch() override {
/* Disable the Memory Activation */ /* Disable the Memory Activation */
...@@ -36,115 +32,102 @@ class DimOpBase : public Operator<Context> { ...@@ -36,115 +32,102 @@ class DimOpBase : public Operator<Context> {
template <class Context> template <class Context>
class DimGradientOpBase : public Operator<Context> { class DimGradientOpBase : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(DimGradientOpBase); SIMPLE_CTOR_DTOR(DimGradientOpBase);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override { void RunOnDevice() override {
// Simply copy the dY to dX // Simply copy the dY to dX
Output(0)->ReshapeLike(Input(0)); Y(0)->ReshapeLike(X(0));
Output(0)->template CopyFrom<Context>(Input(-1), ctx()); Y(0)->CopyFrom(X(-1), ctx());
} }
}; };
#define DEFINE_DIMENSION_GRADIENT_OP(name) \ #define DEFINE_DIMENSION_GRADIENT_OP(name) \
template <class Context> \ template <class Context> \
class name##GradientOp final : public DimGradientOpBase<Context> { \ class name##GradientOp final : \
public DimGradientOpBase<Context> { \
public: \ public: \
name##GradientOp(const OperatorDef& def, Workspace* ws) \ name##GradientOp( \
const OperatorDef& def, \
Workspace* ws) \
: DimGradientOpBase<Context>(def, ws) {} \ : DimGradientOpBase<Context>(def, ws) {} \
}; };
/********************************************* /* Reshape */
* *
* Reshape *
* *
**********************************************/
template <class Context> template <class Context>
class ReshapeOp final : public DimOpBase<Context> { class ReshapeOp final : public DimOpBase<Context> {
public: public:
ReshapeOp(const OperatorDef& def, Workspace* ws) ReshapeOp(const OperatorDef& def, Workspace* ws)
: DimOpBase<Context>(def, ws), : DimOpBase<Context>(def, ws),
shape_like_desc(OperatorBase::Arg<string>("shape_like", "")) { shape_desc_(OpArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int64_t, dims); GET_ARGS_WITH_DESC(int64_t, dims);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
protected: protected:
string shape_like_desc; string shape_desc_;
vector<int64_t> require_shape, new_shape; vec64_t req_shape_, new_shape_;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, dims); DECLARE_ARGS_WITH_DESC(int64_t, dims);
}; };
DEFINE_DIMENSION_GRADIENT_OP(Reshape); DEFINE_DIMENSION_GRADIENT_OP(Reshape);
DEFINE_ARGUMENTS_WITH_DESC(int64_t, ReshapeOp, dims); DEFINE_ARGS_WITH_DESC(int64_t, ReshapeOp, dims);
/********************************************* /* Flatten */
* *
* Flatten *
* *
**********************************************/
template <class Context> template <class Context>
class FlattenOp final : public DimOpBase<Context> { class FlattenOp final : public DimOpBase<Context> {
public: public:
FlattenOp(const OperatorDef& def, Workspace* ws) FlattenOp(const OperatorDef& def, Workspace* ws)
: DimOpBase<Context>(def, ws), : DimOpBase<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)), axis_(OpArg<int64_t>("axis", 0)),
num_axes(OperatorBase::Arg<int64_t>("num_axes", -1)), num_axes_(OpArg<int64_t>("num_axes", -1)),
keep_axes(OperatorBase::Arg<int64_t>("keep_axes", INT_MAX)) {} keep_axes_(OpArg<int64_t>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
protected: protected:
int64_t axis, num_axes, keep_axes; int64_t axis_, num_axes_, keep_axes_;
}; };
DEFINE_DIMENSION_GRADIENT_OP(Flatten); DEFINE_DIMENSION_GRADIENT_OP(Flatten);
/********************************************* /* ExpandDims */
* *
* Expand Dims *
* *
**********************************************/
template <class Context> template <class Context>
class ExpandDimsOp final : public DimOpBase<Context> { class ExpandDimsOp final : public DimOpBase<Context> {
public: public:
ExpandDimsOp(const OperatorDef& def, Workspace* ws) ExpandDimsOp(const OperatorDef& def, Workspace* ws)
: DimOpBase<Context>(def, ws), : DimOpBase<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)) {} axis_(OpArg<int64_t>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
protected: protected:
int64_t axis; int64_t axis_;
}; };
DEFINE_DIMENSION_GRADIENT_OP(ExpandDims); DEFINE_DIMENSION_GRADIENT_OP(ExpandDims);
/********************************************* /* Squeeze */
* *
* Squeeze *
* *
**********************************************/
template <class Context> template <class Context>
class SqueezeOp final : public DimOpBase<Context> { class SqueezeOp final : public DimOpBase<Context> {
public: public:
SqueezeOp(const OperatorDef& def, Workspace* ws) SqueezeOp(const OperatorDef& def, Workspace* ws)
: DimOpBase<Context>(def, ws), : DimOpBase<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", INT_MAX)) {} axis_(OpArg<int64_t>("axis", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
protected: protected:
int64_t axis; int64_t axis_;
}; };
DEFINE_DIMENSION_GRADIENT_OP(Squeeze); DEFINE_DIMENSION_GRADIENT_OP(Squeeze);
......
...@@ -10,44 +10,45 @@ ...@@ -10,44 +10,45 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_OPERATORS_ARRAY_GATHER_OP_H_ #ifndef DRAGON_OPERATORS_ARRAY_INDEX_SELECT_OP_H_
#define DRAGON_OPERATORS_ARRAY_GATHER_OP_H_ #define DRAGON_OPERATORS_ARRAY_INDEX_SELECT_OP_H_
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class GatherOp final : public Operator<Context> { class IndexSelectOp final : public Operator<Context> {
public: public:
GatherOp(const OperatorDef& def, Workspace* ws) IndexSelectOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)) {} axis_(OpArg<int64_t>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim, x_slice_dim, y_slice_dim; int64_t outer_dim_, inner_dim_;
vector<int64_t> output_dims; int64_t axis_, axis_dim_, nindices_;
}; };
template <class Context> template <class Context>
class GatherGradientOp final : public Operator<Context> { class IndexSelectGradientOp final : public Operator<Context> {
public: public:
GatherGradientOp(const OperatorDef& def, Workspace* ws) IndexSelectGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)) {} axis_(OpArg<int64_t>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim, x_slice_dim, y_slice_dim; int64_t outer_dim_, inner_dim_;
int64_t axis_, axis_dim_, nindices_;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_GATHER_OP_H_ #endif // DRAGON_OPERATORS_ARRAY_INDEX_SELECT_OP_H_
\ No newline at end of file \ No newline at end of file
...@@ -22,19 +22,19 @@ class MultinomialOp final : public Operator<Context> { ...@@ -22,19 +22,19 @@ class MultinomialOp final : public Operator<Context> {
public: public:
MultinomialOp(const OperatorDef& def, Workspace* ws) MultinomialOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
normalize(OperatorBase::Arg<int64_t>("normalize", 0)), normalize_(OpArg<int64_t>("normalize", 0)),
num_samples(OperatorBase::Arg<int64_t>("num_samples", 1)) {} num_samples_(OpArg<int64_t>("num_samples", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void SoftmaxRun(); void SoftmaxRun();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
Tensor* prob; int64_t outer_dim_, axis_;
int64_t normalize, num_samples, outer_dim, axis; int64_t normalize_, num_samples_;
unique_ptr<OperatorBase> softmax_op; unique_ptr<OperatorBase> softmax_op_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,16 +22,16 @@ class OneHotOp final : public Operator < Context > { ...@@ -22,16 +22,16 @@ class OneHotOp final : public Operator < Context > {
public: public:
OneHotOp(const OperatorDef& def, Workspace* ws) OneHotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
depth(OperatorBase::Arg<int64_t>("depth", -1)), depth_(OpArg<int64_t>("depth", -1)),
on_value(OperatorBase::Arg<int64_t>("on_value", 1)), on_value_(OpArg<int64_t>("on_value", 1)),
off_value(OperatorBase::Arg<int64_t>("off_value", 0)) {} off_value_(OpArg<int64_t>("off_value", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t depth, on_value, off_value; int64_t depth_, on_value_, off_value_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,27 +22,31 @@ class PadOp final : public Operator<Context> { ...@@ -22,27 +22,31 @@ class PadOp final : public Operator<Context> {
public: public:
PadOp(const OperatorDef& def, Workspace* ws) PadOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
pad_l(OperatorBase::Args<int64_t>("pad_l")), pad_l_(OpArgs<int64_t>("pad_l")),
pad_r(OperatorBase::Args<int64_t>("pad_r")), pad_r_(OpArgs<int64_t>("pad_r")),
mode(OperatorBase::Arg<string>("mode", "CONSTANT")), mode_(OpArg<string>("mode", "CONSTANT")),
value(OperatorBase::Arg<float>("value", 0.f)) { value_(OpArg<float>("value", 0.f)) {
if (pad_r.size() == 0) pad_r = pad_l; if (pad_r_.empty()) {
else CHECK_EQ(pad_l.size(), pad_r.size()) pad_r_ = pad_l_;
<< "The pad_l and pad_r should have the same length."; } else {
CHECK_EQ(pad_l_.size(), pad_r_.size())
<< "\nThe pad_l and pad_r "
<< "should have the same length.";
}
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template <typename T> void ConstRunWithType(); template <typename T> void ConstRunImpl();
template <typename T> void ReflectRunWithType(); template <typename T> void ReflectRunImpl();
template <typename T> void EdgeRunWithType(); template <typename T> void EdgeRunImpl();
protected: protected:
float value; float value_;
string mode; string mode_;
vector<int64_t> pad_l, pad_r, y_dimsV; vec64_t pad_l_, pad_r_;
Tensor l_padsT, x_dimsT, x_stridesT, y_dimsT; Tensor pads_, X_dims_, X_strides_, Y_dims_;
}; };
template <class Context> template <class Context>
...@@ -50,25 +54,29 @@ class PadGradientOp final : public Operator<Context> { ...@@ -50,25 +54,29 @@ class PadGradientOp final : public Operator<Context> {
public: public:
PadGradientOp(const OperatorDef& def, Workspace* ws) PadGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
pad_l(OperatorBase::Args<int64_t>("pad_l")), pad_l_(OpArgs<int64_t>("pad_l")),
pad_r(OperatorBase::Args<int64_t>("pad_r")), pad_r_(OpArgs<int64_t>("pad_r")),
mode(OperatorBase::Arg<string>("mode", "CONSTANT")) { mode_(OpArg<string>("mode", "CONSTANT")) {
if (pad_r.size() == 0) pad_r = pad_l; if (pad_r_.empty()) {
else CHECK_EQ(pad_l.size(), pad_r.size()) pad_r_ = pad_l_;
<< "The pad_l and pad_r should have the same length."; } else {
CHECK_EQ(pad_l_.size(), pad_r_.size())
<< "\nThe pad_l and pad_r "
<< "should have the same length.";
}
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template <typename T> void ConstRunWithType(); template <typename T> void ConstRunImpl();
template <typename T> void ReflectRunWithType(); template <typename T> void ReflectRunImpl();
template <typename T> void EdgeRunWithType(); template <typename T> void EdgeRunImpl();
protected: protected:
string mode; string mode_;
vector<int64_t> pad_l, pad_r, x_dimsV; vec64_t pad_l_, pad_r_;
Tensor l_padsT, x_dimsT, y_stridesT; Tensor pads_, X_dims_, Y_strides_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,19 +22,19 @@ class ReduceOp final : public Operator<Context> { ...@@ -22,19 +22,19 @@ class ReduceOp final : public Operator<Context> {
public: public:
ReduceOp(const OperatorDef& def, Workspace* ws) ReduceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axes(OperatorBase::Args<int64_t>("axes")), axes_(OpArgs<int64_t>("axes")),
keep_dims(OperatorBase::Arg<bool>("keep_dims", false)), keep_dims_(OpArg<int64_t>("keep_dims", 0)),
operation(OperatorBase::Arg<string>("operation", "SUM")) {} operation_(OpArg<string>("operation", "SUM")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
string operation; string operation_;
int64_t keep_dims; int64_t keep_dims_;
vector<int64_t> dims, axes; vec64_t dims_, axes_;
vector<int> dims32, axes32; vec32_t dims32_, axes32_;
}; };
template <class Context> template <class Context>
...@@ -42,19 +42,18 @@ class ReduceGradientOp final : public Operator<Context> { ...@@ -42,19 +42,18 @@ class ReduceGradientOp final : public Operator<Context> {
public: public:
ReduceGradientOp(const OperatorDef& def, Workspace* ws) ReduceGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axes(OperatorBase::Args<int64_t>("axes")), axes_(OpArgs<int64_t>("axes")),
operation(OperatorBase::Arg<string>("operation", "SUM")) {} operation_(OpArg<string>("operation", "SUM")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
string operation; string operation_;
int64_t axis, outer_dim, inner_dim, axis_dim; vec32_t axes32_;
vector<int64_t> axes, y_dimsV, y_stridesV; vec64_t axes_, y_dims_, y_strides_;
vector<int> dims32, axes32; Tensor X_dims_, Y_dims_, Y_strides_;
Tensor x_dimsT, y_dimsT, y_stridesT;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,17 +22,18 @@ class RepeatOp final : public Operator<Context> { ...@@ -22,17 +22,18 @@ class RepeatOp final : public Operator<Context> {
public: public:
RepeatOp(const OperatorDef& def, Workspace* ws) RepeatOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", INT_MAX)) { axis_(OpArg<int64_t>("axis", INT_MAX)) {
GET_ARGUMENT_WITH_DESC(int64_t, repeats, 1); GET_ARG_WITH_DESC(int64_t, repeats, 1);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, repeat_dim, inner_dim; int64_t axis_, axis_dim_;
DECLARE_ARGUMENT_WITH_DESC(int64_t, repeats); int64_t outer_dim_, inner_dim_;
DECLARE_ARG_WITH_DESC(int64_t, repeats);
}; };
template <class Context> template <class Context>
...@@ -40,21 +41,22 @@ class RepeatGradientOp final : public Operator<Context> { ...@@ -40,21 +41,22 @@ class RepeatGradientOp final : public Operator<Context> {
public: public:
RepeatGradientOp(const OperatorDef& def, Workspace* ws) RepeatGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", INT_MAX)) { axis_(OpArg<int64_t>("axis", INT_MAX)) {
GET_ARGUMENT_WITH_DESC(int64_t, repeats, 1); GET_ARG_WITH_DESC(int64_t, repeats, 1);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, repeat_dim, inner_dim; int64_t axis_, axis_dim_;
DECLARE_ARGUMENT_WITH_DESC(int64_t, repeats); int64_t outer_dim_, inner_dim_;
DECLARE_ARG_WITH_DESC(int64_t, repeats);
}; };
DEFINE_ARGUMENT_WITH_DESC(int64_t, RepeatOp, repeats); DEFINE_ARG_WITH_DESC(int64_t, RepeatOp, repeats);
DEFINE_ARGUMENT_WITH_DESC(int64_t, RepeatGradientOp, repeats); DEFINE_ARG_WITH_DESC(int64_t, RepeatGradientOp, repeats);
} // namespace dragon } // namespace dragon
......
...@@ -20,7 +20,7 @@ namespace dragon { ...@@ -20,7 +20,7 @@ namespace dragon {
template <class Context> template <class Context>
class ShapeOp final : public Operator<Context> { class ShapeOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ShapeOp); SIMPLE_CTOR_DTOR(ShapeOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -22,17 +22,17 @@ class SliceOp final : public Operator<Context> { ...@@ -22,17 +22,17 @@ class SliceOp final : public Operator<Context> {
public: public:
SliceOp(const OperatorDef& def, Workspace* ws) SliceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)), axis_(OpArg<int64_t>("axis", 0)),
slice_points(OperatorBase::Args<int64_t>("slice_points")) {} points_(OpArgs<int64_t>("slice_points")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, N, steps, slice_offset; vec64_t points_;
int64_t outer_dim, inner_dim, x_slice_dim, y_slice_dim; int64_t outer_dim_, inner_dim_;
vector<int64_t> slice_dims, slice_points; int64_t axis_, axis_dim_, slice_dim_, N_;
}; };
template <class Context> template <class Context>
...@@ -40,17 +40,17 @@ class SliceGradientOp final : public Operator<Context> { ...@@ -40,17 +40,17 @@ class SliceGradientOp final : public Operator<Context> {
public: public:
SliceGradientOp(const OperatorDef& def, Workspace* ws) SliceGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)), axis_(OpArg<int64_t>("axis", 0)),
slice_points(OperatorBase::Args<int64_t>("slice_points")) {} points_(OpArgs<int64_t>("slice_points")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, N, x_offset, y_offset, slice_offset; vec64_t points_;
int64_t outer_dim, inner_dim, x_slice_dim, y_slice_dim; int64_t outer_dim_, inner_dim_;
vector<int64_t> slice_dims, slice_points; int64_t axis_, axis_dim_, slice_dim_, N_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,15 +22,14 @@ class StackOp final : public Operator<Context> { ...@@ -22,15 +22,14 @@ class StackOp final : public Operator<Context> {
public: public:
StackOp(const OperatorDef& def, Workspace* ws) StackOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)) {} axis_(OpArg<int64_t>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; int64_t axis_, outer_dim_, inner_dim_;
vector<int64_t> stack_dims, concat_dims;
}; };
template <class Context> template <class Context>
...@@ -38,14 +37,14 @@ class StackGradientOp final : public Operator<Context> { ...@@ -38,14 +37,14 @@ class StackGradientOp final : public Operator<Context> {
public: public:
StackGradientOp(const OperatorDef& def, Workspace* ws) StackGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)) {} axis_(OpArg<int64_t>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; int64_t axis_, outer_dim_, inner_dim_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,19 +22,16 @@ class TileOp final : public Operator<Context> { ...@@ -22,19 +22,16 @@ class TileOp final : public Operator<Context> {
public: public:
TileOp(const OperatorDef& def, Workspace* ws) TileOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int64_t, multiples); GET_ARGS_WITH_DESC(int64_t, multiples);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunImpl();
protected: protected:
int64_t axis, multiple, rows, cols; Tensor X_dims_, X_strides_, Y_dims_;
Tensor* dst, *src, nav; DECLARE_ARGS_WITH_DESC(int64_t, multiples);
vector<int64_t> y_dimsV;
Tensor x_dimsT, x_stridesT, y_dimsT;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, multiples);
}; };
template <class Context> template <class Context>
...@@ -42,21 +39,21 @@ class TileGradientOp final : public Operator<Context> { ...@@ -42,21 +39,21 @@ class TileGradientOp final : public Operator<Context> {
public: public:
TileGradientOp(const OperatorDef& def, Workspace* ws) TileGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int64_t, multiples); GET_ARGS_WITH_DESC(int64_t, multiples);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunImpl();
protected: protected:
int64_t axis, multiple, rows, cols; Tensor* dst_, * src_, nav_;
Tensor* dst, *src, nav; int64_t axis_, multiple_, rows_, cols_;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, multiples); DECLARE_ARGS_WITH_DESC(int64_t, multiples);
}; };
DEFINE_ARGUMENTS_WITH_DESC(int64_t, TileOp, multiples); DEFINE_ARGS_WITH_DESC(int64_t, TileOp, multiples);
DEFINE_ARGUMENTS_WITH_DESC(int64_t, TileGradientOp, multiples); DEFINE_ARGS_WITH_DESC(int64_t, TileGradientOp, multiples);
} // namespace dragon } // namespace dragon
......
...@@ -22,16 +22,16 @@ class TransposeOp final: public Operator<Context> { ...@@ -22,16 +22,16 @@ class TransposeOp final: public Operator<Context> {
public: public:
TransposeOp(const OperatorDef& def, Workspace* ws) TransposeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int64_t, perm); GET_ARGS_WITH_DESC(int64_t, perm);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
Tensor x_strides, y_dims; Tensor X_strides_, Y_dims_;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, perm); DECLARE_ARGS_WITH_DESC(int64_t, perm);
}; };
template <class Context> template <class Context>
...@@ -39,20 +39,20 @@ class TransposeGradientOp final : public Operator<Context> { ...@@ -39,20 +39,20 @@ class TransposeGradientOp final : public Operator<Context> {
public: public:
TransposeGradientOp(const OperatorDef& def, Workspace* ws) TransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int64_t, perm); GET_ARGS_WITH_DESC(int64_t, perm);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
Tensor x_strides, y_dims; Tensor X_strides_, Y_dims_;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, perm); DECLARE_ARGS_WITH_DESC(int64_t, perm);
}; };
DEFINE_ARGUMENTS_WITH_DESC(int64_t, TransposeOp, perm); DEFINE_ARGS_WITH_DESC(int64_t, TransposeOp, perm);
DEFINE_ARGUMENTS_WITH_DESC(int64_t, TransposeGradientOp, perm); DEFINE_ARGS_WITH_DESC(int64_t, TransposeGradientOp, perm);
} // namespace dragon } // namespace dragon
......
...@@ -22,24 +22,24 @@ class AssignOp final : public Operator<Context> { ...@@ -22,24 +22,24 @@ class AssignOp final : public Operator<Context> {
public: public:
AssignOp(const OperatorDef& def, Workspace* ws) AssignOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int64_t, starts); GET_ARGS_WITH_DESC(int64_t, starts);
GET_ARGUMENTS_WITH_DESC(int64_t, sizes); GET_ARGS_WITH_DESC(int64_t, sizes);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
vector<int64_t> st, ed, x_dimsV; vec64_t st_, ed_;
Tensor startsT, y_stridesT, x_dimsT, fake_x; Tensor X_, X_starts_, Y_strides_, X_dims_;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, starts); DECLARE_ARGS_WITH_DESC(int64_t, starts);
DECLARE_ARGUMENTS_WITH_DESC(int64_t, sizes); DECLARE_ARGS_WITH_DESC(int64_t, sizes);
}; };
DEFINE_ARGUMENTS_WITH_DESC(int64_t, AssignOp, starts); DEFINE_ARGS_WITH_DESC(int64_t, AssignOp, starts);
DEFINE_ARGUMENTS_WITH_DESC(int64_t, AssignOp, sizes); DEFINE_ARGS_WITH_DESC(int64_t, AssignOp, sizes);
} // namespace dragon } // namespace dragon
......
...@@ -22,21 +22,20 @@ class CompareOp final : public Operator<Context> { ...@@ -22,21 +22,20 @@ class CompareOp final : public Operator<Context> {
public: public:
CompareOp(const OperatorDef& def, Workspace* ws) CompareOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
operation(OperatorBase::Arg<string>("operation", "NONE")), op_str_(OpArg<string>("operation", "NONE")),
to_uint8(OperatorBase::Arg<bool>("to_uint8", false)) {} to_uint8_(OpArg<bool>("to_uint8", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EqualRunWithType(); template <typename T> void EqualRunImpl();
template <typename T> void LessRunWithType(); template <typename T> void LessRunImpl();
template <typename T> void LessEqualRunWithType(); template <typename T> void LessEqualRunImpl();
template <typename T> void GreaterRunWithType(); template <typename T> void GreaterRunImpl();
template <typename T> void GreaterEqualRunWithType(); template <typename T> void GreaterEqualRunImpl();
protected: protected:
string operation; string op_str_;
bool to_uint8; bool to_uint8_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,11 +20,11 @@ namespace dragon { ...@@ -20,11 +20,11 @@ namespace dragon {
template <class Context> template <class Context>
class CopyOp final : public Operator<Context> { class CopyOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(CopyOp); SIMPLE_CTOR_DTOR(CopyOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
} // namespace dragon } // namespace dragon
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ScanOp final: public Operator<Context> {
public:
ScanOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)),
nsteps(OperatorBase::Arg<int64_t>("nsteps", 0)),
step_type(OperatorBase::Arg<string>("step_type", "Static")),
step_tensor(OperatorBase::Arg<string>("step_tensor", "")),
nseqs(OperatorBase::Arg<int64_t>("nseqs", 0)),
default_outputs(OperatorBase::Args<string>("default_outputs")),
nout((int)default_outputs.size()),
debug_mode(OperatorBase::Arg<bool>("debug_mode", false)) {
InitTemplate();
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void InitTemplate();
void UnrollTemplate();
void UpdateTerms(int cur_step);
protected:
GraphDef func_def, template_def, new_def;
Map<int, unique_ptr<Graph>> graphs;
Graph* cur_graph;
Map<string, string> terms;
vector<string> default_outputs;
int64_t axis, nseqs, nsteps, nrepeats, nout;
string step_type, step_tensor;
bool debug_mode;
};
template <class Context>
class ScanGradientOp final: public Operator<Context> {
public:
ScanGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)),
nsteps(OperatorBase::Arg<int64_t>("nsteps", 0)),
step_type(OperatorBase::Arg<string>("step_type", "Static")),
step_tensor(OperatorBase::Arg<string>("step_tensor", "")),
forward_inputs(OperatorBase::Args<string>("inputs_name")),
forward_outputs(OperatorBase::Args<string>("outputs_name")) {
// handle GO(x)
for (int i = 0; i < forward_outputs.size(); i++)
terms[forward_outputs[i] + "_grad"] = Input(i + (int)OutputSize()).name();
// handle GI(x)
for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = Output(i)->name();
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void MakeOps(const GraphDef& forward_def, GraphDef& new_def);
protected:
Map<string, string> terms;
Map<int, unique_ptr<Graph>> graphs;
vector<string> forward_inputs, forward_outputs;
Graph* cur_graph;
int64_t axis, nsteps;
string step_type, step_tensor;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
\ No newline at end of file
...@@ -20,24 +20,24 @@ namespace dragon { ...@@ -20,24 +20,24 @@ namespace dragon {
template <class Context> template <class Context>
class CTCLossOp final : public Operator<Context> { class CTCLossOp final : public Operator<Context> {
public: public:
CTCLossOp(const OperatorDef& def, Workspace* ws) CTCLossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
LOG(FATAL) << "CTCLoss requires CuDNN support."; LOG(FATAL) << "CTCLoss requires CuDNN support.";
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {} void RunOnDevice() override {}
}; };
template <class Context> template <class Context>
class CTCLossGradientOp final : public Operator<Context> { class CTCLossGradientOp final : public Operator<Context> {
public: public:
CTCLossGradientOp(const OperatorDef& def, Workspace* ws) CTCLossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {} : Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -49,36 +49,34 @@ class CuDNNCTCLossOp final : public Operator<Context> { ...@@ -49,36 +49,34 @@ class CuDNNCTCLossOp final : public Operator<Context> {
public: public:
CuDNNCTCLossOp(const OperatorDef& def, Workspace* ws) CuDNNCTCLossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
blank_first(OperatorBase::Arg<bool>("blank_first", true)), blank_first_(OpArg<bool>("blank_first", true)),
padding_mask(OperatorBase::Arg<int64_t>("padding_mask", -1)) { padding_mask_(OpArg<int64_t>("padding_mask", -1)) {
CUDNN_CHECK(cudnnCreateCTCLossDescriptor(&ctc_desc)); CuDNNCreateTensorDesc(&prob_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&prob_desc)); CuDNNCreateTensorDesc(&grad_desc_);
CUDNN_CHECK(cudnnCreateTensorDescriptor(&grad_desc)); ctc_algo_ = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
ctc_algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC; CUDNN_CHECK(cudnnCreateCTCLossDescriptor(&ctc_desc_));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CuDNNCTCLossOp() { ~CuDNNCTCLossOp() {
CUDNN_CHECK(cudnnDestroyCTCLossDescriptor(ctc_desc)); CuDNNDestroyTensorDesc(&prob_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(prob_desc)); CuDNNDestroyTensorDesc(&grad_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(grad_desc)); CUDNN_CHECK(cudnnDestroyCTCLossDescriptor(ctc_desc_));
} }
void Reshape();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
void WrapIO();
protected: protected:
bool blank_first; bool blank_first_;
int64_t padding_mask; int64_t padding_mask_;
size_t workspace_size_;
cudnnCTCLossAlgo_t ctc_algo; cudnnCTCLossAlgo_t ctc_algo_;
cudnnCTCLossDescriptor_t ctc_desc; cudnnCTCLossDescriptor_t ctc_desc_;
cudnnTensorDescriptor_t prob_desc, grad_desc; cudnnTensorDescriptor_t prob_desc_, grad_desc_;
size_t workspace_size; vec32_t packed_labels_, label_lengths_, input_lengths_;
vector<int> packed_labels, label_lengths, input_lengths;
}; };
#endif #endif
......
...@@ -22,18 +22,18 @@ class L1LossOp final : public Operator<Context> { ...@@ -22,18 +22,18 @@ class L1LossOp final : public Operator<Context> {
public: public:
L1LossOp(const OperatorDef& def, Workspace* ws) L1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.f)), scale_(OpArg<float>(
normalization(OperatorBase::Arg<string>( "scale", 1.f)),
"normalization", "BATCH_SIZE")) {} reduction_(OpArg<string>(
"reduction", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float scale; float scale_;
Tensor* diff; string reduction_;
string normalization;
}; };
template <class Context> template <class Context>
...@@ -41,18 +41,18 @@ class L1LossGradientOp final : public Operator<Context> { ...@@ -41,18 +41,18 @@ class L1LossGradientOp final : public Operator<Context> {
public: public:
L1LossGradientOp(const OperatorDef& def, Workspace* ws) L1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.f)), scale_(OpArg<float>(
normalization(OperatorBase::Arg<string>( "scale", 1.f)),
"normalization", "BATCH_SIZE")) {} reduction_(OpArg<string>(
"reduction", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float scale; float scale_;
Tensor* diff; string reduction_;
string normalization;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,18 +22,18 @@ class L2LossOp final : public Operator<Context> { ...@@ -22,18 +22,18 @@ class L2LossOp final : public Operator<Context> {
public: public:
L2LossOp(const OperatorDef& def, Workspace* ws) L2LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.f)), scale_(OpArg<float>(
normalization(OperatorBase::Arg<string>( "scale", 1.f)),
"normalization", "BATCH_SIZE")) {} reduction_(OpArg<string>(
"reduction", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float scale; float scale_;
Tensor* diff; string reduction_;
string normalization;
}; };
template <class Context> template <class Context>
...@@ -41,18 +41,18 @@ class L2LossGradientOp final : public Operator<Context> { ...@@ -41,18 +41,18 @@ class L2LossGradientOp final : public Operator<Context> {
public: public:
L2LossGradientOp(const OperatorDef& def, Workspace* ws) L2LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.f)), scale_(OpArg<float>(
normalization(OperatorBase::Arg<string>( "scale", 1.f)),
"normalization", "BATCH_SIZE")) {} reduction_(OpArg<string>(
"reduction", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float scale; float scale_;
Tensor* diff; string reduction_;
string normalization;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -18,59 +18,59 @@ ...@@ -18,59 +18,59 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class NLLLossOp : public Operator<Context> { class NLLLossOp final : public Operator<Context> {
public: public:
NLLLossOp( NLLLossOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "VALID")) { "reduction", "VALID")) {
auto xs = OperatorBase::Args<int64_t>("ignore_labels"); auto ivec = OpArgs<int64_t>("ignore_labels");
if (xs.size()) { if (!ivec.empty()) {
ignores.Reshape({ (int64_t)xs.size() }); ignore_.Reshape({ (int64_t)ivec.size() });
auto* Idata = ignores.mutable_data<int, CPUContext>(); auto* idata = ignore_.mutable_data<int, CPUContext>();
for (int i = 0; i < xs.size(); i++) Idata[i] = xs[i]; for (int i = 0; i < ivec.size(); i++) idata[i] = ivec[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; string reduction_;
Tensor losses, flags, ignores; Tensor loss_, flag_, ignore_;
string normalization; int64_t axis_, outer_dim_, inner_dim_;
}; };
template <class Context> template <class Context>
class NLLLossGradientOp : public Operator<Context> { class NLLLossGradientOp final : public Operator<Context> {
public: public:
NLLLossGradientOp( NLLLossGradientOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "VALID")) { "reduction", "VALID")) {
auto xs = OperatorBase::Args<int64_t>("ignore_labels"); auto ivec = OpArgs<int64_t>("ignore_labels");
if (xs.size()) { if (!ivec.empty()) {
ignores.Reshape({ (int64_t)xs.size() }); ignore_.Reshape({ (int64_t)ivec.size() });
auto* Idata = ignores.mutable_data<int, CPUContext>(); auto* idata = ignore_.mutable_data<int, CPUContext>();
for (int i = 0; i < xs.size(); i++) Idata[i] = xs[i]; for (int i = 0; i < ivec.size(); i++) idata[i] = ivec[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; string reduction_;
Tensor ignores, flags; Tensor ignore_, flag_;
string normalization; int64_t axis_, outer_dim_, inner_dim_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -18,43 +18,43 @@ ...@@ -18,43 +18,43 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SigmoidCrossEntropyOp class SigmoidCrossEntropyOp final
final : public Operator<Context> { : public Operator<Context> {
public: public:
SigmoidCrossEntropyOp( SigmoidCrossEntropyOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "VALID")) {} "reduction", "VALID")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
Tensor losses, flags; string reduction_;
string normalization; Tensor loss_, flag_;
}; };
template <class Context> template <class Context>
class SigmoidCrossEntropyGradientOp class SigmoidCrossEntropyGradientOp final
final : public Operator<Context> { : public Operator<Context> {
public: public:
SigmoidCrossEntropyGradientOp( SigmoidCrossEntropyGradientOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "VALID")) {} "reduction", "VALID")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
Tensor flags; Tensor flag_;
string normalization; string reduction_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -18,61 +18,61 @@ ...@@ -18,61 +18,61 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SigmoidFocalLossOp class SigmoidFocalLossOp final
final : public Operator<Context> { : public Operator<Context> {
public: public:
SigmoidFocalLossOp( SigmoidFocalLossOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( neg_id_(OpArg<int64_t>("neg_id", 0)),
"normalization", "VALID")), alpha_(OpArg<float>("alpha", 0.25f)),
alpha(OperatorBase::Arg<float>("alpha", 0.25f)), gamma_(OpArg<float>("gamma", 2.f)),
gamma(OperatorBase::Arg<float>("gamma", 2.f)), reduction_(OpArg<string>(
neg_id(OperatorBase::Arg<int64_t>("neg_id", 0)) { "reduction", "VALID")) {
pos_alpha = alpha; pos_alpha_ = alpha_;
neg_alpha = 1.f - alpha; neg_alpha_ = 1.f - alpha_;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
float alpha, gamma, pos_alpha, neg_alpha; string reduction_;
int64_t axis, neg_id, outer_dim, axis_dim, inner_dim; Tensor loss_, flag_;
Tensor losses, flags; float alpha_, gamma_, pos_alpha_, neg_alpha_;
string normalization; int64_t axis_, neg_id_, outer_dim_, inner_dim_;
}; };
template <class Context> template <class Context>
class SigmoidFocalLossGradientOp class SigmoidFocalLossGradientOp final
final : public Operator<Context> { : public Operator<Context> {
public: public:
SigmoidFocalLossGradientOp( SigmoidFocalLossGradientOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( neg_id_(OpArg<int64_t>("neg_id", 0)),
"normalization", "VALID")), alpha_(OpArg<float>("alpha", 0.25f)),
alpha(OperatorBase::Arg<float>("alpha", 0.25f)), gamma_(OpArg<float>("gamma", 2.f)),
gamma(OperatorBase::Arg<float>("gamma", 2.f)), reduction_(OpArg<string>(
neg_id(OperatorBase::Arg<int64_t>("neg_id", 0)) { "reduction", "VALID")) {
pos_alpha = alpha; pos_alpha_ = alpha_;
neg_alpha = 1.f - alpha; neg_alpha_ = 1.f - alpha_;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
float alpha, gamma, pos_alpha, neg_alpha; Tensor flag_;
int64_t axis, neg_id, outer_dim, axis_dim, inner_dim; string reduction_;
Tensor flags; float alpha_, gamma_, pos_alpha_, neg_alpha_;
string normalization; int64_t axis_, neg_id_, outer_dim_, inner_dim_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -18,41 +18,42 @@ ...@@ -18,41 +18,42 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SmoothL1LossOp final : public Operator<Context> { class SmoothL1LossOp final
: public Operator<Context> {
public: public:
SmoothL1LossOp(const OperatorDef& def, Workspace* ws) SmoothL1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.f)), beta_(OpArg<float>("beta", 1.f)),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "BATCH_SIZE")) {} "reduction", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float beta; float beta_;
Tensor* diff, *error; string reduction_;
string normalization; };
};
template <class Context> template <class Context>
class SmoothL1LossGradientOp final : public Operator<Context> { class SmoothL1LossGradientOp final
: public Operator<Context> {
public: public:
SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws) SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.f)), beta_(OpArg<float>(
normalization(OperatorBase::Arg<string>( "beta", 1.f)),
"normalization", "BATCH_SIZE")) {} reduction_(OpArg<string>(
"reduction", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float beta; float beta_;
Tensor* diff; string reduction_;
string normalization;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -18,50 +18,48 @@ ...@@ -18,50 +18,48 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SoftmaxCrossEntropyOp class SoftmaxCrossEntropyOp final
final : public Operator<Context> { : public Operator<Context> {
public: public:
SoftmaxCrossEntropyOp( SoftmaxCrossEntropyOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "FULL")) {} "reduction", "MEAN")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void SoftmaxRun(); void SoftmaxRun();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; Tensor loss_;
Tensor losses; string reduction_;
Tensor* prob; int64_t axis_, outer_dim_, inner_dim_;
unique_ptr<OperatorBase> softmax_op; unique_ptr<OperatorBase> softmax_op_;
string normalization;
}; };
template <class Context> template <class Context>
class SoftmaxCrossEntropyGradientOp class SoftmaxCrossEntropyGradientOp final
final : public Operator<Context> { : public Operator<Context> {
public: public:
SoftmaxCrossEntropyGradientOp( SoftmaxCrossEntropyGradientOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "FULL")) {} "reduction", "MEAN")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; string reduction_;
Tensor* prob; int64_t axis_, outer_dim_, inner_dim_;
string normalization;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -18,61 +18,60 @@ ...@@ -18,61 +18,60 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SoftmaxFocalLossOp class SoftmaxFocalLossOp final :
final : public SparseSoftmaxCrossEntropyOp<Context> { public SparseSoftmaxCrossEntropyOp<Context> {
public: public:
SoftmaxFocalLossOp( SoftmaxFocalLossOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(def, ws), : SparseSoftmaxCrossEntropyOp<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( neg_id_(OpArg<int64_t>("neg_id", 0)),
"normalization", "VALID")), alpha_(OpArg<float>("alpha", 0.25f)),
alpha(OperatorBase::Arg<float>("alpha", 0.25f)), gamma_(OpArg<float>("gamma", 2.f)),
gamma(OperatorBase::Arg<float>("gamma", 2.f)), reduction_(OpArg<string>(
neg_id(OperatorBase::Arg<int64_t>("neg_id", 0)) { "reduction", "VALID")) {
pos_alpha = alpha; pos_alpha_ = alpha_;
neg_alpha = 1.f - alpha; neg_alpha_ = 1.f - alpha_;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
float alpha, gamma, pos_alpha, neg_alpha; string reduction_;
int64_t axis, neg_id, outer_dim, inner_dim; Tensor loss_, flag_;
Tensor losses, flags; float alpha_, gamma_, pos_alpha_, neg_alpha_;
string normalization; int64_t axis_, neg_id_, outer_dim_, inner_dim_;
}; };
template <class Context> template <class Context>
class SoftmaxFocalLossGradientOp class SoftmaxFocalLossGradientOp final
final : public SparseSoftmaxCrossEntropyGradientOp<Context> { : public SparseSoftmaxCrossEntropyGradientOp<Context> {
public: public:
SoftmaxFocalLossGradientOp( SoftmaxFocalLossGradientOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: SparseSoftmaxCrossEntropyGradientOp<Context>(def, ws), : SparseSoftmaxCrossEntropyGradientOp<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( neg_id_(OpArg<int64_t>("neg_id", 0)),
"normalization", "VALID")), alpha_(OpArg<float>("alpha", 0.25f)),
alpha(OperatorBase::Arg<float>("alpha", 0.25f)), gamma_(OpArg<float>("gamma", 2.f)),
gamma(OperatorBase::Arg<float>("gamma", 2.f)), reduction_(OpArg<string>("reduction", "VALID")) {
neg_id(OperatorBase::Arg<int64_t>("neg_id", 0)) { pos_alpha_ = alpha_;
pos_alpha = alpha; neg_alpha_ = 1.f - alpha_;
neg_alpha = 1.f - alpha;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
float alpha, gamma, pos_alpha, neg_alpha; Tensor flag_;
int64_t axis, neg_id, outer_dim, inner_dim; string reduction_;
Tensor flags; float alpha_, gamma_, pos_alpha_, neg_alpha_;
string normalization; int64_t axis_, neg_id_, outer_dim_, inner_dim_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -18,20 +18,21 @@ ...@@ -18,20 +18,21 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SparseSoftmaxCrossEntropyOp : public Operator<Context> { class SparseSoftmaxCrossEntropyOp
: public Operator<Context> {
public: public:
SparseSoftmaxCrossEntropyOp( SparseSoftmaxCrossEntropyOp(
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "VALID")) { "reduction", "VALID")) {
auto xs = OperatorBase::Args<int64_t>("ignore_labels"); auto ivec = OpArgs<int64_t>("ignore_labels");
if (xs.size()) { if (!ivec.empty()) {
ignores.Reshape({ (int64_t)xs.size() }); ignore_.Reshape({ (int64_t)ivec.size() });
auto* Idata = ignores.mutable_data<int, CPUContext>(); auto* x = ignore_.mutable_data<int, CPUContext>();
for (int i = 0; i < xs.size(); i++) Idata[i] = xs[i]; for (int i = 0; i < ivec.size(); i++) x[i] = ivec[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -39,13 +40,13 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -39,13 +40,13 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
void SoftmaxRun(); void SoftmaxRun();
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; string reduction_;
Tensor* prob, losses, flags, ignores; Tensor loss_, flag_, ignore_;
unique_ptr<OperatorBase> softmax_op; int64_t axis_, outer_dim_, inner_dim_;
string normalization; unique_ptr<OperatorBase> softmax_op_;
}; };
template <class Context> template <class Context>
...@@ -56,25 +57,25 @@ class SparseSoftmaxCrossEntropyGradientOp ...@@ -56,25 +57,25 @@ class SparseSoftmaxCrossEntropyGradientOp
const OperatorDef& def, const OperatorDef& def,
Workspace* ws) Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 1)), axis_(OpArg<int64_t>("axis", 1)),
normalization(OperatorBase::Arg<string>( reduction_(OpArg<string>(
"normalization", "VALID")) { "reduction", "VALID")) {
auto xs = OperatorBase::Args<int64_t>("ignore_labels"); auto ivec = OpArgs<int64_t>("ignore_labels");
if (xs.size()) { if (!ivec.empty()) {
ignores.Reshape({ (int64_t)xs.size() }); ignore_.Reshape({ (int64_t)ivec.size() });
auto* Idata = ignores.mutable_data<int, CPUContext>(); auto* x = ignore_.mutable_data<int, CPUContext>();
for (int i = 0; i < xs.size(); i++) Idata[i] = xs[i]; for (int i = 0; i < ivec.size(); i++) x[i] = ivec[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
int64_t axis, outer_dim, inner_dim; string reduction_;
Tensor* prob, ignores, flags; Tensor ignore_, flag_;
string normalization; int64_t axis_, outer_dim_, inner_dim_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,23 +22,25 @@ class AccuracyOp final : public Operator<Context> { ...@@ -22,23 +22,25 @@ class AccuracyOp final : public Operator<Context> {
public: public:
AccuracyOp(const OperatorDef& def, Workspace* ws) AccuracyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
top_k(OperatorBase::Arg<int64_t>("top_k", 1)), axis_(OpArg<int64_t>("axis", 1)),
axis(OperatorBase::Arg<int64_t>("axis", 1)) { top_k_(OpArg<int64_t>("top_k", 1)) {
auto ignores = OperatorBase::Args<int>("ignore_labels"); auto ivec = OpArgs<int64_t>("ignore_labels");
if (ignores.size()) { if (!ivec.empty()) {
ignore.Reshape({ (int64_t)ignores.size() }); ignore_.Reshape({ (int64_t)ivec.size() });
auto* Idata = ignore.mutable_data<int, CPUContext>(); auto* x = ignore_.mutable_data<int, CPUContext>();
for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i]; for (int i = 0; i < ivec.size(); i++) x[i] = ivec[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunImpl();
protected: protected:
int64_t top_k, axis, outer_dim, inner_dim, num_classes; Tensor ignore_;
Tensor ignore; CPUContext cctx_;
int64_t outer_dim_, inner_dim_;
int64_t axis_, axis_dim_, top_k_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,27 +22,22 @@ class CastOp final : public Operator<Context> { ...@@ -22,27 +22,22 @@ class CastOp final : public Operator<Context> {
public: public:
CastOp(const OperatorDef& def, Workspace* ws) CastOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
dtype(OperatorBase::Arg<string>("dtype", "float32")), inplace_(OpArg<int64_t>("inplace", 0)) {}
inplace(OperatorBase::Arg<bool>("inplace", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
protected: protected:
string dtype; int64_t inplace_;
bool inplace;
}; };
template <class Context> template <class Context>
class CastGradientOp final : public Operator<Context> { class CastGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(CastGradientOp); SIMPLE_CTOR_DTOR(CastGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
protected:
string dtype;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,14 +22,14 @@ class GradientGenerateOp final: public Operator<Context> { ...@@ -22,14 +22,14 @@ class GradientGenerateOp final: public Operator<Context> {
public: public:
GradientGenerateOp(const OperatorDef& def, Workspace* ws) GradientGenerateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
defaults(OperatorBase::Args<float>("defaults")) { defaults(OpArgs<float>("defaults")) {
CHECK_EQ(InputSize(), OutputSize()); CHECK_EQ(XSize(), YSize());
CHECK_EQ(defaults.size(), OutputSize()); CHECK_EQ(defaults.size(), YSize());
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
vector<float> defaults; vector<float> defaults;
...@@ -40,8 +40,8 @@ class GradientGatherOp final : public Operator<Context> { ...@@ -40,8 +40,8 @@ class GradientGatherOp final : public Operator<Context> {
public: public:
GradientGatherOp(const OperatorDef& def, Workspace* ws) GradientGatherOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) { : Operator<Context>(def, ws) {
for (int i = 0; i < InputSize(); i++) { for (int i = 0; i < XSize(); i++) {
if (Input(i).name() != "NULL") { if (X(i).name() != "NULL") {
indices.push_back(i); indices.push_back(i);
} }
} }
...@@ -49,26 +49,26 @@ class GradientGatherOp final : public Operator<Context> { ...@@ -49,26 +49,26 @@ class GradientGatherOp final : public Operator<Context> {
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
vector<int> indices; vec32_t indices;
}; };
template <class Context> template <class Context>
class GradientAddOp final : public Operator<Context> { class GradientAddOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(GradientAddOp); SIMPLE_CTOR_DTOR(GradientAddOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class StopGradientOp final : public Operator<Context> { class StopGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(StopGradientOp); SIMPLE_CTOR_DTOR(StopGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -22,35 +22,32 @@ class ImageDataOp final : public Operator<Context> { ...@@ -22,35 +22,32 @@ class ImageDataOp final : public Operator<Context> {
public: public:
ImageDataOp(const OperatorDef& def, Workspace* ws) ImageDataOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
dtype(OperatorBase::Arg<string>("dtype", "float32")), mean_vec_(OpArgs<float>("mean_values")),
mean_values(OperatorBase::Args<float>("mean_values")), std_vec_(OpArgs<float>("std_values")) {
std_values(OperatorBase::Args<float>("std_values")), if (mean_vec_.size() > 0) {
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) { CHECK_EQ((int)mean_vec_.size(), 3);
if (mean_values.size() > 0) { auto* mean = mean_.Reshape({ 3 })
CHECK_EQ((int)mean_values.size(), 3) ->mutable_data<float, CPUContext>();
<< "The mean values should be a list with length 3."; for (int i = 0; i < 3; ++i) mean[i] = mean_vec_[i];
mean.Reshape({ 3 });
for (int i = 0; i < 3; i++)
mean.mutable_data<float, CPUContext>()[i] = mean_values[i];
} }
if (std_values.size() > 0) { if (std_vec_.size() > 0) {
CHECK_EQ((int)std_values.size(), 3) CHECK_EQ((int)std_vec_.size(), 3);
<< "The std values should be a list with length 3."; auto* std = std_.Reshape({ 3 })
std.Reshape({ 3 }); ->mutable_data<float, CPUContext>();
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; ++i) std[i] = std_vec_[i];
std.mutable_data<float, CPUContext>()[i] = std_values[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
template <typename Tx, typename Ty>
void RunImpl();
protected: protected:
string dtype, data_format; Tensor mean_, std_;
vector<float> mean_values, std_values; int64_t n_, c_, h_, w_;
int64_t n, c, h, w; vector<float> mean_vec_, std_vec_;
Tensor mean, std;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -23,19 +23,18 @@ class InitializeOp : public Operator<Context> { ...@@ -23,19 +23,18 @@ class InitializeOp : public Operator<Context> {
public: public:
InitializeOp(const OperatorDef& def, Workspace* ws) InitializeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
shape_desc(OperatorBase::Arg<string>("shape", "")), shape_desc_(OpArg<string>("shape", "")) {
dtype(OperatorBase::Arg<string>("dtype", "float32")) { GET_ARGS_WITH_DESC(int64_t, dims);
GET_ARGUMENTS_WITH_DESC(int64_t, dims);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
string shape_desc, dtype; string shape_desc_;
TensorFillerProto filler_proto; TensorFillerProto proto_;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, dims); DECLARE_ARGS_WITH_DESC(int64_t, dims);
}; };
template <class Context> template <class Context>
...@@ -43,85 +42,92 @@ class FillOp final : public Operator<Context> { ...@@ -43,85 +42,92 @@ class FillOp final : public Operator<Context> {
public: public:
FillOp(const OperatorDef& def, Workspace* ws) FillOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
shape_desc(OperatorBase::Arg<string>("shape", "")), shape_desc_(OpArg<string>("shape", "")),
dtype(OperatorBase::Arg<string>("dtype", "float32")), value_(OpArg<float>("value", 0.f)) {
value(OperatorBase::Arg<float>("value", 0.f)) { GET_ARGS_WITH_DESC(int64_t, dims);
GET_ARGUMENTS_WITH_DESC(int64_t, dims);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
protected: protected:
float value; float value_;
string shape_desc, dtype; string shape_desc_;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, dims); DECLARE_ARGS_WITH_DESC(int64_t, dims);
}; };
namespace {
template<typename T>
struct TypeIdentity { typedef T type; };
} // namespace
template <class Context> template <class Context>
class GivenTensorFillOp final : public Operator<Context> { class GivenTensorFillOp final : public Operator<Context> {
public: public:
GivenTensorFillOp(const OperatorDef& def, Workspace* ws) GivenTensorFillOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
shape(OperatorBase::Args<int64_t>("shape")), shape_(OpArgs<int64_t>("shape")) {
dtype(OperatorBase::Arg<string>("dtype", "float32")) { GET_ARGS_WITH_DESC(int64_t, dims);
GET_ARGUMENTS_WITH_DESC(int64_t, dims);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
template<typename T>
struct TypeIdentity { typedef T type; };
template <typename T> template <typename T>
void ExtractValues() { ExtractValuesImpl(TypeIdentity<T>()); } void Extract() { ExtractImpl(TypeIdentity<T>()); }
template <typename T> void ExtractValuesImpl(TypeIdentity<T>) { template <typename T> void ExtractImpl(TypeIdentity<T>) {
auto source_values = OperatorBase::Args<T>("values"); auto raw_values = OpArgs<T>("values");
auto num_values = (int64_t)source_values.size(); auto nelements = (int64_t)raw_values.size();
values.Reshape(vector<int64_t>({ num_values })); auto nbytes = nelements * sizeof(T);
auto* Vdata = values.template mutable_data<T, CPUContext>(); auto* values = values_.Reshape({ nelements })
memcpy(Vdata, source_values.data(), num_values * sizeof(T)); ->template mutable_data<T, CPUContext>();
memcpy(values, raw_values.data(), nbytes);
} }
void ExtractValuesImpl(TypeIdentity<float16>) { void ExtractImpl(TypeIdentity<float16>) {
auto source_values = OperatorBase::Args<float>("values"); auto raw_values = OpArgs<float>("values");
auto num_values = (int64_t)source_values.size(); auto nelements = (int64_t)raw_values.size();
values.Reshape(vector<int64_t>({ num_values })); auto nbytes = nelements * sizeof(float16);
auto* Vdata = values.template mutable_data<float16, CPUContext>(); auto* values = values_.Reshape({ nelements })
memcpy(Vdata, source_values.data(), num_values * sizeof(float16)); ->template mutable_data<float16, CPUContext>();
memcpy(values, raw_values.data(), nbytes);
} }
protected: protected:
string dtype; Tensor values_;
vector<int64_t> shape; vector<int64_t> shape_;
Tensor values; DECLARE_ARGS_WITH_DESC(int64_t, dims);
DECLARE_ARGUMENTS_WITH_DESC(int64_t, dims);
}; };
template <class Context> template <class Context>
class RandomUniformOp final : public InitializeOp<Context> { class RandomUniformOp final : public InitializeOp<Context> {
public: public:
RandomUniformOp(const OperatorDef& def, Workspace* ws) RandomUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler_proto.set_type("uniform"); auto low = OpArg<float>("low", -1.f);
this->filler_proto.set_low(OperatorBase::Arg<float>("low", -1.f)); auto high = OpArg<float>("high", 1.f);
this->filler_proto.set_high(OperatorBase::Arg<float>("high", 1.f)); this->proto_.set_low(low);
this->proto_.set_high(high);
this->proto_.set_type("uniform");
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
class RandomNormalOp final : public InitializeOp<Context> { class RandomNormalOp final : public InitializeOp<Context> {
public: public:
RandomNormalOp(const OperatorDef& def, Workspace* ws) RandomNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler_proto.set_type("normal"); auto mu = OpArg<float>("mean", 0.f);
this->filler_proto.set_mean(OperatorBase::Arg<float>("mean", 0.f)); auto sigma = OpArg<float>("std", 1.f);
this->filler_proto.set_std(OperatorBase::Arg<float>("std", 1.f)); this->proto_.set_mean(mu);
this->proto_.set_std(sigma);
this->proto_.set_type("normal");
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
...@@ -131,66 +137,66 @@ class TruncatedNormalOp final : public InitializeOp<Context> { ...@@ -131,66 +137,66 @@ class TruncatedNormalOp final : public InitializeOp<Context> {
public: public:
TruncatedNormalOp(const OperatorDef& def, Workspace* ws) TruncatedNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
this->filler_proto.set_type("truncated_normal"); auto mu = OpArg<float>("mean", 0.f);
float mu = OperatorBase::Arg<float>("mean", 0.f); auto sigma = OpArg<float>("std", 1.f);
float sigma = OperatorBase::Arg<float>("std", 1.f); this->proto_.set_mean(mu);
this->filler_proto.set_mean(mu); this->proto_.set_std(sigma);
this->filler_proto.set_std(sigma); this->proto_.set_low(mu - 2 * sigma);
this->filler_proto.set_low(mu - 2 * sigma); this->proto_.set_high(mu + 2 * sigma);
this->filler_proto.set_high(mu + 2 * sigma); this->proto_.set_type("truncated_normal");
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
class GlorotUniformOp final : public InitializeOp<Context> { class GlorotUniformOp final : public InitializeOp<Context> {
public: public:
GlorotUniformOp(const OperatorDef& def, Workspace* ws) GlorotUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in"); auto scale = OpArg<float>("scale", 3.f);
float scale = OperatorBase::Arg<float>("scale", 3.f); auto mode = OpArg<string>("mode", "fan_in");
this->filler_proto.set_type("xavier"); this->proto_.set_type("xavier");
if (mode == "fan_avg") { if (mode == "fan_avg") {
this->filler_proto.set_variance_norm( this->proto_.set_variance_norm(
TensorFillerProto_VarianceNorm_FAN_AVG); TensorFillerProto_VarianceNorm_FAN_AVG);
} else if (mode == "fan_out") { } else if (mode == "fan_out") {
this->filler_proto.set_variance_norm( this->proto_.set_variance_norm(
TensorFillerProto_VarianceNorm_FAN_OUT); TensorFillerProto_VarianceNorm_FAN_OUT);
} else { } else {
this->filler_proto.set_variance_norm( this->proto_.set_variance_norm(
TensorFillerProto_VarianceNorm_FAN_IN); TensorFillerProto_VarianceNorm_FAN_IN);
} }
this->filler_proto.set_scale(scale); this->proto_.set_scale(scale);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
class GlorotNormalOp final : public InitializeOp<Context> { class GlorotNormalOp final : public InitializeOp<Context> {
public: public:
GlorotNormalOp(const OperatorDef& def, Workspace* ws) GlorotNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) { : InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in"); auto scale = OpArg<float>("scale", 2.f);
float scale = OperatorBase::Arg<float>("scale", 2.f); auto mode = OpArg<string>("mode", "fan_in");
this->filler_proto.set_type("msra"); this->proto_.set_type("msra");
if (mode == "fan_avg") { if (mode == "fan_avg") {
this->filler_proto.set_variance_norm( this->proto_.set_variance_norm(
TensorFillerProto_VarianceNorm_FAN_AVG); TensorFillerProto_VarianceNorm_FAN_AVG);
} else if (mode == "fan_out") { } else if (mode == "fan_out") {
this->filler_proto.set_variance_norm( this->proto_.set_variance_norm(
TensorFillerProto_VarianceNorm_FAN_OUT); TensorFillerProto_VarianceNorm_FAN_OUT);
} else { } else {
this->filler_proto.set_variance_norm( this->proto_.set_variance_norm(
TensorFillerProto_VarianceNorm_FAN_IN); TensorFillerProto_VarianceNorm_FAN_IN);
} }
this->filler_proto.set_scale(scale); this->proto_.set_scale(scale);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int64_t, InitializeOp, dims); DEFINE_ARGS_WITH_DESC(int64_t, InitializeOp, dims);
DEFINE_ARGUMENTS_WITH_DESC(int64_t, FillOp, dims); DEFINE_ARGS_WITH_DESC(int64_t, FillOp, dims);
DEFINE_ARGUMENTS_WITH_DESC(int64_t, GivenTensorFillOp, dims); DEFINE_ARGS_WITH_DESC(int64_t, GivenTensorFillOp, dims);
} // namespace dragon } // namespace dragon
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#ifdef WITH_MPI
#include <mpi.h>
#include "core/operator.h"
namespace dragon {
template <class Context>
class ModelMPIBase : public Operator<Context> {
public:
ModelMPIBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
comm((MPI_Comm)OperatorBase::Arg<int64_t>("comm", 0)),
group((MPI_Group)OperatorBase::Arg<int64_t>("group", 0)) {
if (comm == MPI_COMM_NULL) return;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
MPI_Comm_size(comm, &comm_size);
MPI_Comm_rank(comm, &comm_rank);
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
int world_root = OperatorBase::Arg<int64_t>("root", 0);
MPI_Group_translate_ranks(world_group, 1, &world_root, group, &comm_root);
CHECK(comm_root != MPI_UNDEFINED)
<< "\nMPI root is not included in layer group.";
}
template <typename T>
MPI_Datatype mpi_dtype() {
auto dtype = TypeMetaToString(TypeMeta::Make<T>());
if (dtype == "int8") return MPI_CHAR;
else if (dtype == "uint8") return MPI_UNSIGNED_CHAR;
else if (dtype == "int32") return MPI_INT;
else if (dtype == "int64") return MPI_LONG_LONG;
else if (dtype == "float16") return MPI_UNSIGNED_SHORT;
else if (dtype == "float32") return MPI_FLOAT;
else if (dtype == "float64") return MPI_DOUBLE;
return MPI_DATATYPE_NULL;
}
public:
MPI_Comm comm;
MPI_Group group;
int comm_size, comm_rank, comm_root;
int world_size, world_rank;
};
#define USE_MODEL_MPI_FUNCTIONS \
using ModelMPIBase<Context>::comm; \
using ModelMPIBase<Context>::comm_size; \
using ModelMPIBase<Context>::comm_rank; \
using ModelMPIBase<Context>::comm_root;
} // namespace dragon
#endif // WITH_MPI
#endif // DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
\ No newline at end of file
...@@ -15,32 +15,32 @@ ...@@ -15,32 +15,32 @@
#ifdef WITH_MPI #ifdef WITH_MPI
#include "operators/mpi/base_mpi_op.h" #include "operators/mpi/mpi_op_base.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class MPIBroadcastOp final : public ModelMPIBase<Context> { class MPIBroadcastOp final : public MPIOpBase<Context> {
public: public:
MPIBroadcastOp(const OperatorDef& def, Workspace* ws) MPIBroadcastOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {} : MPIOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MODEL_MPI_FUNCTIONS; USE_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class MPIBroadcastGradientOp final : public ModelMPIBase<Context> { class MPIBroadcastGradientOp final : public MPIOpBase<Context> {
public: public:
MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws) MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {} : MPIOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MODEL_MPI_FUNCTIONS; USE_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
} // namespace dragon } // namespace dragon
......
...@@ -15,32 +15,32 @@ ...@@ -15,32 +15,32 @@
#ifdef WITH_MPI #ifdef WITH_MPI
#include "operators/mpi/base_mpi_op.h" #include "operators/mpi/mpi_op_base.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class MPIGatherOp final : public ModelMPIBase<Context> { class MPIGatherOp final : public MPIOpBase<Context> {
public: public:
MPIGatherOp(const OperatorDef& def, Workspace *ws) MPIGatherOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {} : MPIOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MODEL_MPI_FUNCTIONS; USE_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
template <class Context> template <class Context>
class MPIGatherGradientOp final : public ModelMPIBase<Context> { class MPIGatherGradientOp final : public MPIOpBase<Context> {
public: public:
MPIGatherGradientOp(const OperatorDef& def, Workspace *ws) MPIGatherGradientOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {} : MPIOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MODEL_MPI_FUNCTIONS; USE_MPI_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunImpl();
}; };
} // namespace dragon } // namespace dragon
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_MPI_MPI_OP_BASE_H_
#define DRAGON_OPERATORS_MPI_MPI_OP_BASE_H_
#ifdef WITH_MPI
#include <mpi.h>
#include "core/operator.h"
namespace dragon {
template <class Context>
class MPIOpBase : public Operator<Context> {
public:
MPIOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
comm_((MPI_Comm)OpArg<int64_t>("comm", 0)),
group_((MPI_Group)OpArg<int64_t>("group", 0)) {
if (comm_ == MPI_COMM_NULL) return;
MPI_Comm_size(MPI_COMM_WORLD, &world_size_);
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank_);
MPI_Comm_size(comm_, &comm_size_);
MPI_Comm_rank(comm_, &comm_rank_);
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
int world_root = OpArg<int64_t>("root", 0);
MPI_Group_translate_ranks(
world_group, 1, &world_root,
group_, &comm_root_
);
CHECK(comm_root_ != MPI_UNDEFINED)
<< "\nRoot is not included in the group.";
}
template <typename T>
MPI_Datatype mpi_dtype() {
auto type_str = TypeMetaToString(TypeMeta::Make<T>());
if (type_str == "uint8") return MPI_BYTE;
else if (type_str == "int8") return MPI_CHAR;
else if (type_str == "uint8") return MPI_UNSIGNED_CHAR;
else if (type_str == "int32") return MPI_INT;
else if (type_str == "int64") return MPI_LONG_LONG;
else if (type_str == "float16") return MPI_UNSIGNED_SHORT;
else if (type_str == "float32") return MPI_FLOAT;
else if (type_str == "float64") return MPI_DOUBLE;
return MPI_DATATYPE_NULL;
}
template <typename T>
void Recv(T* buf, int count, int from) {
MPI_Recv(
buf, count,
mpi_dtype<T>(),
from, 0, comm_,
MPI_STATUS_IGNORE
);
}
template <typename T>
void IRecv(
T* buf,
int count,
int from,
MPI_Request* req) {
MPI_Irecv(
buf, count,
mpi_dtype<T>(),
from, 0, comm_, req
);
}
template <typename T>
void Send(const T* buf, int count, int to) {
MPI_Send(
buf, count,
mpi_dtype<T>(),
to, 0, comm_
);
}
template <typename T>
void SendRecv(
const T* send_buf,
int send_count,
int to,
T* recv_buf,
int recv_count,
int from) {
MPI_Sendrecv(
send_buf,
send_count,
mpi_dtype<T>(), to, 0,
recv_buf, recv_count,
mpi_dtype<T>(), from, 0,
comm_,
MPI_STATUS_IGNORE
);
}
template <typename T>
void BCast(T* buf, int count) {
MPI_Bcast(
buf, count,
mpi_dtype<T>(),
comm_root_, comm_
);
}
public:
MPI_Comm comm_;
MPI_Group group_;
int world_size_, world_rank_;
int comm_size_, comm_rank_, comm_root_;
};
#define USE_MPI_FUNCTIONS \
using MPIOpBase<Context>::Recv; \
using MPIOpBase<Context>::IRecv; \
using MPIOpBase<Context>::Send; \
using MPIOpBase<Context>::SendRecv; \
using MPIOpBase<Context>::BCast; \
using MPIOpBase<Context>::comm_; \
using MPIOpBase<Context>::comm_size_; \
using MPIOpBase<Context>::comm_rank_; \
using MPIOpBase<Context>::comm_root_;
} // namespace dragon
#endif // WITH_MPI
#endif // DRAGON_OPERATORS_MPI_MPI_OP_BASE_H_
\ No newline at end of file
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!