Commit 5dea1524 by Ting PAN

Refactor Context/Stream & Add CTC Loss

1 parent 081258b9
Showing with 1399 additions and 1045 deletions
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#ifndef DRAGON_CORE_COMMON_H_ #ifndef DRAGON_CORE_COMMON_H_
#define DRAGON_CORE_COMMON_H_ #define DRAGON_CORE_COMMON_H_
#include <ctime>
#include <climits> #include <climits>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -48,7 +49,26 @@ using Map = std::unordered_map<Key, Value>; ...@@ -48,7 +49,26 @@ using Map = std::unordered_map<Key, Value>;
template <typename Value> template <typename Value>
using Set = std::unordered_set<Value> ; using Set = std::unordered_set<Value> ;
#define DRAGON_VERSION 2205 /*
* Define the Kernel version.
*
* | Major(2) | Minor(2) | Patch(06) |
*/
#define DRAGON_VERSION 2206
/*
* Define the default random seed.
*/
#define DEFAULT_RNG_SEED 3
/*
* Define the common marcos.
*/
#ifdef _MSC_VER
#if _MSC_VER < 1900
#define thread_local __declspec(thread)
#endif
#endif
#define CONCATENATE_IMPL(s1, s2) s1##s2 #define CONCATENATE_IMPL(s1, s2) s1##s2
#define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2) #define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2)
......
...@@ -23,21 +23,18 @@ ...@@ -23,21 +23,18 @@
namespace dragon { namespace dragon {
class CPUObject {
public:
unique_ptr<std::mt19937> rand_generator;
};
class CPUContext { class CPUContext {
public: public:
CPUContext(): random_seed_(3) { generator(); } CPUContext(): random_seed_(3) {}
CPUContext(unsigned int random_seed): random_seed_(random_seed) { generator(); } CPUContext(unsigned int random_seed)
CPUContext(const DeviceOption& option): random_seed_(option.has_random_seed() ? : random_seed_(random_seed) {}
option.random_seed() : 3) { generator(); } CPUContext(const DeviceOption& option)
: random_seed_(option.has_random_seed() ?
option.random_seed() : DEFAULT_RNG_SEED) {}
virtual ~CPUContext() {} virtual ~CPUContext() {}
inline void SwitchToDevice() {} inline void SwitchToDevice() {}
inline void static FinishDeviceCompution() { return; } inline void FinishDeviceCompution() {}
inline static void* New(size_t nbytes) { inline static void* New(size_t nbytes) {
void* data; void* data;
...@@ -50,40 +47,54 @@ class CPUContext { ...@@ -50,40 +47,54 @@ class CPUContext {
return data; return data;
} }
inline static void Memset(size_t nbytes, void* ptr) { memset(ptr, 0, nbytes); } inline static void Memset(size_t nbytes, void* ptr) {
memset(ptr, 0, nbytes);
}
template<class DstContext, class SrcContext> template<class DstContext, class SrcContext>
inline static void Memcpy(size_t nbytes, void* dst, const void* src) { memcpy(dst, src, nbytes); } inline static void Memcpy(
size_t nbytes,
void* dst,
const void* src) {
memcpy(dst, src, nbytes);
}
inline static void Delete(void* data) { free(data); } inline static void Delete(void* data) { free(data); }
template<class DstContext, class SrcContext> template<class DstContext, class SrcContext>
inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) { NOT_IMPLEMENTED; } inline static void MemcpyAsync(
size_t nbytes,
void* dst,
const void* src) {
NOT_IMPLEMENTED;
}
template<typename T, class DstContext, class SrcContext> template<typename T, class DstContext, class SrcContext>
inline static void Copy(int n, T* dst, const T* src) { inline static void Copy(
int n,
T* dst,
const T* src) {
if (dst == src) return; if (dst == src) return;
// only the basic types(e.g. int/float) can memcpy correctly // only the basic types(e.g. int/float) can memcpy correctly
if (std::is_fundamental<T>::value) if (std::is_fundamental<T>::value)
Memcpy<DstContext, SrcContext>(n * sizeof(T), (void*)dst, (const void*)src); Memcpy<DstContext, SrcContext>(
n * sizeof(T), (void*)dst, (const void*)src);
else for (int i = 0; i < n; i++) dst[i] = src[i]; else for (int i = 0; i < n; i++) dst[i] = src[i];
} }
inline std::mt19937* generator() { inline int device_id() const { return 0; }
auto& generator = cpu_object_.rand_generator;
if (!generator.get())
generator.reset(new std::mt19937(random_seed_));
return generator.get();
}
static CPUObject cpu_object_; inline std::mt19937* rand_generator() {
if (!rand_generator_.get())
rand_generator_.reset(new std::mt19937(random_seed_));
return rand_generator_.get();
}
private: private:
unsigned int random_seed_; unsigned int random_seed_;
unique_ptr<std::mt19937> rand_generator_;
}; };
static inline std::mt19937* rand_generator() {
return CPUContext::cpu_object_.rand_generator.get();
}
#define CPU_FP16_NOT_SUPPORTED \ #define CPU_FP16_NOT_SUPPORTED \
LOG(FATAL) << "FP16 is unsupported for CPUContext."; LOG(FATAL) << "FP16 is unsupported for CPUContext.";
......
...@@ -26,11 +26,18 @@ class GraphBase { ...@@ -26,11 +26,18 @@ class GraphBase {
string op_type; string op_type;
}; };
GraphBase(const GraphDef& meta_graph, Workspace* ws); GraphBase(
const GraphDef& meta_graph,
Workspace* ws);
virtual ~GraphBase() {} virtual ~GraphBase() {}
virtual bool Create(const GraphDef& optimized_graph, Workspace* ws) = 0; virtual bool Create(
virtual bool Run(const string& include, const string& exclude) = 0; const GraphDef& optimized_graph,
Workspace* ws) = 0;
virtual bool Run(
const string& include,
const string& exclude) = 0;
inline string name() const { return name_; } inline string name() const { return name_; }
...@@ -45,21 +52,31 @@ class Graph final : public GraphBase { ...@@ -45,21 +52,31 @@ class Graph final : public GraphBase {
Graph(const GraphDef& meta_graph, Workspace* ws); Graph(const GraphDef& meta_graph, Workspace* ws);
~Graph() { for (auto* op : ops_) delete op; } ~Graph() { for (auto* op : ops_) delete op; }
bool Create(const GraphDef& optimized_graph, Workspace* ws) override; bool Create(
bool Run(const string& include, const string& exclude) override; const GraphDef& optimized_graph,
Workspace* ws) override;
bool Run(
const string& include,
const string& exclude) override;
GraphDef Prune(const GraphDef& meta_graph); GraphDef Prune(const GraphDef& meta_graph);
GraphDef MakeUpdate(const GraphDef& meta_graph); GraphDef MakeUpdate(const GraphDef& meta_graph);
GraphDef Share(const GraphDef& optimized_graph); GraphDef Share(const GraphDef& optimized_graph);
void ShareGrads(GraphDef& optimized_graph); void ShareGrads(GraphDef& optimized_graph);
void RecomputingAware(const GraphDef& optimized_graph, Workspace* ws); void RecomputingAware(
const GraphDef& optimized_graph,
Workspace* ws);
inline Workspace* ws() const { return ws_; } inline Workspace* ws() const { return ws_; }
private: private:
void ForwardShareDyeing(string u, string ancestor); void ForwardShareDyeing(string u, string ancestor);
void ForwardPruneDyeing(string u, string leaf, vector<string> path); void ForwardPruneDyeing(
string u,
string leaf,
vector<string> path);
void BackwardPruneDyeing(string v); void BackwardPruneDyeing(string v);
vector<OperatorBase*> ops_; vector<OperatorBase*> ops_;
...@@ -69,8 +86,15 @@ class Graph final : public GraphBase { ...@@ -69,8 +86,15 @@ class Graph final : public GraphBase {
Set<string> targets_; Set<string> targets_;
}; };
GraphBase* NewGraph(const GraphDef& meta_graph, Workspace* ws); GraphBase* NewGraph(
DECLARE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*); const GraphDef& meta_graph,
Workspace* ws);
DECLARE_REGISTRY(
GraphRegistry,
GraphBase,
const GraphDef&,
Workspace*);
} // namespace dragon } // namespace dragon
......
...@@ -18,24 +18,32 @@ namespace dragon { ...@@ -18,24 +18,32 @@ namespace dragon {
class GraphGradientMaker { class GraphGradientMaker {
public: public:
GraphGradientMaker() : cur_op_idx_(0) {} GraphGradientMaker(): cur_op_idx_(0) {}
void Make(const GraphDef& forward_def, void Make(
const GraphDef& forward_def,
const vector<string>& targets, const vector<string>& targets,
GraphDef& new_def); GraphDef& new_def);
void Share(const string& grads_prefix, GraphDef& graph); void Share(const string& grads_prefix, GraphDef& graph);
inline void SetTerms(const Map<string, string>& terms) { terms_ = terms; } inline void SetTerms(
inline void SetOperatorPrefix(const string& prefix) { op_prefix_ = prefix; } const Map<string, string>& terms) { terms_ = terms; }
inline void SetOperatorSuffix(const string& suffix) { op_suffix_ = suffix; } inline void SetOperatorPrefix(
inline void AddExternalGrad(const string& name) { external_grads_.insert(name); } const string& prefix) { op_prefix_ = prefix; }
inline void AddIgnoreGrad(const string& name) { ignore_grads_.insert(name); } inline void SetOperatorSuffix(
const string& suffix) { op_suffix_ = suffix; }
inline void AddExternalGrad(
const string& name) { external_grads_.insert(name); }
inline void AddIgnoreGrad(
const string& name) { ignore_grads_.insert(name); }
private: private:
bool CheckGrad(const OperatorDef& forward_op, bool CheckGrad(
const OperatorDef& forward_op,
const Set<string>& targets, const Set<string>& targets,
vector< pair<string, int> >& gen_grads); vector< pair<string, int> >& gen_grads);
string GetOperatorName(); string GetOperatorName();
Map<string, string> terms_, inputs_to_grads_; Map<string, string> terms_, inputs_to_grads_;
......
...@@ -19,7 +19,13 @@ namespace dragon { ...@@ -19,7 +19,13 @@ namespace dragon {
class MixedMemory { class MixedMemory {
public: public:
enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED }; enum State {
UNINITIALIZED,
STATE_AT_CPU,
STATE_AT_CUDA,
SWITCHED,
SYNCED };
MixedMemory() : cpu_ptr_(nullptr), cuda_ptr_(nullptr) {} MixedMemory() : cpu_ptr_(nullptr), cuda_ptr_(nullptr) {}
MixedMemory(const TypeMeta& meta, const size_t nbytes) MixedMemory(const TypeMeta& meta, const size_t nbytes)
: meta_(meta), nbytes_(nbytes), : meta_(meta), nbytes_(nbytes),
...@@ -31,9 +37,6 @@ class MixedMemory { ...@@ -31,9 +37,6 @@ class MixedMemory {
void* mutable_cpu_data(); void* mutable_cpu_data();
void* mutable_cuda_data(); void* mutable_cuda_data();
void set_cpu_data(void* cpu_ptr, size_t nbytes); void set_cpu_data(void* cpu_ptr, size_t nbytes);
#ifdef WITH_CUDA
void async_cuda_data(const cudaStream_t& stream);
#endif
void SwitchToDevice(); void SwitchToDevice();
void SwitchToCUDADevice(int device_id); void SwitchToCUDADevice(int device_id);
......
...@@ -29,7 +29,7 @@ class Workspace; ...@@ -29,7 +29,7 @@ class Workspace;
class OperatorBase { class OperatorBase {
public: public:
OperatorBase(const OperatorDef& op_def, Workspace* ws); OperatorBase(const OperatorDef& def, Workspace* ws);
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
Tensor& Input(int idx); Tensor& Input(int idx);
...@@ -38,66 +38,73 @@ class OperatorBase { ...@@ -38,66 +38,73 @@ class OperatorBase {
inline size_t InputSize() { return inputs_.size(); } inline size_t InputSize() { return inputs_.size(); }
inline size_t OutputSize() { return outputs_.size(); } inline size_t OutputSize() { return outputs_.size(); }
void MutableOp(const OperatorDef& op_def); void MutableOp(const OperatorDef& def);
void MutableOp(const vector<string>& inputs, void MutableOp(const vector<string>& inputs,
const vector<string>& outputs, const vector<string>& outputs,
const string& anchor); const string& anchor);
inline void SwitchToPhase(const string& phase) { this->phase_ = phase; } inline void SwitchToPhase(const string& phase) { phase_ = phase; }
virtual void Run() { NOT_IMPLEMENTED; } virtual void Run() { NOT_IMPLEMENTED; }
inline const string& name() const { return op_def_.name(); } inline const string& name() const { return def_.name(); }
inline const string& type() const { return op_def_.type(); } inline const string& type() const { return def_.type(); }
inline const string& phase() const { return phase_; } inline const string& phase() const { return phase_; }
inline const string& anchor() { return anchor_; } inline const string& anchor() { return anchor_; }
inline Workspace* ws() const { return ws_; } inline Workspace* ws() const { return ws_; }
template <typename T> template <typename T>
T GetSingleArg(const string& name, const T& default_value); T Arg(const string& name, const T& default_value);
template <typename T> template <typename T>
vector<T> GetRepeatedArg(const string& name); vector<T> Args(const string& name);
inline const Map<std::string, const Argument*>& args() { return args_; } inline const Map<std::string, const Argument*>& args() { return args_; }
inline const Argument& arg(const string& name) { return *(args_[name]); } inline const Argument& arg(const string& name) { return *(args_[name]); }
typedef Map<string, vector<OperatorBase*> > RecomputeMap; typedef Map<string, vector<OperatorBase*> > RecomputeMap;
inline RecomputeMap& recompute_map() { return recompute_map_; } inline RecomputeMap& recompute_map() { return recompute_map_; }
void set_recompute_map(RecomputeMap recompute_map) { recompute_map_ = recompute_map; } void set_recompute_map(RecomputeMap recompute_map) {
recompute_map_ = recompute_map;
}
inline const OperatorDef& op_def() const { return op_def_; } inline const OperatorDef& def() const { return def_; }
inline string DebugString() const { return op_def_.DebugString(); } inline string DebugString() const { return def_.DebugString(); }
string DTypeHelper(const Tensor& tensor, const Set<string>& dtypes) const; string DTypeHelper(
string DTypeHelper(const string& dtype, const Set<string>& dtypes) const; const Tensor& tensor,
const Set<string>& dtypes) const;
string DTypeHelper(
const string& dtype,
const Set<string>& dtypes) const;
protected: protected:
string phase_, anchor_; string phase_, anchor_;
Map<std::string, const Argument*> args_; Map<std::string, const Argument*> args_;
Map<string, vector<OperatorBase*> > recompute_map_; Map<string, vector<OperatorBase*> > recompute_map_;
vector<Tensor*> inputs_, outputs_; vector<Tensor*> inputs_, outputs_;
OperatorDef op_def_; OperatorDef def_;
Workspace* ws_; Workspace* ws_;
}; };
template <class Context> template <class Context>
class Operator : public OperatorBase { class Operator : public OperatorBase {
public: public:
Operator(const OperatorDef& op_def, Workspace* ws) Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(op_def, ws), ctx_(op_def.device_option()), : OperatorBase(def, ws), ctx_(def.device_option()),
do_synchronize_(Operator::GetSingleArg<bool>("do_synchronize", false)), recomputing_aware_(OperatorBase::Arg<bool>(
recomputing_aware_(Operator::GetSingleArg<bool>("recomputing_aware", false)) { "recomputing_aware", false)) {
allow_run_ = true; allow_run_ = true;
allow_run_ &= _MPICheck(); allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && Output(0)->name() == "ignore")); allow_run_ &= (!(OutputSize() == 1 &&
Output(0)->name() == "ignore"));
} }
virtual void Run() final { virtual void Run() final {
if (!allow_run_) return; if (!allow_run_) return;
if (recomputing_aware_) MakeResource(); if (recomputing_aware_) MakeResource();
ctx_.SwitchToDevice(); ctx().SwitchToDevice();
MemorySwitch(); MemorySwitch();
RunOnDevice(); RunOnDevice();
if (do_synchronize_) ctx_.FinishDeviceCompution(); ctx().FinishDeviceCompution();
if (recomputing_aware_) CleanResource(); if (recomputing_aware_) CleanResource();
} }
...@@ -106,8 +113,10 @@ class Operator : public OperatorBase { ...@@ -106,8 +113,10 @@ class Operator : public OperatorBase {
virtual void CleanResource(); virtual void CleanResource();
void MemorySwitch() { void MemorySwitch() {
for (auto* I : inputs_) if(I->name() != "ignore") I->SwitchToDevice(); for (auto* I : inputs_)
for (auto* O : outputs_) if(O->name() != "ignore") O->SwitchToDevice(); if(I->name() != "ignore") I->SwitchToDevice();
for (auto* O : outputs_)
if(O->name() != "ignore") O->SwitchToDevice();
} }
virtual void RunOnDevice() = 0; virtual void RunOnDevice() = 0;
...@@ -117,14 +126,15 @@ class Operator : public OperatorBase { ...@@ -117,14 +126,15 @@ class Operator : public OperatorBase {
protected: protected:
Context ctx_; Context ctx_;
bool allow_run_, recomputing_aware_, do_synchronize_; bool allow_run_, recomputing_aware_;
private: private:
bool _MPICheck() { bool _MPICheck() {
#ifndef WITH_MPI #ifndef WITH_MPI
return true; return true;
#else #else
vector<int> allow_ranks = Operator::GetRepeatedArg<int>("mpi_ranks"); vector<int> allow_ranks =
OperatorBase::Args<int>("mpi_ranks");
if (allow_ranks.empty()) return true; if (allow_ranks.empty()) return true;
int cur_rank; int cur_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank); MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank);
...@@ -135,11 +145,11 @@ class Operator : public OperatorBase { ...@@ -135,11 +145,11 @@ class Operator : public OperatorBase {
} }
}; };
OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws); OperatorBase* CreateOperator(const OperatorDef& def, Workspace* ws);
#define USE_SIMPLE_CTOR_DTOR(name) \ #define USE_SIMPLE_CTOR_DTOR(name) \
name(const OperatorDef& op_def, Workspace* ws) \ name(const OperatorDef& def, Workspace* ws) \
: Operator<Context>(op_def, ws) {} \ : Operator<Context>(def, ws) {} \
virtual ~name() {} virtual ~name() {}
#define USE_OPERATOR_BASE_FUNCTIONS \ #define USE_OPERATOR_BASE_FUNCTIONS \
...@@ -150,7 +160,7 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws); ...@@ -150,7 +160,7 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
using OperatorBase::type; \ using OperatorBase::type; \
using OperatorBase::phase; \ using OperatorBase::phase; \
using OperatorBase::anchor; \ using OperatorBase::anchor; \
using OperatorBase::op_def; \ using OperatorBase::def; \
using OperatorBase::InputSize; \ using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \ using OperatorBase::OutputSize; \
using OperatorBase::DebugString; \ using OperatorBase::DebugString; \
...@@ -162,9 +172,23 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws); ...@@ -162,9 +172,23 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
using Operator<Context>::ctx; \ using Operator<Context>::ctx; \
using Operator<Context>::AllowRun using Operator<Context>::AllowRun
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*); DECLARE_REGISTRY(
DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); CPUOperatorRegistry,
DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); OperatorBase,
const OperatorDef&,
Workspace*);
DECLARE_REGISTRY(
CUDAOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
DECLARE_REGISTRY(
CUDNNOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
#define TENSOR_FILL(tensor, shape) \ #define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \ if (tensor.count() == 0) { \
...@@ -174,7 +198,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -174,7 +198,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
tensor.Reshape(shape); \ tensor.Reshape(shape); \
unique_ptr< Filler<T, Context> > filler( \ unique_ptr< Filler<T, Context> > filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \ CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor); \ filler->Fill(&tensor, &ctx()); \
} else { \ } else { \
TIndex count = 1; \ TIndex count = 1; \
for(int i = 0; i < shape.size(); i++) count *= shape[i]; \ for(int i = 0; i < shape.size(); i++) count *= shape[i]; \
...@@ -189,7 +213,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -189,7 +213,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
#define INIT_MULTIPLIER(ptr_tensor, size) { \ #define INIT_MULTIPLIER(ptr_tensor, size) { \
ptr_tensor = ws()->CreateTensor("/share/multiplier"); \ ptr_tensor = ws()->CreateTensor("/share/multiplier"); \
if (size > ptr_tensor->count()) { \ if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape(vector<TIndex>(1, size)); \ ptr_tensor->Reshape({ size }); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \ math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
ptr_tensor->template mutable_data<T, Context>()); \ ptr_tensor->template mutable_data<T, Context>()); \
} \ } \
...@@ -214,12 +238,12 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -214,12 +238,12 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
type argument(int idx) type argument(int idx)
#define GET_ARGUMENT_WITH_DESC(type, argument, default_value) \ #define GET_ARGUMENT_WITH_DESC(type, argument, default_value) \
argument##_value = OperatorBase::GetSingleArg<type>(#argument, default_value); \ argument##_value = OperatorBase::Arg<type>(#argument, default_value); \
argument##_desc = OperatorBase::GetSingleArg<string>(string(#argument) + "_desc", "") argument##_desc = OperatorBase::Arg<string>(string(#argument) + "_desc", "")
#define GET_ARGUMENTS_WITH_DESC(type, argument) \ #define GET_ARGUMENTS_WITH_DESC(type, argument) \
argument##_value = OperatorBase::GetRepeatedArg<type>(#argument); \ argument##_value = OperatorBase::Args<type>(#argument); \
argument##_desc = OperatorBase::GetRepeatedArg<string>(string(#argument) + "_desc") argument##_desc = OperatorBase::Args<string>(string(#argument) + "_desc")
#define DEFINE_ARGUMENT_WITH_DESC(type, classname, argument) \ #define DEFINE_ARGUMENT_WITH_DESC(type, classname, argument) \
template <class Context> \ template <class Context> \
......
...@@ -23,7 +23,8 @@ struct Gradient { ...@@ -23,7 +23,8 @@ struct Gradient {
vector<OperatorDef> ops; vector<OperatorDef> ops;
vector<string> g_inputs; vector<string> g_inputs;
vector<float> defaults; vector<float> defaults;
Gradient(const vector<OperatorDef>& ops, Gradient(
const vector<OperatorDef>& ops,
const vector<string>& g_inputs, const vector<string>& g_inputs,
const vector<float>& defaults) const vector<float>& defaults)
: ops(ops), g_inputs(g_inputs), defaults(defaults) {} : ops(ops), g_inputs(g_inputs), defaults(defaults) {}
...@@ -31,9 +32,11 @@ struct Gradient { ...@@ -31,9 +32,11 @@ struct Gradient {
class GradientMakerBase { class GradientMakerBase {
public: public:
GradientMakerBase(const OperatorDef& def, GradientMakerBase(
const OperatorDef& def,
const vector<string>& g_outputs) const vector<string>& g_outputs)
: def(def), g_outputs_(g_outputs), g_inputs_(def.input_size()) {} : def(def), g_outputs_(g_outputs),
g_inputs_(def.input_size()) {}
virtual ~GradientMakerBase() {} virtual ~GradientMakerBase() {}
inline virtual bool CopyDeviceOption() const { return true; } inline virtual bool CopyDeviceOption() const { return true; }
...@@ -80,7 +83,9 @@ class GradientMakerBase { ...@@ -80,7 +83,9 @@ class GradientMakerBase {
}; };
// implemented in operator.cc // implemented in operator.cc
Gradient MakeGradientForOp(const OperatorDef& op_def, const vector<string>& g_outputs); Gradient MakeGradientForOp(
const OperatorDef& op_def,
const vector<string>& g_outputs);
# define GRADIENT_MAKER_CTOR(name) \ # define GRADIENT_MAKER_CTOR(name) \
name(const OperatorDef& def, const vector<string>& g_output) \ name(const OperatorDef& def, const vector<string>& g_output) \
...@@ -94,12 +99,14 @@ class NoGradient : public GradientMakerBase { ...@@ -94,12 +99,14 @@ class NoGradient : public GradientMakerBase {
} }
}; };
DECLARE_REGISTRY(GradientRegistry, DECLARE_REGISTRY(
GradientRegistry,
GradientMakerBase, GradientMakerBase,
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&);
DECLARE_REGISTRY(NoGradientRegistry, DECLARE_REGISTRY(
NoGradientRegistry,
GradientMakerBase, GradientMakerBase,
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&);
......
...@@ -22,13 +22,24 @@ namespace dragon { ...@@ -22,13 +22,24 @@ namespace dragon {
class OpSchema { class OpSchema {
public: public:
OpSchema() OpSchema()
: op_type_("unknown"), file_("unknown"), line_(0) { Init(); } : op_type_("unknown"), file_("unknown"), line_(0) {
OpSchema(const string& op_type, const string& file, const int line) Init();
: op_type_(op_type), file_(file), line_(line) { Init(); } }
OpSchema(
const string& op_type,
const string& file,
const int line)
: op_type_(op_type), file_(file), line_(line) {
Init();
}
bool Verify(const OperatorDef& def) const; bool Verify(const OperatorDef& def) const;
inline OpSchema& IgnoreVerify() { ignore_verify_ = true; return *this; } inline OpSchema& IgnoreVerify() {
ignore_verify_ = true;
return *this;
}
OpSchema& Inplace(set<pair<int, int> > inplace); OpSchema& Inplace(set<pair<int, int> > inplace);
std::function<bool(int, int)> CheckInplace; std::function<bool(int, int)> CheckInplace;
...@@ -56,20 +67,26 @@ class OpSchema { ...@@ -56,20 +67,26 @@ class OpSchema {
class OpSchemaRegistry { class OpSchemaRegistry {
public: public:
static OpSchema& NewSchema(const string& op_type, const string& file, const int line) { static OpSchema& NewSchema(
const string& op_type,
const string& file,
const int line) {
auto& m = schema_map(); auto& m = schema_map();
CHECK(!m.count(op_type)) CHECK(!m.count(op_type))
<< "\nOpSchema(" << op_type << ") has registered before." << "\nOpSchema(" << op_type
<< ") has registered before."
<< "\nat file: " << file << "\nat file: " << file
<< "\n line: " << line; << "\n line: " << line;
m.emplace(std::make_pair(op_type, OpSchema(op_type, file, line))); m.emplace(std::make_pair(op_type,
OpSchema(op_type, file, line)));
return m[op_type]; return m[op_type];
} }
static const OpSchema* Schema(const string& op_type) { static const OpSchema* Schema(const string& op_type) {
auto& m = schema_map(); auto& m = schema_map();
if (m.count(op_type)) return &m[op_type]; if (m.count(op_type)) return &m[op_type];
else LOG(FATAL) << "OpSchema(" << op_type << ") has not registered yet."; else LOG(FATAL) << "OpSchema(" << op_type
<< ") has not registered yet.";
return nullptr; return nullptr;
} }
......
...@@ -20,35 +20,50 @@ namespace dragon { ...@@ -20,35 +20,50 @@ namespace dragon {
template <class SrcType, class ObjType, class... Args> template <class SrcType, class ObjType, class... Args>
class Registry { class Registry {
public: public:
typedef std::function<ObjType*(Args ...)> Creator; typedef std::function<ObjType*(Args ...)> Creator;
void Register(const SrcType& key, Creator creator) { void Register(const SrcType& key, Creator creator) {
CHECK(!registry_.count(key)) << "\nKey(" << key << ") has already registered."; CHECK(!registry_.count(key))
<< "\nKey(" << key << ") has already registered.";
registry_[key] = creator; registry_[key] = creator;
} }
ObjType* Create(const SrcType& key, Args ... args) { ObjType* Create(const SrcType& key, Args ... args) {
CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered yet."; CHECK(registry_.count(key))
<< "\nKey(" << key << ") has not registered yet.";
return registry_[key](args...); return registry_[key](args...);
} }
bool Has(const SrcType& key) { return (registry_.count(key)) != 0; }
bool Has(const SrcType& key) {
return (registry_.count(key)) != 0;
}
vector<SrcType> keys() { vector<SrcType> keys() {
vector<SrcType> ret; vector<SrcType> ret;
for (const auto& it : registry_) ret.push_back(it.first); for (const auto& it : registry_)
ret.push_back(it.first);
return ret; return ret;
} }
private:
private:
Map<SrcType, Creator> registry_; Map<SrcType, Creator> registry_;
}; };
template <class SrcType, class ObjType, class... Args> template <class SrcType, class ObjType, class... Args>
class Registerer { class Registerer {
public: public:
Registerer(const SrcType& key, Registry<SrcType, ObjType, Args...>* registry, Registerer(
typename Registry<SrcType, ObjType, Args...>::Creator creator, const string& help_msg = "") { const SrcType& key,
Registry<SrcType, ObjType, Args...>* registry,
typename Registry<SrcType, ObjType, Args...>::Creator creator,
const string& help_msg = "") {
registry->Register(key, creator); registry->Register(key, creator);
} }
template <class DerivedType> template <class DerivedType>
static ObjType* defaultCreator(Args ... args) {return new DerivedType(args...);} static ObjType* defaultCreator(Args ... args) {
return new DerivedType(args...);
}
}; };
// use in *.h files // use in *.h files
......
...@@ -74,10 +74,15 @@ class Tensor { ...@@ -74,10 +74,15 @@ class Tensor {
return ret; return ret;
} }
inline TIndex count() const { return size_; } inline TIndex count() const { return size_; }
inline TIndex count(const TIndex start) const { return count(start, ndim()); } inline TIndex count(const TIndex start) const {
return count(start, ndim());
}
inline TIndex offset(const TIndex n, const TIndex c = 0, inline TIndex offset(
const TIndex h = 0, const TIndex w = 0) { const TIndex n,
const TIndex c = 0,
const TIndex h = 0,
const TIndex w = 0) {
CHECK_LE(n, dim(0)); CHECK_LE(n, dim(0));
CHECK_LE(c, dim(1)); CHECK_LE(c, dim(1));
CHECK_LE(h, dim(2)); CHECK_LE(h, dim(2));
...@@ -95,7 +100,7 @@ class Tensor { ...@@ -95,7 +100,7 @@ class Tensor {
return offset; return offset;
} }
inline string dim_string() const { inline string DimString() const {
if (ndim() == 0) return "(0,)"; if (ndim() == 0) return "(0,)";
std::stringstream ss; std::stringstream ss;
ss << "("; ss << "(";
...@@ -108,9 +113,18 @@ class Tensor { ...@@ -108,9 +113,18 @@ class Tensor {
inline bool is_corrupted() const { return is_corrupted_; } inline bool is_corrupted() const { return is_corrupted_; }
inline void Corrupt() { is_corrupted_ = true; } inline void Corrupt() { is_corrupted_ = true; }
inline bool has_memory() const { return memory_ || ex_memory_ != nullptr; } inline bool has_memory() const {
MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; } return memory_ || ex_memory_ != nullptr;
void set_memory(MixedMemory* mem) { memory_.reset(mem); capacity_ = mem->nbytes(); } }
MixedMemory* memory() const {
return own_mem_ ? memory_.get() : ex_memory_;
}
void set_memory(MixedMemory* mem) {
memory_.reset(mem); capacity_ = mem->nbytes();
}
MixedMemory::State memory_state() const { MixedMemory::State memory_state() const {
MixedMemory* mem = memory(); MixedMemory* mem = memory();
CHECK(mem) << "\nMemory access before allowcating."; CHECK(mem) << "\nMemory access before allowcating.";
...@@ -124,7 +138,8 @@ class Tensor { ...@@ -124,7 +138,8 @@ class Tensor {
const TypeMeta& meta() const { return meta_; } const TypeMeta& meta() const { return meta_; }
void SetMeta(const TypeMeta& meta) { meta_ = meta; } void SetMeta(const TypeMeta& meta) { meta_ = meta; }
template <typename T> inline bool IsType() { return meta_.Match<T>(); } template <typename T>
inline bool IsType() { return meta_.Match<T>(); }
template <class Context> template <class Context>
void mutable_data_ptr(void** data_ptr) { void mutable_data_ptr(void** data_ptr) {
...@@ -132,12 +147,15 @@ class Tensor { ...@@ -132,12 +147,15 @@ class Tensor {
if (!mem) { if (!mem) {
*data_ptr = nullptr; *data_ptr = nullptr;
} else { } else {
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) { if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CPUContext>()) {
*data_ptr = mem->mutable_cpu_data(); *data_ptr = mem->mutable_cpu_data();
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) { } else if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CUDAContext>()) {
*data_ptr = mem->mutable_cuda_data(); *data_ptr = mem->mutable_cuda_data();
} else { } else {
LOG(FATAL) << "Unknown memory type. Only CPU or CUDA is supported."; LOG(FATAL) << "Unknown memory type. "
<< "Only CPU or CUDA is supported.";
} }
} }
} }
...@@ -146,12 +164,15 @@ class Tensor { ...@@ -146,12 +164,15 @@ class Tensor {
const void* const_data_ptr() const { const void* const_data_ptr() const {
MixedMemory* mem = memory(); MixedMemory* mem = memory();
CHECK(mem) << "\nMemory access before allowcating."; CHECK(mem) << "\nMemory access before allowcating.";
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) { if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CPUContext>()) {
return mem->cpu_data(); return mem->cpu_data();
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) { } else if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CUDAContext>()) {
return mem->cuda_data(); return mem->cuda_data();
} else { } else {
LOG(FATAL) << "Unknown memory type. Only CPU or CUDA are supported."; LOG(FATAL) << "Unknown memory type. "
<< "Only CPU or CUDA are supported.";
return nullptr; return nullptr;
} }
} }
...@@ -164,10 +185,17 @@ class Tensor { ...@@ -164,10 +185,17 @@ class Tensor {
if (meta_ != meta && data_ptr && !own_mem_) delete ex_memory_; if (meta_ != meta && data_ptr && !own_mem_) delete ex_memory_;
meta_ = meta; meta_ = meta;
CHECK_GT(size_, 0); CHECK_GT(size_, 0);
if (own_mem_) memory_.reset(new MixedMemory(meta, size_* meta_.itemsize())); if (own_mem_) {
else ex_memory_ = new MixedMemory(meta, size_* meta_.itemsize()); memory_.reset(new MixedMemory(
mutable_data_ptr<Context>(&data_ptr); // malloc memory meta, size_* meta_.itemsize()));
if (meta.ctor()) meta_.ctor()(data_ptr, size_); // call the constructor } else {
ex_memory_ = new MixedMemory(
meta, size_* meta_.itemsize());
}
// malloc memory
mutable_data_ptr<Context>(&data_ptr);
// call the constructors
if (meta.ctor()) meta_.ctor()(data_ptr, size_);
capacity_ = size_ * meta.itemsize(); capacity_ = size_ * meta.itemsize();
return data_ptr; return data_ptr;
} }
...@@ -189,8 +217,10 @@ class Tensor { ...@@ -189,8 +217,10 @@ class Tensor {
T* mutable_data() { T* mutable_data() {
void* data_ptr; void* data_ptr;
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
if (data_ptr && meta_ == TypeMeta::Make<T>()) return static_cast<T*>(data_ptr); if (data_ptr && meta_ == TypeMeta::Make<T>())
return static_cast<T*>(raw_mutable_data<Context>(TypeMeta::Make<T>())); return static_cast<T*>(data_ptr);
return static_cast<T*>(
raw_mutable_data<Context>(TypeMeta::Make<T>()));
} }
template <typename T, class Context> template <typename T, class Context>
...@@ -208,9 +238,11 @@ class Tensor { ...@@ -208,9 +238,11 @@ class Tensor {
auto* src = other.template raw_data<SrcCTX>(); auto* src = other.template raw_data<SrcCTX>();
auto* dst = raw_mutable_data<DstCTX>(other.meta_); auto* dst = raw_mutable_data<DstCTX>(other.meta_);
if (dst == src) return; if (dst == src) return;
if (TypeMeta::Id<DstCTX>() == TypeMeta::Id<CPUContext>()) { if (TypeMeta::Id<DstCTX>() ==
TypeMeta::Id<CPUContext>()) {
CPUContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src); CPUContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src);
} else if (TypeMeta::Id<DstCTX>() == TypeMeta::Id<CUDAContext>()) { } else if (TypeMeta::Id<DstCTX>() ==
TypeMeta::Id<CUDAContext>()) {
CUDAContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src); CUDAContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src);
} }
} }
......
...@@ -82,38 +82,47 @@ class TypeMeta { ...@@ -82,38 +82,47 @@ class TypeMeta {
template <typename T> template <typename T>
static void Ctor(void* ptr, size_t n) { static void Ctor(void* ptr, size_t n) {
T* typed_ptr = static_cast<T*>(ptr); T* typed_ptr = static_cast<T*>(ptr);
for (unsigned int i = 0; i < n; i++) new(typed_ptr + i) T; for (unsigned int i = 0; i < n; i++)
new(typed_ptr + i) T;
} }
template <typename T> template <typename T>
static void Copy(const void* src, void* dst, size_t n) { static void Copy(const void* src, void* dst, size_t n) {
const T* typed_src = static_cast<const T*>(src); const T* typed_src = static_cast<const T*>(src);
T* typed_dst = static_cast<T*>(dst); T* typed_dst = static_cast<T*>(dst);
for (unsigned int i = 0; i < n; i++) typed_dst[i] = typed_src[i]; for (unsigned int i = 0; i < n; i++)
typed_dst[i] = typed_src[i];
} }
template <typename T> template <typename T>
static void Dtor(void* ptr, size_t n) { static void Dtor(void* ptr, size_t n) {
T* typed_ptr = static_cast<T*>(ptr); T* typed_ptr = static_cast<T*>(ptr);
for (unsigned int i = 0; i < n; i++) typed_ptr[i].~T(); for (unsigned int i = 0; i < n; i++)
typed_ptr[i].~T();
} }
#define FundMeta std::enable_if<std::is_fundamental<T>::value,TypeMeta>::type #define FundMeta std::enable_if<std::is_fundamental<T>::value,TypeMeta>::type
#define StructMeta std::enable_if<!std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, TypeMeta>::type
template <typename T> template <typename T>
static typename FundMeta Make() { static typename FundMeta Make() {
return TypeMeta(Id<T>(), Itemsize<T>(), nullptr, nullptr, nullptr); return TypeMeta(Id<T>(), Itemsize<T>(),
nullptr, nullptr, nullptr);
} }
#define StructMeta std::enable_if<!std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, TypeMeta>::type
template<typename T> template<typename T>
static typename StructMeta Make() { static typename StructMeta Make() {
return TypeMeta(Id<T>(), Itemsize<T>(), Ctor<T>, Copy<T>, Dtor<T>); return TypeMeta(Id<T>(), Itemsize<T>(),
Ctor<T>, Copy<T>, Dtor<T>);
} }
private: private:
TypeMeta(TypeId id, size_t itemsize, TypeMeta(
PlacementNew ctor, TypedCopy copy, TypedDestructor dtor) TypeId id,
size_t itemsize,
PlacementNew ctor,
TypedCopy copy,
TypedDestructor dtor)
: id_(id), itemsize_(itemsize), : id_(id), itemsize_(itemsize),
ctor_(ctor), copy_(copy), dtor_(dtor) {} ctor_(ctor), copy_(copy), dtor_(dtor) {}
......
...@@ -40,8 +40,10 @@ typedef struct { ...@@ -40,8 +40,10 @@ typedef struct {
#endif #endif
inline const TypeMeta& TypeStringToMeta(const std::string& str_type) { inline const TypeMeta& TypeStringToMeta(
static std::unordered_map<std::string, TypeMeta> s2m_type_map { const std::string& str_type) {
static std::unordered_map<std::string, TypeMeta>
s2m_type_map {
{ "float32", TypeMeta::Make<float>() }, { "float32", TypeMeta::Make<float>() },
{ "int32", TypeMeta::Make<int>() }, { "int32", TypeMeta::Make<int>() },
{ "int64", TypeMeta::Make<int64_t>() }, { "int64", TypeMeta::Make<int64_t>() },
...@@ -50,11 +52,14 @@ inline const TypeMeta& TypeStringToMeta(const std::string& str_type) { ...@@ -50,11 +52,14 @@ inline const TypeMeta& TypeStringToMeta(const std::string& str_type) {
{ "uint8", TypeMeta::Make<uint8_t>() } { "uint8", TypeMeta::Make<uint8_t>() }
}; };
static TypeMeta unknown_type; static TypeMeta unknown_type;
return s2m_type_map.count(str_type) ? s2m_type_map[str_type] : unknown_type; return s2m_type_map.count(str_type) ?
s2m_type_map[str_type] : unknown_type;
} }
inline const std::string TypeMetaToString(const TypeMeta& meta) { inline const std::string TypeMetaToString(
static std::unordered_map<TypeId, std::string> m2s_type_map { const TypeMeta& meta) {
static std::unordered_map<TypeId, std::string>
m2s_type_map {
{ TypeMeta::Id<float>(), "float32" }, { TypeMeta::Id<float>(), "float32" },
{ TypeMeta::Id<int>(), "int32" }, { TypeMeta::Id<int>(), "int32" },
{ TypeMeta::Id<int64_t>(), "int64" }, { TypeMeta::Id<int64_t>(), "int64" },
...@@ -62,7 +67,8 @@ inline const std::string TypeMetaToString(const TypeMeta& meta) { ...@@ -62,7 +67,8 @@ inline const std::string TypeMetaToString(const TypeMeta& meta) {
{ TypeMeta::Id<float16>(), "float16" }, { TypeMeta::Id<float16>(), "float16" },
{ TypeMeta::Id<uint8_t>(), "uint8" } { TypeMeta::Id<uint8_t>(), "uint8" }
}; };
return m2s_type_map.count(meta.id()) ? m2s_type_map[meta.id()] : "unknown"; return m2s_type_map.count(meta.id()) ?
m2s_type_map[meta.id()] : "unknown";
} }
} // namespace dragon } // namespace dragon
......
...@@ -39,13 +39,16 @@ class Workspace { ...@@ -39,13 +39,16 @@ class Workspace {
inline void InitWorkspace() { inline void InitWorkspace() {
CreateTensor("ignore"); CreateTensor("ignore");
Tensor* head = CreateTensor("/opt/mirror_stage/head"); Tensor* head = CreateTensor(
head->Reshape(vector<TIndex>(1, WORKSPACE_MAX_CORRUPTED_SIZE)); "/opt/mirror_stage/head");
Tensor* recompute_flag = CreateTensor("/opt/mirror_stage/recompute_flag"); head->Reshape({ WORKSPACE_MAX_CORRUPTED_SIZE });
recompute_flag->Reshape(vector<TIndex>(1, 1)); Tensor* recompute_flag = CreateTensor(
"/opt/mirror_stage/recompute_flag");
recompute_flag->Reshape({ 1 });
recompute_flag->mutable_data<bool, CPUContext>()[0] = false; recompute_flag->mutable_data<bool, CPUContext>()[0] = false;
for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) { for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) {
string name = "/opt/mirror_stage/buffer_" + dragon_cast<string, int>(i); string name = "/opt/mirror_stage/buffer_" +
dragon_cast<string, int>(i);
Tensor* buffer = CreateTensor(name); Tensor* buffer = CreateTensor(name);
head->mutable_data<string, CPUContext>()[i] = ""; head->mutable_data<string, CPUContext>()[i] = "";
} }
...@@ -72,7 +75,9 @@ class Workspace { ...@@ -72,7 +75,9 @@ class Workspace {
} else { return name; } } else { return name; }
} }
inline Tensor* TryGetTensor(const string& name, bool use_remote=true) { inline Tensor* TryGetTensor(
const string& name,
bool use_remote = true) {
string query = GetTensorName(name); string query = GetTensorName(name);
// search local workspace // search local workspace
if (tensor_map_.count(query) > 0) if (tensor_map_.count(query) > 0)
...@@ -87,7 +92,9 @@ class Workspace { ...@@ -87,7 +92,9 @@ class Workspace {
return nullptr; return nullptr;
} }
inline bool HasTensor(const string& name, bool use_remote=true) { inline bool HasTensor(
const string& name,
bool use_remote = true) {
return TryGetTensor(name, use_remote) ? true : false; return TryGetTensor(name, use_remote) ? true : false;
} }
...@@ -100,7 +107,9 @@ class Workspace { ...@@ -100,7 +107,9 @@ class Workspace {
return tensor; return tensor;
} }
inline Tensor* GetTensor(const string& name, bool use_remote=true) { inline Tensor* GetTensor(
const string& name,
bool use_remote = true) {
Tensor* tensor = TryGetTensor(name, use_remote); Tensor* tensor = TryGetTensor(name, use_remote);
CHECK(tensor) << "\nTensor(" << name << ") does not exist " CHECK(tensor) << "\nTensor(" << name << ") does not exist "
<< "in current workspace or sub-workspace."; << "in current workspace or sub-workspace.";
...@@ -122,14 +131,17 @@ class Workspace { ...@@ -122,14 +131,17 @@ class Workspace {
// serach remote workspace // serach remote workspace
for (auto& it : workspace_map_) { for (auto& it : workspace_map_) {
vector<string> sub_names = it.second->GetTensors(); vector<string> sub_names = it.second->GetTensors();
names.insert(names.end(), sub_names.begin(), sub_names.end()); names.insert(names.end(),
sub_names.begin(), sub_names.end());
} }
return names; return names;
} }
/******************** Filler ********************/ /******************** Filler ********************/
inline bool HasFiller(const string& name, bool use_remote=true) { inline bool HasFiller(
const string& name,
bool use_remote = true) {
// search local workspace // search local workspace
bool result = filler_map_.count(name) > 0; bool result = filler_map_.count(name) > 0;
if (!use_remote) return result; if (!use_remote) return result;
...@@ -140,14 +152,16 @@ class Workspace { ...@@ -140,14 +152,16 @@ class Workspace {
return result; return result;
} }
inline void CreateFiller(const TensorFiller filler) { inline void CreateFiller(
const TensorFiller filler) {
CHECK_GT(filler.tensor().size(), 0) CHECK_GT(filler.tensor().size(), 0)
<< "Tensor without a valid name can not be filled."; << "Tensor without a valid name can not be filled.";
if (HasFiller(filler.tensor())) return; if (HasFiller(filler.tensor())) return;
filler_map_[filler.tensor()] = filler; filler_map_[filler.tensor()] = filler;
} }
inline const TensorFiller* GetFiller(const string& name) { inline const TensorFiller* GetFiller(
const string& name) {
// search local workspace // search local workspace
if (filler_map_.count(name) > 0) if (filler_map_.count(name) > 0)
return &filler_map_[name]; return &filler_map_[name];
...@@ -163,11 +177,12 @@ class Workspace { ...@@ -163,11 +177,12 @@ class Workspace {
/******************** Cache ********************/ /******************** Cache ********************/
template <class Context> template <class Context>
inline vector<void*> caches(const vector<size_t>& segments) { inline vector<void*> caches(
const vector<size_t>& segments) {
TIndex total_size = 0; TIndex total_size = 0;
for (auto& segment : segments) total_size += (TIndex)segment; for (auto& segment : segments) total_size += (TIndex)segment;
Tensor* cacheT = CreateTensor("/share/cache"); Tensor* cacheT = CreateTensor("/share/cache");
cacheT->Reshape(vector<TIndex>(1, total_size)); cacheT->Reshape({ total_size });
vector<void*> caches(segments.size()); vector<void*> caches(segments.size());
caches[0] = cacheT->template mutable_data<uint8_t, Context>(); caches[0] = cacheT->template mutable_data<uint8_t, Context>();
for (int i = 1; i < segments.size(); i++) for (int i = 1; i < segments.size(); i++)
...@@ -176,11 +191,12 @@ class Workspace { ...@@ -176,11 +191,12 @@ class Workspace {
} }
template <typename T, class Context> template <typename T, class Context>
inline vector<T*> caches(const vector<TIndex>& segments) { inline vector<T*> caches(
const vector<TIndex>& segments) {
TIndex total_count = 0; TIndex total_count = 0;
for (auto& segment : segments) total_count += segment; for (auto& segment : segments) total_count += segment;
Tensor* cacheT = CreateTensor("/share/cache"); Tensor* cacheT = CreateTensor("/share/cache");
cacheT->Reshape(vector<TIndex>(1, total_count)); cacheT->Reshape({ total_count });
vector<T*> caches(segments.size()); vector<T*> caches(segments.size());
caches[0] = cacheT->template mutable_data<T, Context>(); caches[0] = cacheT->template mutable_data<T, Context>();
for (int i = 1; i < segments.size(); i++) for (int i = 1; i < segments.size(); i++)
...@@ -190,12 +206,14 @@ class Workspace { ...@@ -190,12 +206,14 @@ class Workspace {
/******************** Operator ********************/ /******************** Operator ********************/
inline void CreatePersistentOp(const OperatorDef& meta_op) { inline void CreatePersistentOp(
const OperatorDef& meta_op) {
string persistent_key; string persistent_key;
for (auto& arg : meta_op.arg()) for (auto& arg : meta_op.arg())
if (arg.name() == "persistent_key") if (arg.name() == "persistent_key")
persistent_key = arg.s(); persistent_key = arg.s();
CHECK(persistent_key.size() > 0) << "\nGot empty persistent key."; CHECK(persistent_key.size() > 0)
<< "\nGot empty persistent key.";
if (!op_map_.count(persistent_key)) { if (!op_map_.count(persistent_key)) {
for (auto& input : meta_op.input()) CreateTensor(input); for (auto& input : meta_op.input()) CreateTensor(input);
op_map_[persistent_key] = unique_ptr<OperatorBase>( op_map_[persistent_key] = unique_ptr<OperatorBase>(
...@@ -203,7 +221,9 @@ class Workspace { ...@@ -203,7 +221,9 @@ class Workspace {
} }
} }
inline void RunPersistentOp(const string& key, const string& anchor, inline void RunPersistentOp(
const string& key,
const string& anchor,
const vector<string>& inputs, const vector<string>& inputs,
const vector<string>& outputs) { const vector<string>& outputs) {
CHECK(op_map_.count(key) > 0) CHECK(op_map_.count(key) > 0)
...@@ -236,11 +256,13 @@ class Workspace { ...@@ -236,11 +256,13 @@ class Workspace {
GraphBase* CreateGraph(const GraphDef& meta_graph); GraphBase* CreateGraph(const GraphDef& meta_graph);
void RunGraph(const string& graph_name, void RunGraph(
const string& graph_name,
const string& include, const string& include,
const string& exclude) { const string& exclude) {
if (!graph_map_.count(graph_name)) if (!graph_map_.count(graph_name))
LOG(FATAL) << "Graph(" << graph_name << ") does not exist."; LOG(FATAL) << "Graph(" << graph_name
<< ") does not exist.";
graph_map_[graph_name]->Run(include, exclude); graph_map_[graph_name]->Run(include, exclude);
} }
...@@ -252,7 +274,8 @@ class Workspace { ...@@ -252,7 +274,8 @@ class Workspace {
/******************** Utility ********************/ /******************** Utility ********************/
inline void CreateRename(const string& old_tensor, inline void CreateRename(
const string& old_tensor,
const string& new_tensor) { const string& new_tensor) {
rename_map_[old_tensor] = new_tensor; rename_map_[old_tensor] = new_tensor;
} }
......
...@@ -20,11 +20,11 @@ namespace dragon { ...@@ -20,11 +20,11 @@ namespace dragon {
template <class Context> template <class Context>
class DropoutOp final : public Operator<Context> { class DropoutOp final : public Operator<Context> {
public: public:
DropoutOp(const OperatorDef& op_def, Workspace* ws) DropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) { use_scale(OperatorBase::Arg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", "")); SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -39,11 +39,11 @@ class DropoutOp final : public Operator<Context> { ...@@ -39,11 +39,11 @@ class DropoutOp final : public Operator<Context> {
template <class Context> template <class Context>
class DropoutGradientOp final : public Operator<Context> { class DropoutGradientOp final : public Operator<Context> {
public: public:
DropoutGradientOp(const OperatorDef& op_def, Workspace* ws) DropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) { use_scale(OperatorBase::Arg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", "")); SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -66,12 +66,12 @@ DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob); ...@@ -66,12 +66,12 @@ DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob);
template <class Context> template <class Context>
class CuDNNDropoutOp final : public Operator<Context> { class CuDNNDropoutOp final : public Operator<Context> {
public: public:
CuDNNDropoutOp(const OperatorDef& op_def, Workspace* ws) CuDNNDropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false), : Operator<Context>(def, ws), states_initialized(false),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)), use_scale(OperatorBase::Arg<bool>("scale", true)),
random_seed(op_def.device_option().random_seed()) { random_seed(DEFAULT_RNG_SEED) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", "")); SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc)); CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
} }
...@@ -97,12 +97,12 @@ public: ...@@ -97,12 +97,12 @@ public:
template <class Context> template <class Context>
class CuDNNDropoutGradientOp final : public Operator<Context> { class CuDNNDropoutGradientOp final : public Operator<Context> {
public: public:
CuDNNDropoutGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNDropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false), : Operator<Context>(def, ws), states_initialized(false),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)), use_scale(OperatorBase::Arg<bool>("scale", true)),
random_seed(op_def.device_option().random_seed()) { random_seed(DEFAULT_RNG_SEED) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", "")); SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc)); CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
} }
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class EluOp : public Operator<Context> { class EluOp : public Operator<Context> {
public: public:
EluOp(const OperatorDef& op_def, Workspace* ws) EluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {} alpha(OperatorBase::Arg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -34,9 +34,9 @@ class EluOp : public Operator<Context> { ...@@ -34,9 +34,9 @@ class EluOp : public Operator<Context> {
template <class Context> template <class Context>
class EluGradientOp : public Operator<Context> { class EluGradientOp : public Operator<Context> {
public: public:
EluGradientOp(const OperatorDef& op_def, Workspace* ws) EluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {} alpha(OperatorBase::Arg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -53,8 +53,8 @@ class EluGradientOp : public Operator<Context> { ...@@ -53,8 +53,8 @@ class EluGradientOp : public Operator<Context> {
template <class Context> template <class Context>
class CuDNNEluOp final : public EluOp<Context> { class CuDNNEluOp final : public EluOp<Context> {
public: public:
CuDNNEluOp(const OperatorDef& op_def, Workspace* ws) CuDNNEluOp(const OperatorDef& def, Workspace* ws)
: EluOp<Context>(op_def, ws) { : EluOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
...@@ -80,8 +80,8 @@ public: ...@@ -80,8 +80,8 @@ public:
template <class Context> template <class Context>
class CuDNNEluGradientOp final : public EluGradientOp<Context> { class CuDNNEluGradientOp final : public EluGradientOp<Context> {
public: public:
CuDNNEluGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNEluGradientOp(const OperatorDef& def, Workspace* ws)
: EluGradientOp<Context>(op_def, ws) { : EluGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class PReluOp : public Operator<Context> { class PReluOp final : public Operator<Context> {
public: public:
PReluOp(const OperatorDef& op_def, Workspace* ws) PReluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)), channel_shared(OperatorBase::Arg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -34,12 +34,12 @@ class PReluOp : public Operator<Context> { ...@@ -34,12 +34,12 @@ class PReluOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class PReluGradientOp : public Operator<Context> { class PReluGradientOp final : public Operator<Context> {
public: public:
PReluGradientOp(const OperatorDef& op_def, Workspace* ws) PReluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)), channel_shared(OperatorBase::Arg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class ReluOp : public Operator<Context> { class ReluOp : public Operator<Context> {
public: public:
ReluOp(const OperatorDef& op_def, Workspace* ws) ReluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {} slope(OperatorBase::Arg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -34,9 +34,9 @@ class ReluOp : public Operator<Context> { ...@@ -34,9 +34,9 @@ class ReluOp : public Operator<Context> {
template <class Context> template <class Context>
class ReluGradientOp : public Operator<Context> { class ReluGradientOp : public Operator<Context> {
public: public:
ReluGradientOp(const OperatorDef& op_def, Workspace* ws) ReluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {} slope(OperatorBase::Arg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -51,8 +51,8 @@ class ReluGradientOp : public Operator<Context> { ...@@ -51,8 +51,8 @@ class ReluGradientOp : public Operator<Context> {
template <class Context> template <class Context>
class CuDNNReluOp final : public ReluOp<Context> { class CuDNNReluOp final : public ReluOp<Context> {
public: public:
CuDNNReluOp(const OperatorDef& op_def, Workspace* ws) CuDNNReluOp(const OperatorDef& def, Workspace* ws)
: ReluOp<Context>(op_def, ws) { : ReluOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
...@@ -78,8 +78,8 @@ public: ...@@ -78,8 +78,8 @@ public:
template <class Context> template <class Context>
class CuDNNReluGradientOp final : public ReluGradientOp<Context> { class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
public: public:
CuDNNReluGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNReluGradientOp(const OperatorDef& def, Workspace* ws)
: ReluGradientOp<Context>(op_def, ws) { : ReluGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SEluOp : public Operator<Context> { class SEluOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SEluOp); USE_SIMPLE_CTOR_DTOR(SEluOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -27,7 +27,7 @@ class SEluOp : public Operator<Context> { ...@@ -27,7 +27,7 @@ class SEluOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class SEluGradientOp : public Operator<Context> { class SEluGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SEluGradientOp); USE_SIMPLE_CTOR_DTOR(SEluGradientOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -41,8 +41,8 @@ class SigmoidGradientOp : public Operator<Context> { ...@@ -41,8 +41,8 @@ class SigmoidGradientOp : public Operator<Context> {
template <class Context> template <class Context>
class CuDNNSigmoidOp final : public SigmoidOp<Context> { class CuDNNSigmoidOp final : public SigmoidOp<Context> {
public: public:
CuDNNSigmoidOp(const OperatorDef& op_def, Workspace* ws) CuDNNSigmoidOp(const OperatorDef& def, Workspace* ws)
: SigmoidOp<Context>(op_def, ws) { : SigmoidOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
...@@ -68,8 +68,8 @@ public: ...@@ -68,8 +68,8 @@ public:
template <class Context> template <class Context>
class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> { class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
public: public:
CuDNNSigmoidGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNSigmoidGradientOp(const OperatorDef& def, Workspace* ws)
: SigmoidGradientOp<Context>(op_def, ws) { : SigmoidGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class SoftmaxOp final : public Operator<Context> { class SoftmaxOp final : public Operator<Context> {
public: public:
SoftmaxOp(const OperatorDef& op_def, Workspace* ws) SoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -34,9 +34,9 @@ class SoftmaxOp final : public Operator<Context> { ...@@ -34,9 +34,9 @@ class SoftmaxOp final : public Operator<Context> {
template <class Context> template <class Context>
class SoftmaxGradientOp final : public Operator<Context> { class SoftmaxGradientOp final : public Operator<Context> {
public: public:
SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws) SoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -53,9 +53,9 @@ class SoftmaxGradientOp final : public Operator<Context> { ...@@ -53,9 +53,9 @@ class SoftmaxGradientOp final : public Operator<Context> {
template <class Context> template <class Context>
class CuDNNSoftmaxOp final : public Operator<Context> { class CuDNNSoftmaxOp final : public Operator<Context> {
public: public:
CuDNNSoftmaxOp(const OperatorDef& op_def, Workspace* ws) CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) { axis(OperatorBase::Arg<int>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
...@@ -78,9 +78,9 @@ class CuDNNSoftmaxOp final : public Operator<Context> { ...@@ -78,9 +78,9 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
template <class Context> template <class Context>
class CuDNNSoftmaxGradientOp final : public Operator<Context> { class CuDNNSoftmaxGradientOp final : public Operator<Context> {
public: public:
CuDNNSoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) { axis(OperatorBase::Arg<int>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
......
...@@ -41,8 +41,8 @@ class TanhGradientOp : public Operator<Context> { ...@@ -41,8 +41,8 @@ class TanhGradientOp : public Operator<Context> {
template <class Context> template <class Context>
class CuDNNTanhOp final : public TanhOp<Context> { class CuDNNTanhOp final : public TanhOp<Context> {
public: public:
CuDNNTanhOp(const OperatorDef& op_def, Workspace* ws) CuDNNTanhOp(const OperatorDef& def, Workspace* ws)
: TanhOp<Context>(op_def, ws) { : TanhOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
...@@ -68,8 +68,8 @@ public: ...@@ -68,8 +68,8 @@ public:
template <class Context> template <class Context>
class CuDNNTanhGradientOp final : public TanhGradientOp<Context> { class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
public: public:
CuDNNTanhGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNTanhGradientOp(const OperatorDef& def, Workspace* ws)
: TanhGradientOp<Context>(op_def, ws) { : TanhGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc)); CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class AffineOp : public Operator<Context> { class AffineOp final : public Operator<Context> {
public: public:
AffineOp(const OperatorDef& op_def, Workspace* ws) AffineOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {} num_axes(OperatorBase::Arg<int>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -36,10 +36,10 @@ class AffineOp : public Operator<Context> { ...@@ -36,10 +36,10 @@ class AffineOp : public Operator<Context> {
template <class Context> template <class Context>
class AffineGradientOp final : public Operator<Context> { class AffineGradientOp final : public Operator<Context> {
public: public:
AffineGradientOp(const OperatorDef& op_def, Workspace* ws) AffineGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {} num_axes(OperatorBase::Arg<int>("num_axes", -1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -60,10 +60,10 @@ class AffineGradientOp final : public Operator<Context> { ...@@ -60,10 +60,10 @@ class AffineGradientOp final : public Operator<Context> {
template <class Context> template <class Context>
class CuDNNAffineOpBase : public Operator<Context> { class CuDNNAffineOpBase : public Operator<Context> {
public: public:
CuDNNAffineOpBase(const OperatorDef& op_def, Workspace* ws) CuDNNAffineOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) { num_axes(OperatorBase::Arg<int>("num_axes", -1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc));
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_desc)); CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_desc));
...@@ -120,10 +120,10 @@ class CuDNNAffineOpBase : public Operator<Context> { ...@@ -120,10 +120,10 @@ class CuDNNAffineOpBase : public Operator<Context> {
using CuDNNAffineOpBase<Context>::reduce_desc using CuDNNAffineOpBase<Context>::reduce_desc
template <class Context> template <class Context>
class CuDNNAffineOp : public CuDNNAffineOpBase<Context> { class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> {
public: public:
CuDNNAffineOp(const OperatorDef& op_def, Workspace* ws) CuDNNAffineOp(const OperatorDef& def, Workspace* ws)
: CuDNNAffineOpBase<Context>(op_def, ws) {} : CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -133,10 +133,10 @@ class CuDNNAffineOp : public CuDNNAffineOpBase<Context> { ...@@ -133,10 +133,10 @@ class CuDNNAffineOp : public CuDNNAffineOpBase<Context> {
}; };
template <class Context> template <class Context>
class CuDNNAffineGradientOp : public CuDNNAffineOpBase<Context> { class CuDNNAffineGradientOp final : public CuDNNAffineOpBase<Context> {
public: public:
CuDNNAffineGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNAffineGradientOp(const OperatorDef& def, Workspace* ws)
: CuDNNAffineOpBase<Context>(op_def, ws) {} : CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ComputeScaleGradient(T* dYxX, T* dA); template <typename T> void ComputeScaleGradient(T* dYxX, T* dA);
......
...@@ -20,10 +20,10 @@ namespace dragon { ...@@ -20,10 +20,10 @@ namespace dragon {
template <class Context> template <class Context>
class ClipOp final : public Operator<Context> { class ClipOp final : public Operator<Context> {
public: public:
ClipOp(const OperatorDef& op_def, Workspace* ws) ClipOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)), low(OperatorBase::Arg<float>("low", -FLT_MAX)),
high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {} high(OperatorBase::Arg<float>("high", FLT_MAX)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,10 +19,10 @@ namespace dragon { ...@@ -19,10 +19,10 @@ namespace dragon {
template <class Context> template <class Context>
class DotOp final : public Operator<Context> { class DotOp final : public Operator<Context> {
public: public:
DotOp(const OperatorDef& op_def, Workspace* ws) DotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)), TransA(OperatorBase::Arg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} TransB(OperatorBase::Arg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -37,10 +37,10 @@ class DotOp final : public Operator<Context> { ...@@ -37,10 +37,10 @@ class DotOp final : public Operator<Context> {
template <class Context> template <class Context>
class DotGradientOp final : public Operator<Context> { class DotGradientOp final : public Operator<Context> {
public: public:
DotGradientOp(const OperatorDef& op_def, Workspace* ws) DotGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)), TransA(OperatorBase::Arg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} TransB(OperatorBase::Arg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,10 +19,10 @@ namespace dragon { ...@@ -19,10 +19,10 @@ namespace dragon {
template <class Context> template <class Context>
class EltwiseOp final : public Operator<Context> { class EltwiseOp final : public Operator<Context> {
public: public:
EltwiseOp(const OperatorDef& op_def, Workspace* ws) EltwiseOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "SUM")), operation(OperatorBase::Arg<string>("operation", "SUM")),
coeffs(OperatorBase::GetRepeatedArg<float>("coeffs")) { coeffs(OperatorBase::Args<float>("coeffs")) {
if (coeffs.size() > 0) { if (coeffs.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize()) CHECK_EQ(coeffs.size(), InputSize())
<< "\nOp has " << InputSize() << " inputs, " << "\nOp has " << InputSize() << " inputs, "
...@@ -43,10 +43,10 @@ class EltwiseOp final : public Operator<Context> { ...@@ -43,10 +43,10 @@ class EltwiseOp final : public Operator<Context> {
template <class Context> template <class Context>
class EltwiseGradientOp final : public Operator<Context> { class EltwiseGradientOp final : public Operator<Context> {
public: public:
EltwiseGradientOp(const OperatorDef& op_def, Workspace* ws) EltwiseGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "SUM")), operation(OperatorBase::Arg<string>("operation", "SUM")),
coeffs(OperatorBase::GetRepeatedArg<float>("coeff")) { coeffs(OperatorBase::Args<float>("coeff")) {
if (coeffs.size() > 0) { if (coeffs.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize()) CHECK_EQ(coeffs.size(), InputSize())
<< "\nop has " << InputSize() << " inputs, " << "\nop has " << InputSize() << " inputs, "
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class GramMatrixOp final : public Operator<Context> { class GramMatrixOp final : public Operator<Context> {
public: public:
GramMatrixOp(const OperatorDef& op_def, Workspace* ws) GramMatrixOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,9 +35,9 @@ class GramMatrixOp final : public Operator<Context> { ...@@ -35,9 +35,9 @@ class GramMatrixOp final : public Operator<Context> {
template <class Context> template <class Context>
class GramMatrixGradientOp final : public Operator<Context> { class GramMatrixGradientOp final : public Operator<Context> {
public: public:
GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws) GramMatrixGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class InnerProductOp: public Operator<Context> { class InnerProductOp final : public Operator<Context> {
public: public:
InnerProductOp(const OperatorDef& op_def, Workspace *ws) InnerProductOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)), num_output(OperatorBase::Arg<int>("num_output", 0)),
TransW(OperatorBase::GetSingleArg<bool>("TransW", true)) {} TransW(OperatorBase::Arg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice(); void RunOnDevice();
...@@ -37,11 +37,11 @@ class InnerProductOp: public Operator<Context> { ...@@ -37,11 +37,11 @@ class InnerProductOp: public Operator<Context> {
template <class Context> template <class Context>
class InnerProductGradientOp final : public Operator<Context> { class InnerProductGradientOp final : public Operator<Context> {
public: public:
InnerProductGradientOp(const OperatorDef& op_def, Workspace *ws) InnerProductGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)), num_output(OperatorBase::Arg<int>("num_output", 0)),
TransW(OperatorBase::GetSingleArg<bool>("TransW", true)) {} TransW(OperatorBase::Arg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,10 +19,10 @@ namespace dragon { ...@@ -19,10 +19,10 @@ namespace dragon {
template <class Context> template <class Context>
class MatmulOp final : public Operator<Context> { class MatmulOp final : public Operator<Context> {
public: public:
MatmulOp(const OperatorDef& op_def, Workspace* ws) MatmulOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)), TransA(OperatorBase::Arg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} TransB(OperatorBase::Arg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -36,10 +36,10 @@ class MatmulOp final : public Operator<Context> { ...@@ -36,10 +36,10 @@ class MatmulOp final : public Operator<Context> {
template <class Context> template <class Context>
class MatmulGradientOp final : public Operator<Context> { class MatmulGradientOp final : public Operator<Context> {
public: public:
MatmulGradientOp(const OperatorDef& op_def, Workspace* ws) MatmulGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)), TransA(OperatorBase::Arg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} TransB(OperatorBase::Arg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class PowOp: public Operator<Context> { class PowOp final : public Operator<Context> {
public: public:
PowOp(const OperatorDef& op_def, Workspace* ws) PowOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::GetSingleArg<float>("scale", 1.0)), scale(OperatorBase::Arg<float>("scale", 1.0)),
shift(OperatorBase::GetSingleArg<float>("shift", 0.0)), shift(OperatorBase::Arg<float>("shift", 0.0)),
power(OperatorBase::GetSingleArg<float>("power", 1.0)) { power(OperatorBase::Arg<float>("power", 1.0)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -38,11 +38,11 @@ class PowOp: public Operator<Context> { ...@@ -38,11 +38,11 @@ class PowOp: public Operator<Context> {
template <class Context> template <class Context>
class PowGradientOp final : public Operator<Context> { class PowGradientOp final : public Operator<Context> {
public: public:
PowGradientOp(const OperatorDef& op_def, Workspace* ws) PowGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
scale(OperatorBase::GetSingleArg<float>("scale", 1.0)), scale(OperatorBase::Arg<float>("scale", 1.0)),
shift(OperatorBase::GetSingleArg<float>("shift", 0.0)), shift(OperatorBase::Arg<float>("shift", 0.0)),
power(OperatorBase::GetSingleArg<float>("power", 1.0)) { power(OperatorBase::Arg<float>("power", 1.0)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class CompareOp final : public Operator<Context> { class CompareOp final : public Operator<Context> {
public: public:
CompareOp(const OperatorDef& op_def, Workspace* ws) CompareOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {} operation(OperatorBase::Arg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,16 +19,16 @@ namespace dragon { ...@@ -19,16 +19,16 @@ namespace dragon {
template <class Context> template <class Context>
class ScanOp final: public Operator<Context> { class ScanOp final: public Operator<Context> {
public: public:
ScanOp(const OperatorDef& op_def, Workspace *ws) ScanOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)),
nsteps(OperatorBase::GetSingleArg<int>("nsteps", 0)), nsteps(OperatorBase::Arg<int>("nsteps", 0)),
step_type(OperatorBase::GetSingleArg<string>("step_type", "Static")), step_type(OperatorBase::Arg<string>("step_type", "Static")),
step_tensor(OperatorBase::GetSingleArg<string>("step_tensor", "")), step_tensor(OperatorBase::Arg<string>("step_tensor", "")),
nseqs(OperatorBase::GetSingleArg<int>("nseqs", 0)), nseqs(OperatorBase::Arg<int>("nseqs", 0)),
default_outputs(OperatorBase::GetRepeatedArg<string>("default_outputs")), default_outputs(OperatorBase::Args<string>("default_outputs")),
nout((int)default_outputs.size()), nout((int)default_outputs.size()),
debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) { debug_mode(OperatorBase::Arg<bool>("debug_mode", false)) {
InitTemplate(); InitTemplate();
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -52,14 +52,14 @@ class ScanOp final: public Operator<Context> { ...@@ -52,14 +52,14 @@ class ScanOp final: public Operator<Context> {
template <class Context> template <class Context>
class ScanGradientOp final: public Operator<Context> { class ScanGradientOp final: public Operator<Context> {
public: public:
ScanGradientOp(const OperatorDef& op_def, Workspace* ws) ScanGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)),
nsteps(OperatorBase::GetSingleArg<int>("nsteps", 0)), nsteps(OperatorBase::Arg<int>("nsteps", 0)),
step_type(OperatorBase::GetSingleArg<string>("step_type", "Static")), step_type(OperatorBase::Arg<string>("step_type", "Static")),
step_tensor(OperatorBase::GetSingleArg<string>("step_tensor", "")), step_tensor(OperatorBase::Arg<string>("step_tensor", "")),
forward_inputs(OperatorBase::GetRepeatedArg<string>("inputs_name")), forward_inputs(OperatorBase::Args<string>("inputs_name")),
forward_outputs(OperatorBase::GetRepeatedArg<string>("outputs_name")) { forward_outputs(OperatorBase::Args<string>("outputs_name")) {
// handle GO(x) // handle GO(x)
for (int i = 0; i < forward_outputs.size(); i++) for (int i = 0; i < forward_outputs.size(); i++)
terms[forward_outputs[i] + "_grad"] = Input(i + (int)OutputSize()).name(); terms[forward_outputs[i] + "_grad"] = Input(i + (int)OutputSize()).name();
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class CTCLossOp final : public Operator<Context> {
public:
CTCLossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
LOG(FATAL) << "CTCLoss requires CuDNN support.";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {}
};
template <class Context>
class CTCLossGradientOp final : public Operator<Context> {
public:
CTCLossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
};
#ifdef WITH_CUDNN
#if CUDNN_VERSION_MIN(7, 0, 0)
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNCTCLossOp final : public Operator<Context> {
public:
CuDNNCTCLossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
blank_first(OperatorBase::Arg<bool>("blank_first", true)),
padding_mask(OperatorBase::Arg<int>("padding_mask", -1)) {
CUDNN_CHECK(cudnnCreateCTCLossDescriptor(&ctc_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&prob_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&grad_desc));
ctc_algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
}
USE_OPERATOR_FUNCTIONS;
~CuDNNCTCLossOp() {
CUDNN_CHECK(cudnnDestroyCTCLossDescriptor(ctc_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(prob_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(grad_desc));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
void WrapIO();
protected:
bool blank_first;
TIndex padding_mask;
cudnnCTCLossAlgo_t ctc_algo;
cudnnCTCLossDescriptor_t ctc_desc;
cudnnTensorDescriptor_t prob_desc, grad_desc;
size_t workspace_size;
vector<int> packed_labels, label_lengths, input_lengths;
};
#endif
#endif // WITH_CUDNN
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
\ No newline at end of file
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class L1LossOp : public Operator<Context> { class L1LossOp final : public Operator<Context> {
public: public:
L1LossOp(const OperatorDef& op_def, Workspace* ws) L1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,9 +36,10 @@ class L1LossOp : public Operator<Context> { ...@@ -35,9 +36,10 @@ class L1LossOp : public Operator<Context> {
template <class Context> template <class Context>
class L1LossGradientOp final : public Operator<Context> { class L1LossGradientOp final : public Operator<Context> {
public: public:
L1LossGradientOp(const OperatorDef& op_def, Workspace* ws) L1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class L2LossOp : public Operator<Context> { class L2LossOp final : public Operator<Context> {
public: public:
L2LossOp(const OperatorDef& op_def, Workspace* ws) L2LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,9 +36,10 @@ class L2LossOp : public Operator<Context> { ...@@ -35,9 +36,10 @@ class L2LossOp : public Operator<Context> {
template <class Context> template <class Context>
class L2LossGradientOp final : public Operator<Context> { class L2LossGradientOp final : public Operator<Context> {
public: public:
L2LossGradientOp(const OperatorDef& op_def, Workspace* ws) L2LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,9 +19,10 @@ namespace dragon { ...@@ -19,9 +19,10 @@ namespace dragon {
template <class Context> template <class Context>
class SigmoidCrossEntropyOp final : public Operator<Context> { class SigmoidCrossEntropyOp final : public Operator<Context> {
public: public:
SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {} normalization(OperatorBase::Arg<string>(
"normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,9 +36,10 @@ class SigmoidCrossEntropyOp final : public Operator<Context> { ...@@ -35,9 +36,10 @@ class SigmoidCrossEntropyOp final : public Operator<Context> {
template <class Context> template <class Context>
class SigmoidCrossEntropyGradientOp final : public Operator<Context> { class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
public: public:
SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {} normalization(OperatorBase::Arg<string>(
"normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,10 +19,11 @@ namespace dragon { ...@@ -19,10 +19,11 @@ namespace dragon {
template <class Context> template <class Context>
class SmoothL1LossOp final : public Operator<Context> { class SmoothL1LossOp final : public Operator<Context> {
public: public:
SmoothL1LossOp(const OperatorDef& op_def, Workspace* ws) SmoothL1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)), beta(OperatorBase::Arg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -37,10 +38,11 @@ class SmoothL1LossOp final : public Operator<Context> { ...@@ -37,10 +38,11 @@ class SmoothL1LossOp final : public Operator<Context> {
template <class Context> template <class Context>
class SmoothL1LossGradientOp final : public Operator<Context> { class SmoothL1LossGradientOp final : public Operator<Context> {
public: public:
SmoothL1LossGradientOp(const OperatorDef& op_def, Workspace* ws) SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)), beta(OperatorBase::Arg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,11 +19,11 @@ namespace dragon { ...@@ -19,11 +19,11 @@ namespace dragon {
template <class Context> template <class Context>
class SoftmaxCrossEntropyOp final : public Operator<Context> { class SoftmaxCrossEntropyOp final : public Operator<Context> {
public: public:
SoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws) SoftmaxCrossEntropyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) { normalization(OperatorBase::Arg<string>(
} "normalization", "FULL")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void SoftmaxRun(); void SoftmaxRun();
...@@ -41,10 +41,11 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> { ...@@ -41,10 +41,11 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
template <class Context> template <class Context>
class SoftmaxCrossEntropyGradientOp final : public Operator<Context> { class SoftmaxCrossEntropyGradientOp final : public Operator<Context> {
public: public:
SoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws) SoftmaxCrossEntropyGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {} normalization(OperatorBase::Arg<string>(
"normalization", "FULL")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,15 +19,16 @@ namespace dragon { ...@@ -19,15 +19,16 @@ namespace dragon {
template <class Context> template <class Context>
class SparseSoftmaxCrossEntropyOp : public Operator<Context> { class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
public: public:
SparseSoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxCrossEntropyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) { normalization(OperatorBase::Arg<string>(
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels"); "normalization", "VALID")) {
if (args.size()) { vector<int> ignores = OperatorBase::Args<int>("ignore_labels");
ignore.Reshape(vector<TIndex>(1, args.size())); if (ignores.size()) {
int* ignore_data = ignore.mutable_data<int, CPUContext>(); ignore.Reshape({ (TIndex)ignores.size() });
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; auto* Idata = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -49,15 +50,16 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -49,15 +50,16 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
template <class Context> template <class Context>
class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> { class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
public: public:
SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) { normalization(OperatorBase::Arg<string>(
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels"); "normalization", "VALID")) {
if (args.size()) { vector<int> ignores = OperatorBase::Args<int>("ignore_labels");
ignore.Reshape(vector<TIndex>(1, args.size())); if (ignores.size()) {
int* ignore_data = ignore.mutable_data<int, CPUContext>(); ignore.Reshape({ (TIndex)ignores.size() });
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; auto* Idata = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -19,13 +19,14 @@ namespace dragon { ...@@ -19,13 +19,14 @@ namespace dragon {
template <class Context> template <class Context>
class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Context> { class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Context> {
public: public:
SparseSoftmaxFocalLossOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxFocalLossOp(const OperatorDef& def, Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(op_def, ws), : SparseSoftmaxCrossEntropyOp<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")), normalization(OperatorBase::Arg<string>(
alpha(OperatorBase::GetSingleArg<float>("alpha", 0.5)), "normalization", "VALID")),
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)), alpha(OperatorBase::Arg<float>("alpha", 0.5)),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) { gamma(OperatorBase::Arg<float>("gamma", 0.0)),
neg_id(OperatorBase::Arg<int>("neg_id", -1)) {
pos_alpha = alpha * 2.0; pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0; neg_alpha = (1 - alpha) * 2.0;
} }
...@@ -46,13 +47,14 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex ...@@ -46,13 +47,14 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
template <class Context> template <class Context>
class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyGradientOp<Context> { class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyGradientOp<Context> {
public: public:
SparseSoftmaxFocalLossGradientOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxFocalLossGradientOp(const OperatorDef& def, Workspace* ws)
: SparseSoftmaxCrossEntropyGradientOp<Context>(op_def, ws), : SparseSoftmaxCrossEntropyGradientOp<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")), normalization(OperatorBase::Arg<string>(
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)), "normalization", "VALID")),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))), gamma(OperatorBase::Arg<float>("gamma", 0.0)),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {} eps(OperatorBase::Arg<float>("eps", float(1e-10))),
neg_id(OperatorBase::Arg<int>("neg_id", -1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,17 +17,17 @@ ...@@ -17,17 +17,17 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class AccuracyOp final: public Operator<Context> { class AccuracyOp final : public Operator<Context> {
public: public:
AccuracyOp(const OperatorDef& op_def, Workspace* ws) AccuracyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)), top_k(OperatorBase::Arg<int>("top_k", 1)),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) { axis(OperatorBase::Arg<int>("axis", 1)) {
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels"); vector<int> ignores = OperatorBase::Args<int>("ignore_labels");
if (args.size()) { if (ignores.size()) {
ignore_labels.Reshape(vector<TIndex>(1, args.size())); ignore.Reshape({ (TIndex)ignores.size() });
int* ignore_data = ignore_labels.mutable_data<int, CPUContext>(); auto* Idata = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i];
} }
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -37,7 +37,7 @@ class AccuracyOp final: public Operator<Context> { ...@@ -37,7 +37,7 @@ class AccuracyOp final: public Operator<Context> {
protected: protected:
TIndex top_k, axis, outer_dim, inner_dim, num_classes; TIndex top_k, axis, outer_dim, inner_dim, num_classes;
Tensor ignore_labels; Tensor ignore;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -19,10 +19,10 @@ namespace dragon { ...@@ -19,10 +19,10 @@ namespace dragon {
template <class Context> template <class Context>
class AsTypeOp final : public Operator<Context> { class AsTypeOp final : public Operator<Context> {
public: public:
AsTypeOp(const OperatorDef& op_def, Workspace* ws) AsTypeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "float32")), dtype(OperatorBase::Arg<string>("dtype", "float32")),
inplace(OperatorBase::GetSingleArg<bool>("inplace", false)) {} inplace(OperatorBase::Arg<bool>("inplace", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class GradientGenerateOp final: public Operator<Context> { class GradientGenerateOp final: public Operator<Context> {
public: public:
GradientGenerateOp(const OperatorDef& op_def, Workspace* ws) GradientGenerateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
defaults(OperatorBase::GetRepeatedArg<float>("defaults")) { defaults(OperatorBase::Args<float>("defaults")) {
CHECK_EQ(InputSize(), OutputSize()); CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize()); CHECK_EQ(defaults.size(), OutputSize());
} }
...@@ -37,8 +37,8 @@ class GradientGenerateOp final: public Operator<Context> { ...@@ -37,8 +37,8 @@ class GradientGenerateOp final: public Operator<Context> {
template <class Context> template <class Context>
class GradientGatherOp final : public Operator<Context> { class GradientGatherOp final : public Operator<Context> {
public: public:
GradientGatherOp(const OperatorDef& op_def, Workspace* ws) GradientGatherOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(def, ws) {
for (int i = 0; i < InputSize(); i++) for (int i = 0; i < InputSize(); i++)
if (Input(i).name() != "ignore") indices.push_back(i); if (Input(i).name() != "ignore") indices.push_back(i);
} }
......
...@@ -19,23 +19,23 @@ namespace dragon { ...@@ -19,23 +19,23 @@ namespace dragon {
template <class Context> template <class Context>
class ImageDataOp final : public Operator<Context> { class ImageDataOp final : public Operator<Context> {
public: public:
ImageDataOp(const OperatorDef& op_def, Workspace* ws) ImageDataOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")), dtype(OperatorBase::Arg<string>("dtype", "FLOAT32")),
mean_values(OperatorBase::GetRepeatedArg<float>("mean_values")), mean_values(OperatorBase::Args<float>("mean_values")),
std_values(OperatorBase::GetRepeatedArg<float>("std_values")), std_values(OperatorBase::Args<float>("std_values")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {
if (mean_values.size() > 0) { if (mean_values.size() > 0) {
CHECK_EQ((int)mean_values.size(), 3) CHECK_EQ((int)mean_values.size(), 3)
<< "The mean values should be a list with length 3."; << "The mean values should be a list with length 3.";
mean.Reshape(vector<TIndex>(1, 3)); mean.Reshape({ 3 });
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
mean.mutable_data<float, CPUContext>()[i] = mean_values[i]; mean.mutable_data<float, CPUContext>()[i] = mean_values[i];
} }
if (std_values.size() > 0) { if (std_values.size() > 0) {
CHECK_EQ((int)std_values.size(), 3) CHECK_EQ((int)std_values.size(), 3)
<< "The std values should be a list with length 3."; << "The std values should be a list with length 3.";
std.Reshape(vector<TIndex>(1, 3)); std.Reshape({ 3 });
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
std.mutable_data<float, CPUContext>()[i] = std_values[i]; std.mutable_data<float, CPUContext>()[i] = std_values[i];
} }
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class InitializeOp: public Operator<Context> { class InitializeOp : public Operator<Context> {
public: public:
InitializeOp(const OperatorDef& op_def, Workspace* ws) InitializeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) { shape_desc(OperatorBase::Arg<string>("shape", "")) {
GET_ARGUMENTS_WITH_DESC(int, dims); GET_ARGUMENTS_WITH_DESC(int, dims);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -38,11 +38,11 @@ class InitializeOp: public Operator<Context> { ...@@ -38,11 +38,11 @@ class InitializeOp: public Operator<Context> {
template <class Context> template <class Context>
class FillOp final : public InitializeOp<Context> { class FillOp final : public InitializeOp<Context> {
public: public:
FillOp(const OperatorDef& op_def, Workspace* ws) FillOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("constant"); this->filler.set_type("constant");
this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0)); this->filler.set_value(OperatorBase::Arg<float>("value", 0.0));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
...@@ -50,11 +50,11 @@ public: ...@@ -50,11 +50,11 @@ public:
template <class Context> template <class Context>
class RandomUniformOp final : public InitializeOp<Context> { class RandomUniformOp final : public InitializeOp<Context> {
public: public:
RandomUniformOp(const OperatorDef& op_def, Workspace* ws) RandomUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("uniform"); this->filler.set_type("uniform");
this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0)); this->filler.set_low(OperatorBase::Arg<float>("low", -1.0));
this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0)); this->filler.set_high(OperatorBase::Arg<float>("high", 1.0));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
...@@ -62,11 +62,11 @@ public: ...@@ -62,11 +62,11 @@ public:
template <class Context> template <class Context>
class RandomNormalOp final : public InitializeOp<Context> { class RandomNormalOp final : public InitializeOp<Context> {
public: public:
RandomNormalOp(const OperatorDef& op_def, Workspace* ws) RandomNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("normal"); this->filler.set_type("normal");
this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0)); this->filler.set_mean(OperatorBase::Arg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0)); this->filler.set_std(OperatorBase::Arg<float>("std", 1.0));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
...@@ -74,11 +74,11 @@ public: ...@@ -74,11 +74,11 @@ public:
template <class Context> template <class Context>
class TruncatedNormalOp final : public InitializeOp<Context> { class TruncatedNormalOp final : public InitializeOp<Context> {
public: public:
TruncatedNormalOp(const OperatorDef& op_def, Workspace* ws) TruncatedNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) { : InitializeOp<Context>(def, ws) {
this->filler.set_type("truncated_normal"); this->filler.set_type("truncated_normal");
float mu = OperatorBase::GetSingleArg<float>("mean", 0.0); float mu = OperatorBase::Arg<float>("mean", 0.0);
float sigma = OperatorBase::GetSingleArg<float>("std", 1.0); float sigma = OperatorBase::Arg<float>("std", 1.0);
this->filler.set_mean(mu); this->filler.set_mean(mu);
this->filler.set_std(sigma); this->filler.set_std(sigma);
this->filler.set_low(mu - 2 * sigma); this->filler.set_low(mu - 2 * sigma);
...@@ -90,10 +90,10 @@ public: ...@@ -90,10 +90,10 @@ public:
template <class Context> template <class Context>
class GlorotUniformOp final : public InitializeOp<Context> { class GlorotUniformOp final : public InitializeOp<Context> {
public: public:
GlorotUniformOp(const OperatorDef& op_def, Workspace* ws) GlorotUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) { : InitializeOp<Context>(def, ws) {
string mode = OperatorBase::GetSingleArg<string>("mode", "fan_in"); string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::GetSingleArg<float>("scale", 3.0); float scale = OperatorBase::Arg<float>("scale", 3.0);
this->filler.set_type("xavier"); this->filler.set_type("xavier");
if (mode == "fan_avg") { if (mode == "fan_avg") {
...@@ -111,10 +111,10 @@ public: ...@@ -111,10 +111,10 @@ public:
template <class Context> template <class Context>
class GlorotNormalOp final : public InitializeOp<Context> { class GlorotNormalOp final : public InitializeOp<Context> {
public: public:
GlorotNormalOp(const OperatorDef& op_def, Workspace* ws) GlorotNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) { : InitializeOp<Context>(def, ws) {
string mode = OperatorBase::GetSingleArg<string>("mode", "fan_in"); string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::GetSingleArg<float>("scale", 2.0); float scale = OperatorBase::Arg<float>("scale", 2.0);
this->filler.set_type("msra"); this->filler.set_type("msra");
if (mode == "fan_avg") { if (mode == "fan_avg") {
......
...@@ -23,7 +23,7 @@ namespace dragon { ...@@ -23,7 +23,7 @@ namespace dragon {
template <class Context> template <class Context>
class RunOp : public Operator<Context> { class RunOp : public Operator<Context> {
public: public:
RunOp(const OperatorDef& op_def, Workspace* ws); RunOp(const OperatorDef& def, Workspace* ws);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -36,16 +36,16 @@ class RunOp : public Operator<Context> { ...@@ -36,16 +36,16 @@ class RunOp : public Operator<Context> {
template <class Context> template <class Context>
class TemplateOp : public RunOp<Context> { class TemplateOp : public RunOp<Context> {
public: public:
TemplateOp(const OperatorDef& op_def, Workspace* ws) TemplateOp(const OperatorDef& def, Workspace* ws)
: RunOp<Context>(op_def, ws) {} : RunOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
class TemplateGradientOp : public TemplateOp<Context> { class TemplateGradientOp final : public TemplateOp<Context> {
public: public:
TemplateGradientOp(const OperatorDef& op_def, Workspace* ws) TemplateGradientOp(const OperatorDef& def, Workspace* ws)
: TemplateOp<Context>(op_def, ws) {} : TemplateOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -22,11 +22,11 @@ namespace dragon { ...@@ -22,11 +22,11 @@ namespace dragon {
template <class Context> template <class Context>
class ModelMPIBase : public Operator<Context> { class ModelMPIBase : public Operator<Context> {
public: public:
ModelMPIBase(const OperatorDef& op_def, Workspace* ws) ModelMPIBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
comm((MPI_Comm)OperatorBase::GetSingleArg<int64_t>("comm", 0)), comm((MPI_Comm)OperatorBase::Arg<int64_t>("comm", 0)),
group((MPI_Group)OperatorBase::GetSingleArg<int64_t>("group", 0)), group((MPI_Group)OperatorBase::Arg<int64_t>("group", 0)),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) { dtype(OperatorBase::Arg<string>("dtype", "FLOAT32")) {
if (comm == MPI_COMM_NULL) return; if (comm == MPI_COMM_NULL) return;
MPI_Comm_size(MPI_COMM_WORLD, &world_size); MPI_Comm_size(MPI_COMM_WORLD, &world_size);
...@@ -36,7 +36,7 @@ class ModelMPIBase : public Operator<Context> { ...@@ -36,7 +36,7 @@ class ModelMPIBase : public Operator<Context> {
MPI_Group world_group; MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group); MPI_Comm_group(MPI_COMM_WORLD, &world_group);
int world_root = OperatorBase::GetSingleArg<int>("root", 0); int world_root = OperatorBase::Arg<int>("root", 0);
MPI_Group_translate_ranks(world_group, 1, &world_root, group, &comm_root); MPI_Group_translate_ranks(world_group, 1, &world_root, group, &comm_root);
CHECK(comm_root != MPI_UNDEFINED) << "MPI root is not included in layer group."; CHECK(comm_root != MPI_UNDEFINED) << "MPI root is not included in layer group.";
......
...@@ -21,8 +21,8 @@ namespace dragon { ...@@ -21,8 +21,8 @@ namespace dragon {
template <class Context> template <class Context>
class MPIBroadcastOp final : public ModelMPIBase<Context> { class MPIBroadcastOp final : public ModelMPIBase<Context> {
public: public:
MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws) MPIBroadcastOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
...@@ -33,8 +33,8 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> { ...@@ -33,8 +33,8 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
template <class Context> template <class Context>
class MPIBroadcastGradientOp final : public ModelMPIBase<Context> { class MPIBroadcastGradientOp final : public ModelMPIBase<Context> {
public: public:
MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws) MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
......
...@@ -21,8 +21,8 @@ namespace dragon { ...@@ -21,8 +21,8 @@ namespace dragon {
template <class Context> template <class Context>
class MPIGatherOp final : public ModelMPIBase<Context> { class MPIGatherOp final : public ModelMPIBase<Context> {
public: public:
MPIGatherOp(const OperatorDef& op_def, Workspace *ws) MPIGatherOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
...@@ -33,8 +33,8 @@ class MPIGatherOp final : public ModelMPIBase<Context> { ...@@ -33,8 +33,8 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
template <class Context> template <class Context>
class MPIGatherGradientOp final : public ModelMPIBase<Context> { class MPIGatherGradientOp final : public ModelMPIBase<Context> {
public: public:
MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws) MPIGatherGradientOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class ArangeOp final : public Operator<Context> { class ArangeOp final : public Operator<Context> {
public: public:
ArangeOp(const OperatorDef& op_def, Workspace* ws) ArangeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) { dtype(OperatorBase::Arg<string>("dtype", "FLOAT32")) {
GET_ARGUMENT_WITH_DESC(int, start, 0); GET_ARGUMENT_WITH_DESC(int, start, 0);
GET_ARGUMENT_WITH_DESC(int, stop, 0); GET_ARGUMENT_WITH_DESC(int, stop, 0);
GET_ARGUMENT_WITH_DESC(int, step, 1); GET_ARGUMENT_WITH_DESC(int, step, 1);
......
...@@ -19,12 +19,12 @@ namespace dragon { ...@@ -19,12 +19,12 @@ namespace dragon {
template <class Context> template <class Context>
class ArgReduceOp final : public Operator<Context> { class ArgReduceOp final : public Operator<Context> {
public: public:
ArgReduceOp(const OperatorDef& op_def, Workspace* ws) ArgReduceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")), operation(OperatorBase::Arg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)), keep_dims(OperatorBase::Arg<bool>("keep_dims", false)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {} top_k(OperatorBase::Arg<int>("top_k", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,16 +19,17 @@ namespace dragon { ...@@ -19,16 +19,17 @@ namespace dragon {
template <class Context> template <class Context>
class ConcatOp : public Operator<Context> { class ConcatOp : public Operator<Context> {
public: public:
ConcatOp(const OperatorDef& op_def, Workspace* ws) ConcatOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, outer_dim, inner_dim, x_concat_dim, y_concat_dim; TIndex axis, outer_dim, inner_dim;
TIndex x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset; TIndex x_offset, y_offset, concat_offset;
vector<TIndex> concat_dims; vector<TIndex> concat_dims;
}; };
...@@ -36,16 +37,17 @@ class ConcatOp : public Operator<Context> { ...@@ -36,16 +37,17 @@ class ConcatOp : public Operator<Context> {
template <class Context> template <class Context>
class ConcatGradientOp : public Operator<Context> { class ConcatGradientOp : public Operator<Context> {
public: public:
ConcatGradientOp(const OperatorDef& op_def, Workspace* ws) ConcatGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, outer_dim, inner_dim, x_concat_dim, y_concat_dim; TIndex axis, outer_dim, inner_dim;
TIndex x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset; TIndex x_offset, y_offset, concat_offset;
vector<TIndex> concat_dims; vector<TIndex> concat_dims;
}; };
......
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class CropOp: public Operator<Context> { class CropOp final : public Operator<Context> {
public: public:
CropOp(const OperatorDef& op_def, Workspace* ws) CropOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)), start_axis(OperatorBase::Arg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")), offsets(OperatorBase::Args<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")), shape(OperatorBase::Args<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) { shape_like(OperatorBase::Arg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, starts); GET_ARGUMENTS_WITH_DESC(int, starts);
GET_ARGUMENTS_WITH_DESC(int, ends); GET_ARGUMENTS_WITH_DESC(int, ends);
} }
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class ExpandDimsOp final : public Operator<Context> { class ExpandDimsOp final : public Operator<Context> {
public: public:
ExpandDimsOp(const OperatorDef& op_def, Workspace* ws) ExpandDimsOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {} axis(OperatorBase::Arg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,11 +19,11 @@ namespace dragon { ...@@ -19,11 +19,11 @@ namespace dragon {
template <class Context> template <class Context>
class FlattenOp final : public Operator<Context> { class FlattenOp final : public Operator<Context> {
public: public:
FlattenOp(const OperatorDef& op_def, Workspace* ws) FlattenOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::Arg<int>("num_axes", -1)),
keep_axes(OperatorBase::GetSingleArg<int>("keep_axes", INT_MAX)) {} keep_axes(OperatorBase::Arg<int>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class GatherOp final : public Operator<Context> { class GatherOp final : public Operator<Context> {
public: public:
GatherOp(const OperatorDef& op_def, Workspace* ws) GatherOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {} axis(OperatorBase::Arg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,10 +35,10 @@ class GatherOp final : public Operator<Context> { ...@@ -35,10 +35,10 @@ class GatherOp final : public Operator<Context> {
template <class Context> template <class Context>
class GatherGradientOp final : public Operator<Context> { class GatherGradientOp final : public Operator<Context> {
public: public:
GatherGradientOp(const OperatorDef& op_def, Workspace* ws) GatherGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {} acc_grad(OperatorBase::Arg<bool>("acc_gradient", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,11 +19,11 @@ namespace dragon { ...@@ -19,11 +19,11 @@ namespace dragon {
template <class Context> template <class Context>
class OneHotOp final : public Operator < Context > { class OneHotOp final : public Operator < Context > {
public: public:
OneHotOp(const OperatorDef& op_def, Workspace* ws) OneHotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
depth(OperatorBase::GetSingleArg<int>("depth", -1)), depth(OperatorBase::Arg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)), on_value(OperatorBase::Arg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {} off_value(OperatorBase::Arg<int>("off_value", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,12 +19,12 @@ namespace dragon { ...@@ -19,12 +19,12 @@ namespace dragon {
template <class Context> template <class Context>
class PadOp final : public Operator<Context> { class PadOp final : public Operator<Context> {
public: public:
PadOp(const OperatorDef& op_def, Workspace* ws) PadOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
pad_l(OperatorBase::GetRepeatedArg<int>("pad_l")), pad_l(OperatorBase::Args<int>("pad_l")),
pad_r(OperatorBase::GetRepeatedArg<int>("pad_r")), pad_r(OperatorBase::Args<int>("pad_r")),
mode(OperatorBase::GetSingleArg<string>("mode", "CONSTANT")), mode(OperatorBase::Arg<string>("mode", "CONSTANT")),
value(OperatorBase::GetSingleArg<float>("value", 0.0f)) { value(OperatorBase::Arg<float>("value", 0.0f)) {
if (pad_r.size() == 0) pad_r = pad_l; if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size()) else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length."; << "The pad_l and pad_r should have the same length.";
...@@ -54,11 +54,11 @@ class PadOp final : public Operator<Context> { ...@@ -54,11 +54,11 @@ class PadOp final : public Operator<Context> {
template <class Context> template <class Context>
class PadGradientOp final : public Operator<Context> { class PadGradientOp final : public Operator<Context> {
public: public:
PadGradientOp(const OperatorDef& op_def, Workspace* ws) PadGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
pad_l(OperatorBase::GetRepeatedArg<int>("pad_l")), pad_l(OperatorBase::Args<int>("pad_l")),
pad_r(OperatorBase::GetRepeatedArg<int>("pad_r")), pad_r(OperatorBase::Args<int>("pad_r")),
mode(OperatorBase::GetSingleArg<string>("mode", "CONSTANT")) { mode(OperatorBase::Arg<string>("mode", "CONSTANT")) {
if (pad_r.size() == 0) pad_r = pad_l; if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size()) else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length."; << "The pad_l and pad_r should have the same length.";
......
...@@ -17,20 +17,20 @@ ...@@ -17,20 +17,20 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class RandomPickOp : public Operator<Context> { class RandomPickOp final : public Operator<Context> {
public: public:
RandomPickOp(const OperatorDef& op_def, Workspace* ws) : RandomPickOp(const OperatorDef& def, Workspace* ws) :
Operator<Context>(op_def, ws), Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)),
max_samples(OperatorBase::GetSingleArg<int>("max_samples", 1)) {} max_samples(OperatorBase::Arg<int>("max_samples", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, max_samples; TIndex axis, outer_dim, inner_dim, max_samples;
TIndex outer_dim, inner_dim, x_slice_dim, y_slice_dim; TIndex x_slice_dim, y_slice_dim;
vector<TIndex> output_dims; vector<TIndex> output_dims;
Tensor* pick_indices; Tensor* pick_indices;
}; };
...@@ -38,17 +38,17 @@ class RandomPickOp : public Operator<Context> { ...@@ -38,17 +38,17 @@ class RandomPickOp : public Operator<Context> {
template <class Context> template <class Context>
class RandomPickGradientOp final : public Operator<Context> { class RandomPickGradientOp final : public Operator<Context> {
public: public:
RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws) RandomPickGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {} axis(OperatorBase::Arg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis; TIndex axis, outer_dim, inner_dim;
TIndex outer_dim, inner_dim, x_slice_dim, y_slice_dim; TIndex x_slice_dim, y_slice_dim;
Tensor* pick_indices; Tensor* pick_indices;
}; };
......
...@@ -19,11 +19,11 @@ namespace dragon { ...@@ -19,11 +19,11 @@ namespace dragon {
template <class Context> template <class Context>
class ReduceOp final : public Operator<Context> { class ReduceOp final : public Operator<Context> {
public: public:
ReduceOp(const OperatorDef& op_def, Workspace* ws) ReduceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")), operation(OperatorBase::Arg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {} keep_dims(OperatorBase::Arg<bool>("keep_dims", false)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -38,10 +38,10 @@ class ReduceOp final : public Operator<Context> { ...@@ -38,10 +38,10 @@ class ReduceOp final : public Operator<Context> {
template <class Context> template <class Context>
class ReduceGradientOp final : public Operator<Context> { class ReduceGradientOp final : public Operator<Context> {
public: public:
ReduceGradientOp(const OperatorDef& op_def, Workspace* ws) ReduceGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {} operation(OperatorBase::Arg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class RepeatOp : public Operator<Context> { class RepeatOp final : public Operator<Context> {
public: public:
RepeatOp(const OperatorDef& op_def, Workspace* ws) RepeatOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::Arg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1); GET_ARGUMENT_WITH_DESC(int, repeats, 1);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -35,11 +35,11 @@ class RepeatOp : public Operator<Context> { ...@@ -35,11 +35,11 @@ class RepeatOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class RepeatGradientOp : public Operator<Context> { class RepeatGradientOp final : public Operator<Context> {
public: public:
RepeatGradientOp(const OperatorDef& op_def, Workspace* ws) RepeatGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::Arg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1); GET_ARGUMENT_WITH_DESC(int, repeats, 1);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class ReshapeOp final : public Operator<Context> { class ReshapeOp final : public Operator<Context> {
public: public:
ReshapeOp(const OperatorDef& op_def, Workspace* ws) ReshapeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) { shape_like_desc(OperatorBase::Arg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape); GET_ARGUMENTS_WITH_DESC(int, shape);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SliceOp : public Operator<Context> { class SliceOp final : public Operator<Context> {
public: public:
SliceOp(const OperatorDef& op_def, Workspace* ws): SliceOp(const OperatorDef& def, Workspace* ws)
Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {} nout(OperatorBase::Arg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -38,10 +38,10 @@ class SliceOp : public Operator<Context> { ...@@ -38,10 +38,10 @@ class SliceOp : public Operator<Context> {
template <class Context> template <class Context>
class SliceGradientOp final : public Operator<Context> { class SliceGradientOp final : public Operator<Context> {
public: public:
SliceGradientOp(const OperatorDef& op_def, Workspace* ws): SliceGradientOp(const OperatorDef& def, Workspace* ws)
Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::Arg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {} nout(OperatorBase::Arg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,37 +17,37 @@ ...@@ -17,37 +17,37 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class StackOp : public Operator<Context> { class StackOp final : public Operator<Context> {
public: public:
StackOp(const OperatorDef& op_def, Workspace* ws) StackOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)) {}
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, nin, outer_dim, inner_dim, x_concat_dim, y_concat_dim; TIndex axis, outer_dim, inner_dim;
TIndex x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset; TIndex x_offset, y_offset, concat_offset;
vector<TIndex> stack_dims, concat_dims; vector<TIndex> stack_dims, concat_dims;
}; };
template <class Context> template <class Context>
class StackGradientOp : public Operator<Context> { class StackGradientOp final : public Operator<Context> {
public: public:
StackGradientOp(const OperatorDef& op_def, Workspace* ws) StackGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)) {}
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, nin, outer_dim, inner_dim, x_concat_dim, y_concat_dim; TIndex axis, outer_dim, inner_dim;
TIndex x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset; TIndex x_offset, y_offset, concat_offset;
vector<TIndex> concat_dims; vector<TIndex> concat_dims;
}; };
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class TileOp : public Operator<Context> { class TileOp final : public Operator<Context> {
public: public:
TileOp(const OperatorDef& op_def, Workspace* ws) TileOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples); GET_ARGUMENTS_WITH_DESC(int, multiples);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -35,10 +35,10 @@ class TileOp : public Operator<Context> { ...@@ -35,10 +35,10 @@ class TileOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class TileGradientOp : public Operator<Context> { class TileGradientOp final : public Operator<Context> {
public: public:
TileGradientOp(const OperatorDef& op_def, Workspace* ws) TileGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples); GET_ARGUMENTS_WITH_DESC(int, multiples);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -19,11 +19,9 @@ namespace dragon { ...@@ -19,11 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class TransposeOp final: public Operator<Context> { class TransposeOp final: public Operator<Context> {
public: public:
TransposeOp(const OperatorDef& op_def, Workspace* ws) TransposeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws) {
perms(OperatorBase::GetRepeatedArg<int>("perms")) { GET_ARGUMENTS_WITH_DESC(int, perms);
if (perms.size() > 0) reverse_dims = false;
else reverse_dims = true;
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -31,17 +29,17 @@ class TransposeOp final: public Operator<Context> { ...@@ -31,17 +29,17 @@ class TransposeOp final: public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<int> perms; DECLARE_ARGUMENTS_WITH_DESC(int, perms);
bool reverse_dims;
Tensor* order, *old_steps, *new_steps; Tensor* order, *old_steps, *new_steps;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, TransposeOp, perms);
template <class Context> template <class Context>
class TransposeGradientOp final : public Operator<Context> { class TransposeGradientOp final : public Operator<Context> {
public: public:
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws) TransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,15 +19,15 @@ ...@@ -19,15 +19,15 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class BatchNormOp : public Operator<Context> { class BatchNormOp final : public Operator<Context> {
public: public:
BatchNormOp(const OperatorDef& op_def, Workspace* ws) BatchNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)), momentum(OperatorBase::Arg<float>("momentum", 0.9f)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)), eps(OperatorBase::Arg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)), use_stats(OperatorBase::Arg<int>("use_stats", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) { mode(OperatorBase::Arg<string>("mode", "DEFAULT")) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
...@@ -51,10 +51,10 @@ class BatchNormOp : public Operator<Context> { ...@@ -51,10 +51,10 @@ class BatchNormOp : public Operator<Context> {
template <class Context> template <class Context>
class BatchNormGradientOp final : public Operator<Context> { class BatchNormGradientOp final : public Operator<Context> {
public: public:
BatchNormGradientOp(const OperatorDef& op_def, Workspace *ws) BatchNormGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { use_stats(OperatorBase::Arg<int>("use_stats", -1)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
...@@ -77,12 +77,12 @@ class BatchNormGradientOp final : public Operator<Context> { ...@@ -77,12 +77,12 @@ class BatchNormGradientOp final : public Operator<Context> {
template <class Context> template <class Context>
class FusedBatchNormOp : public Operator<Context> { class FusedBatchNormOp : public Operator<Context> {
public: public:
FusedBatchNormOp(const OperatorDef& op_def, Workspace* ws) FusedBatchNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)), momentum(OperatorBase::Arg<float>("momentum", 0.9f)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)), eps(OperatorBase::Arg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::Arg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -102,11 +102,11 @@ class FusedBatchNormOp : public Operator<Context> { ...@@ -102,11 +102,11 @@ class FusedBatchNormOp : public Operator<Context> {
template <class Context> template <class Context>
class FusedBatchNormGradientOp : public Operator<Context> { class FusedBatchNormGradientOp : public Operator<Context> {
public: public:
FusedBatchNormGradientOp(const OperatorDef& op_def, Workspace* ws) FusedBatchNormGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)), eps(OperatorBase::Arg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::Arg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -132,9 +132,9 @@ class FusedBatchNormGradientOp : public Operator<Context> { ...@@ -132,9 +132,9 @@ class FusedBatchNormGradientOp : public Operator<Context> {
template <class Context> template <class Context>
class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> { class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
public: public:
CuDNNBatchNormOp(const OperatorDef& op_def, Workspace* ws) CuDNNBatchNormOp(const OperatorDef& def, Workspace* ws)
: FusedBatchNormOp<Context>(op_def, ws), : FusedBatchNormOp<Context>(def, ws),
eps64(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) { eps64(OperatorBase::Arg<float>("eps", 1e-3f)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
...@@ -169,9 +169,9 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> { ...@@ -169,9 +169,9 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
template <class Context> template <class Context>
class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context> { class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context> {
public: public:
CuDNNBatchNormGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNBatchNormGradientOp(const OperatorDef& def, Workspace* ws)
: FusedBatchNormGradientOp<Context>(op_def, ws), : FusedBatchNormGradientOp<Context>(def, ws),
eps64(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) { eps64(OperatorBase::Arg<float>("eps", 1e-3f)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
......
...@@ -17,19 +17,19 @@ ...@@ -17,19 +17,19 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class BatchRenormOp : public Operator<Context> { class BatchRenormOp final : public Operator<Context> {
public: public:
BatchRenormOp(const OperatorDef& op_def, Workspace* ws) BatchRenormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)), momentum(OperatorBase::Arg<float>("momentum", 0.9f)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)), eps(OperatorBase::Arg<float>("eps", 1e-3f)),
r_max(OperatorBase::GetSingleArg<float>("r_max", 3.f)), r_max(OperatorBase::Arg<float>("r_max", 3.f)),
d_max(OperatorBase::GetSingleArg<float>("d_max", 5.f)), d_max(OperatorBase::Arg<float>("d_max", 5.f)),
t_delta(OperatorBase::GetSingleArg<float>("t_delta", 1.f)), t_delta(OperatorBase::Arg<float>("t_delta", 1.f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)), use_stats(OperatorBase::Arg<int>("use_stats", -1)),
t_r_max(1.f), t_d_max(0.f), t_val(0.f), t_r_max(1.f), t_d_max(0.f), t_val(0.f),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) { mode(OperatorBase::Arg<string>("mode", "DEFAULT")) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
...@@ -55,10 +55,10 @@ class BatchRenormOp : public Operator<Context> { ...@@ -55,10 +55,10 @@ class BatchRenormOp : public Operator<Context> {
template <class Context> template <class Context>
class BatchRenormGradientOp final : public Operator<Context> { class BatchRenormGradientOp final : public Operator<Context> {
public: public:
BatchRenormGradientOp(const OperatorDef& op_def, Workspace *ws) BatchRenormGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { use_stats(OperatorBase::Arg<int>("use_stats", -1)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class GroupNormOp : public Operator<Context> { class GroupNormOp final : public Operator<Context> {
public: public:
GroupNormOp(const OperatorDef& op_def, Workspace* ws) GroupNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::Arg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) { eps(OperatorBase::Arg<float>("eps", 1e-3f)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
...@@ -45,10 +45,10 @@ class GroupNormOp : public Operator<Context> { ...@@ -45,10 +45,10 @@ class GroupNormOp : public Operator<Context> {
template <class Context> template <class Context>
class GroupNormGradientOp final : public Operator<Context> { class GroupNormGradientOp final : public Operator<Context> {
public: public:
GroupNormGradientOp(const OperatorDef& op_def, Workspace *ws) GroupNormGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::Arg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::Arg<int>("axis", -1)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
...@@ -67,13 +67,13 @@ class GroupNormGradientOp final : public Operator<Context> { ...@@ -67,13 +67,13 @@ class GroupNormGradientOp final : public Operator<Context> {
}; };
template <class Context> template <class Context>
class FusedGroupNormOp : public Operator<Context> { class FusedGroupNormOp final : public Operator<Context> {
public: public:
FusedGroupNormOp(const OperatorDef& op_def, Workspace* ws) FusedGroupNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::Arg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {} eps(OperatorBase::Arg<float>("eps", 1e-3f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -89,12 +89,12 @@ class FusedGroupNormOp : public Operator<Context> { ...@@ -89,12 +89,12 @@ class FusedGroupNormOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class FusedGroupNormGradientOp : public Operator<Context> { class FusedGroupNormGradientOp final : public Operator<Context> {
public: public:
FusedGroupNormGradientOp(const OperatorDef& op_def, Workspace* ws) FusedGroupNormGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::Arg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {} axis(OperatorBase::Arg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class InstanceNormOp : public Operator<Context> { class InstanceNormOp final : public Operator<Context> {
public: public:
InstanceNormOp(const OperatorDef& op_def, Workspace* ws) InstanceNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::Arg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) { eps(OperatorBase::Arg<float>("eps", 1e-3f)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1."; CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1.";
} }
...@@ -43,9 +43,9 @@ class InstanceNormOp : public Operator<Context> { ...@@ -43,9 +43,9 @@ class InstanceNormOp : public Operator<Context> {
template <class Context> template <class Context>
class InstanceNormGradientOp final : public Operator<Context> { class InstanceNormGradientOp final : public Operator<Context> {
public: public:
InstanceNormGradientOp(const OperatorDef& op_def, Workspace *ws) InstanceNormGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::Arg<int>("axis", -1)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1."; CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1.";
} }
......
...@@ -19,12 +19,12 @@ namespace dragon { ...@@ -19,12 +19,12 @@ namespace dragon {
template <class Context> template <class Context>
class L2NormOp final : public Operator<Context> { class L2NormOp final : public Operator<Context> {
public: public:
L2NormOp(const OperatorDef& op_def, Workspace* ws) L2NormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::Arg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-5f)), eps(OperatorBase::Arg<float>("eps", 1e-5f)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {} mode(OperatorBase::Arg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -42,11 +42,11 @@ class L2NormOp final : public Operator<Context> { ...@@ -42,11 +42,11 @@ class L2NormOp final : public Operator<Context> {
template <class Context> template <class Context>
class L2NormGradientOp final : public Operator<Context> { class L2NormGradientOp final : public Operator<Context> {
public: public:
L2NormGradientOp(const OperatorDef& op_def, Workspace* ws) L2NormGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::Arg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::Arg<int>("num_axes", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {} mode(OperatorBase::Arg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -49,29 +49,29 @@ class cudnnTensorDescriptors { ...@@ -49,29 +49,29 @@ class cudnnTensorDescriptors {
template <class Context> template <class Context>
class CuDNNRecurrentOpBase : public Operator<Context> { class CuDNNRecurrentOpBase : public Operator<Context> {
public: public:
CuDNNRecurrentOpBase(const OperatorDef& op_def, Workspace* ws) CuDNNRecurrentOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false), : Operator<Context>(def, ws), states_initialized(false),
hidden_size(OperatorBase::GetSingleArg<int>("hidden_size", 0)), hidden_size(OperatorBase::Arg<int>("hidden_size", 0)),
num_layers(OperatorBase::GetSingleArg<int>("num_layers", 1)), num_layers(OperatorBase::Arg<int>("num_layers", 1)),
bidirectional(OperatorBase::GetSingleArg<bool>("bidirectional", false)), bidirectional(OperatorBase::Arg<bool>("bidirectional", false)),
dropout_ratio(OperatorBase::GetSingleArg<float>("dropout_ratio", 1.0)), dropout_ratio(OperatorBase::Arg<float>("dropout_ratio", 1.0)),
random_seed(op_def.device_option().random_seed()) { random_seed(def.device_option().random_seed()) {
// determine the rnn direction // determine the rnn direction
rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// determine the rnn mode // determine the rnn mode
const string mode = OperatorBase::GetSingleArg<string>("rnn_mode", ""); const string mode = OperatorBase::Arg<string>("rnn_mode", "");
if (mode == "rnn_tanh") rnn_mode = CUDNN_RNN_TANH; if (mode == "rnn_tanh") rnn_mode = CUDNN_RNN_TANH;
else if (mode == "rnn_relu") rnn_mode = CUDNN_RNN_RELU; else if (mode == "rnn_relu") rnn_mode = CUDNN_RNN_RELU;
else if (mode == "lstm") rnn_mode = CUDNN_LSTM; else if (mode == "lstm") rnn_mode = CUDNN_LSTM;
else if (mode == "gru") rnn_mode = CUDNN_GRU; else if (mode == "gru") rnn_mode = CUDNN_GRU;
else LOG(FATAL) << "Unsupported rnn mode: " << mode; else LOG(FATAL) << "Unsupported rnn mode: " << mode;
// determine the rnn input mode // determine the rnn input mode
const string input_mode = OperatorBase::GetSingleArg<string>("rnn_input_mode", "linear"); const string input_mode = OperatorBase::Arg<string>("rnn_input_mode", "linear");
if (input_mode == "skip") rnn_input_mode = CUDNN_SKIP_INPUT; if (input_mode == "skip") rnn_input_mode = CUDNN_SKIP_INPUT;
else if (input_mode == "linear") rnn_input_mode = CUDNN_LINEAR_INPUT; else if (input_mode == "linear") rnn_input_mode = CUDNN_LINEAR_INPUT;
else LOG(FATAL) << "Unsupported rnn input mode: " << input_mode; else LOG(FATAL) << "Unsupported rnn input mode: " << input_mode;
// override the running phase // override the running phase
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", "")); SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc)); CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc)); CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
CUDNN_CHECK(cudnnCreateFilterDescriptor(&w_desc)); CUDNN_CHECK(cudnnCreateFilterDescriptor(&w_desc));
...@@ -92,8 +92,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> { ...@@ -92,8 +92,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc));
} }
template <typename T> void ResetDesc(Tensor* X, Tensor* Hx, Tensor* Cx, template <typename T> void ResetDesc();
Tensor* Y, Tensor* Hy, Tensor* Cy);
public: public:
TIndex hidden_size, num_layers; TIndex hidden_size, num_layers;
...@@ -109,7 +108,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> { ...@@ -109,7 +108,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
cudnnFilterDescriptor_t w_desc; cudnnFilterDescriptor_t w_desc;
cudnnTensorDescriptor_t hx_desc, cx_desc; cudnnTensorDescriptor_t hx_desc, cx_desc;
cudnnTensorDescriptor_t hy_desc, cy_desc; cudnnTensorDescriptor_t hy_desc, cy_desc;
vector<TIndex> input_dims; vector<TIndex> input_dims, output_dims, hidden_dims;
size_t workspace_size, reserve_size, states_size; size_t workspace_size, reserve_size, states_size;
std::unique_ptr<cudnnTensorDescriptors> xs_desc; std::unique_ptr<cudnnTensorDescriptors> xs_desc;
...@@ -128,14 +127,16 @@ class CuDNNRecurrentOpBase : public Operator<Context> { ...@@ -128,14 +127,16 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
using CuDNNRecurrentOpBase<Context>::xs_desc; \ using CuDNNRecurrentOpBase<Context>::xs_desc; \
using CuDNNRecurrentOpBase<Context>::ys_desc; \ using CuDNNRecurrentOpBase<Context>::ys_desc; \
using CuDNNRecurrentOpBase<Context>::input_dims; \ using CuDNNRecurrentOpBase<Context>::input_dims; \
using CuDNNRecurrentOpBase<Context>::output_dims; \
using CuDNNRecurrentOpBase<Context>::hidden_dims; \
using CuDNNRecurrentOpBase<Context>::workspace_size; \ using CuDNNRecurrentOpBase<Context>::workspace_size; \
using CuDNNRecurrentOpBase<Context>::reserve_size using CuDNNRecurrentOpBase<Context>::reserve_size
template <class Context> template <class Context>
class CuDNNRecurrentOp : public CuDNNRecurrentOpBase<Context> { class CuDNNRecurrentOp final : public CuDNNRecurrentOpBase<Context> {
public: public:
CuDNNRecurrentOp(const OperatorDef& op_def, Workspace* ws) CuDNNRecurrentOp(const OperatorDef& def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(op_def, ws) {} : CuDNNRecurrentOpBase<Context>(def, ws) {}
USE_CUDNN_RECURRENT_FUNCTIONS; USE_CUDNN_RECURRENT_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -143,10 +144,10 @@ class CuDNNRecurrentOp : public CuDNNRecurrentOpBase<Context> { ...@@ -143,10 +144,10 @@ class CuDNNRecurrentOp : public CuDNNRecurrentOpBase<Context> {
}; };
template <class Context> template <class Context>
class CuDNNRecurrentGradientOp : public CuDNNRecurrentOpBase<Context> { class CuDNNRecurrentGradientOp final : public CuDNNRecurrentOpBase<Context> {
public: public:
CuDNNRecurrentGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNRecurrentGradientOp(const OperatorDef& def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(op_def, ws) {} : CuDNNRecurrentOpBase<Context>(def, ws) {}
USE_CUDNN_RECURRENT_FUNCTIONS; USE_CUDNN_RECURRENT_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class LSTMCellOp : public Operator<Context> { class LSTMCellOp final : public Operator<Context> {
public: public:
LSTMCellOp(const OperatorDef& op_def, Workspace* ws) LSTMCellOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -29,10 +29,10 @@ class LSTMCellOp : public Operator<Context> { ...@@ -29,10 +29,10 @@ class LSTMCellOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class LSTMCellGradientOp : public Operator<Context> { class LSTMCellGradientOp final : public Operator<Context> {
public: public:
LSTMCellGradientOp(const OperatorDef& op_def, Workspace* ws) LSTMCellGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -19,8 +19,8 @@ namespace dragon { ...@@ -19,8 +19,8 @@ namespace dragon {
template <class Context> template <class Context>
class RecurrentOp : public Operator<Context> { class RecurrentOp : public Operator<Context> {
public: public:
RecurrentOp(const OperatorDef& op_def, Workspace* ws) RecurrentOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(def, ws) {
LOG(FATAL) << "RNN Operators require CuDNN support."; LOG(FATAL) << "RNN Operators require CuDNN support.";
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -31,8 +31,8 @@ class RecurrentOp : public Operator<Context> { ...@@ -31,8 +31,8 @@ class RecurrentOp : public Operator<Context> {
template <class Context> template <class Context>
class RecurrentGradientOp : public Operator<Context> { class RecurrentGradientOp : public Operator<Context> {
public: public:
RecurrentGradientOp(const OperatorDef& op_def, Workspace* ws) RecurrentGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(def, ws) {
LOG(FATAL) << "RNN Operators require CuDNN support."; LOG(FATAL) << "RNN Operators require CuDNN support.";
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -17,18 +17,18 @@ ...@@ -17,18 +17,18 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class RNNParamSetOp : public Operator<Context> { class RNNParamSetOp final : public Operator<Context> {
public: public:
RNNParamSetOp(const OperatorDef& op_def, Workspace* ws) RNNParamSetOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
param_type(OperatorBase::GetSingleArg<string>("param_type", "matrix")), param_type(OperatorBase::Arg<string>("param_type", "matrix")),
rnn_mode(OperatorBase::GetSingleArg<string>("rnn_mode", "rnn_tanh")), rnn_mode(OperatorBase::Arg<string>("rnn_mode", "rnn_tanh")),
num_layers(OperatorBase::GetSingleArg<int>("num_layers", 1)), num_layers(OperatorBase::Arg<int>("num_layers", 1)),
num_directions(OperatorBase::GetSingleArg<int>("num_directions", 1)), num_directions(OperatorBase::Arg<int>("num_directions", 1)),
input_size(OperatorBase::GetSingleArg<int>("input_size", 0)), input_size(OperatorBase::Arg<int>("input_size", 0)),
hidden_size(OperatorBase::GetSingleArg<int>("hidden_size", 0)), hidden_size(OperatorBase::Arg<int>("hidden_size", 0)),
layer_id(OperatorBase::GetSingleArg<int>("layer_id", 0)), layer_id(OperatorBase::Arg<int>("layer_id", 0)),
param_id(OperatorBase::GetSingleArg<int>("param_id", 0)) { param_id(OperatorBase::Arg<int>("param_id", 0)) {
if (rnn_mode == "rnn_tanh") { num_params = 2; spliter = 1; } if (rnn_mode == "rnn_tanh") { num_params = 2; spliter = 1; }
else if (rnn_mode == "rnn_relu") { num_params = 2; spliter = 1; } else if (rnn_mode == "rnn_relu") { num_params = 2; spliter = 1; }
else if (rnn_mode == "lstm") { num_params = 8; spliter = 4; } else if (rnn_mode == "lstm") { num_params = 8; spliter = 4; }
......
...@@ -19,8 +19,8 @@ namespace dragon { ...@@ -19,8 +19,8 @@ namespace dragon {
template <class Context> template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> { class AdamUpdateOp final : public UpdateOpBase<Context> {
public: public:
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws) AdamUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), t(0) {} : UpdateOpBase<Context>(def, ws), t(0) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
......
...@@ -19,16 +19,26 @@ namespace dragon { ...@@ -19,16 +19,26 @@ namespace dragon {
#ifdef WITH_MPI #ifdef WITH_MPI
template <class Context> template <class Context>
class CollectiveUpdateOp : public Operator<Context> { class CollectiveUpdateOp final : public Operator<Context> {
public: public:
CollectiveUpdateOp(const OperatorDef& op_def, Workspace* ws) CollectiveUpdateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
mode(OperatorBase::GetSingleArg<string>("mode", "UNKNOWN")) { mode(OperatorBase::Arg<string>("mode", "UNKNOWN")) {
InitMPI(); InitMPI();
if (mode.find("NCCL") != string::npos) InitNCCL(); if (mode.find("NCCL") != string::npos) InitNCCL();
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
~CollectiveUpdateOp() {
/* TODO(PhyscalX): Temporarily disable it,
to avoid a unhandled error. */
#ifdef WITH_MPI_NCCL
if (mode.find("NCCL") != string::npos) {
/* ncclCommDestroy(nccl_comm); */
}
#endif
}
void InitMPI(); void InitMPI();
void InitNCCL(); void InitNCCL();
...@@ -48,7 +58,7 @@ class CollectiveUpdateOp : public Operator<Context> { ...@@ -48,7 +58,7 @@ class CollectiveUpdateOp : public Operator<Context> {
#ifdef WITH_MPI_NCCL #ifdef WITH_MPI_NCCL
ncclComm_t nccl_comm; ncclComm_t nccl_comm;
cudaStream_t stream; CUDAClosure<Context> closure;
#endif #endif
}; };
......
...@@ -19,9 +19,9 @@ namespace dragon { ...@@ -19,9 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class MovingAverageOp final : public Operator<Context> { class MovingAverageOp final : public Operator<Context> {
public: public:
MovingAverageOp(const OperatorDef& op_def, Workspace* ws) MovingAverageOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {} decay(OperatorBase::Arg<float>("decay", 1.0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -29,7 +29,6 @@ class MovingAverageOp final : public Operator<Context> { ...@@ -29,7 +29,6 @@ class MovingAverageOp final : public Operator<Context> {
protected: protected:
float decay; float decay;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -19,8 +19,8 @@ namespace dragon { ...@@ -19,8 +19,8 @@ namespace dragon {
template <class Context> template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> { class NesterovUpdateOp final : public UpdateOpBase<Context> {
public: public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws) NesterovUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws) {} : UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
......
...@@ -19,8 +19,8 @@ namespace dragon { ...@@ -19,8 +19,8 @@ namespace dragon {
template <class Context> template <class Context>
class RMSPropUpdateOp final : public UpdateOpBase<Context> { class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public: public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws) RMSPropUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws) {} : UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
......
...@@ -19,8 +19,8 @@ namespace dragon { ...@@ -19,8 +19,8 @@ namespace dragon {
template <class Context> template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> { class SGDUpdateOp final : public UpdateOpBase<Context> {
public: public:
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws) SGDUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(def, ws),
old_lr(-1.f), correction(1.f) {} old_lr(-1.f), correction(1.f) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
......
...@@ -19,12 +19,12 @@ namespace dragon { ...@@ -19,12 +19,12 @@ namespace dragon {
template <class Context> template <class Context>
class UpdateOpBase : public Operator<Context> { class UpdateOpBase : public Operator<Context> {
public: public:
UpdateOpBase(const OperatorDef& op_def, Workspace* ws) UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)), lr_mult(OperatorBase::Arg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)), decay_mult(OperatorBase::Arg<float>("decay_mult", 1.0)),
slot(OperatorBase::GetSingleArg<string>("slot", "")), slot(OperatorBase::Arg<string>("slot", "")),
zero_grad(OperatorBase::GetSingleArg<bool>("zero_grad", true)) { zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot"; CHECK(!slot.empty()) << "\nRequired a non-empty slot";
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -34,8 +34,13 @@ class UpdateOpBase : public Operator<Context> { ...@@ -34,8 +34,13 @@ class UpdateOpBase : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void PreprocessRunWithType(); template <typename T> void PreprocessRunWithType();
virtual void ComputeRunWithFloat() = 0; virtual void ComputeRunWithFloat() = 0;
virtual void ComputeRunWithFloat16() { LOG(FATAL) << "This Updater does not support FP16."; }
virtual void ComputeRunWithFloat16() {
LOG(FATAL) << "This Updater does not support FP16.";
}
template <typename T> void UpdateRunWithType(); template <typename T> void UpdateRunWithType();
protected: protected:
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class BiasAddOp : public Operator<Context> { class BiasAddOp final : public Operator<Context> {
public: public:
BiasAddOp(const OperatorDef& op_def, Workspace* ws) BiasAddOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,9 +35,9 @@ class BiasAddOp : public Operator<Context> { ...@@ -35,9 +35,9 @@ class BiasAddOp : public Operator<Context> {
template <class Context> template <class Context>
class BiasAddGradientOp final : public Operator<Context> { class BiasAddGradientOp final : public Operator<Context> {
public: public:
BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws) BiasAddGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class BilinearResizeOp : public Operator<Context> { class BilinearResizeOp final : public Operator<Context> {
public: public:
BilinearResizeOp(const OperatorDef& op_def, Workspace* ws) BilinearResizeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)), fy(OperatorBase::Arg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)), fx(OperatorBase::Arg<float>("fx", -1.0)),
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")), shape_like_desc(OperatorBase::Arg<string>("shape_like", "")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {
GET_ARGUMENTS_WITH_DESC(int, dsize); GET_ARGUMENTS_WITH_DESC(int, dsize);
if (data_format == "NCHW") spatial_axis = 2; if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
...@@ -43,11 +43,11 @@ class BilinearResizeOp : public Operator<Context> { ...@@ -43,11 +43,11 @@ class BilinearResizeOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class BilinearResizeGradientOp : public Operator<Context> { class BilinearResizeGradientOp final : public Operator<Context> {
public: public:
BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws) BilinearResizeGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -53,10 +53,11 @@ class Conv2dGradientOp : public Conv2dOp<Context> { ...@@ -53,10 +53,11 @@ class Conv2dGradientOp : public Conv2dOp<Context> {
#include "utils/cudnn_device.h" #include "utils/cudnn_device.h"
template <class Context> template <class Context>
class CuDNNConv2dOp : public Conv2dOp<Context> { class CuDNNConv2dOp final : public Conv2dOp<Context> {
public: public:
CuDNNConv2dOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws), enable_tensor_core(true) { : Conv2dOp<Context>(def, ws),
enable_tensor_core(true) {
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_group = 1; cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE(); enable_tensor_core &= TENSOR_CORE_AVAILABLE();
...@@ -64,14 +65,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -64,14 +65,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnn_group = group; cudnn_group = group;
enable_tensor_core = false; enable_tensor_core = false;
#endif #endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
ctx().SwitchToDevice();
for (int g = 0; g < cudnn_group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
}
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc)); CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
...@@ -90,10 +83,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -90,10 +83,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc)); CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc)); if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
for (int g = 0; g < cudnn_group; g++) {
cudaStreamDestroy(stream[g]);
CUDNN_CHECK(cudnnDestroy(handle[g]));
}
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -101,8 +90,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -101,8 +90,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnDataType_t compute_type; cudnnDataType_t compute_type;
cudnnTensorFormat_t format; cudnnTensorFormat_t format;
cudnnConvolutionFwdAlgo_t fwd_algo; cudnnConvolutionFwdAlgo_t fwd_algo;
...@@ -116,10 +103,11 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -116,10 +103,11 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
}; };
template <class Context> template <class Context>
class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { class CuDNNConv2dGradientOp final : public Conv2dGradientOp<Context> {
public: public:
CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dGradientOp<Context>(def, ws), enable_tensor_core(true) { : Conv2dGradientOp<Context>(def, ws),
enable_tensor_core(true) {
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_group = 1; cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE(); enable_tensor_core &= TENSOR_CORE_AVAILABLE();
...@@ -127,13 +115,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -127,13 +115,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnn_group = group; cudnn_group = group;
enable_tensor_core = false; enable_tensor_core = false;
#endif #endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
for (int g = 0; g < cudnn_group * 3; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
}
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc)); CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
...@@ -152,10 +133,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -152,10 +133,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc)); CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc)); if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
for (int g = 0; g < cudnn_group * 3; g++) {
cudaStreamDestroy(stream[g]);
CUDNN_CHECK(cudnnDestroy(handle[g]));
}
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -163,8 +140,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -163,8 +140,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnDataType_t compute_type; cudnnDataType_t compute_type;
cudnnTensorFormat_t format; cudnnTensorFormat_t format;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo; cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
......
...@@ -21,14 +21,14 @@ namespace dragon { ...@@ -21,14 +21,14 @@ namespace dragon {
template <class Context> template <class Context>
class ConvOpBase : public Operator<Context> { class ConvOpBase : public Operator<Context> {
public: public:
ConvOpBase(const OperatorDef& op_def, Workspace* ws) ConvOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")), data_format(OperatorBase::Arg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")), padding(OperatorBase::Arg<string>("padding", "VALID")),
num_output(OperatorBase::GetSingleArg<int>("num_output", 1)), num_output(OperatorBase::Arg<int>("num_output", 1)),
group(OperatorBase::GetSingleArg<int>("group", 1)) { group(OperatorBase::Arg<int>("group", 1)) {
output_dims_value = OperatorBase::GetRepeatedArg<int>("output_shape"); output_dims_value = OperatorBase::Args<int>("output_shape");
output_dims_desc = OperatorBase::GetRepeatedArg<string>("output_shape_desc"); output_dims_desc = OperatorBase::Args<string>("output_shape_desc");
if (data_format == "NCHW") spatial_axis = 2; if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class Conv2dTransposeOp: public ConvOpBase<Context> { class Conv2dTransposeOp : public ConvOpBase<Context> {
public: public:
Conv2dTransposeOp(const OperatorDef& def, Workspace* ws) Conv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) { : ConvOpBase<Context>(def, ws) {
...@@ -57,7 +57,7 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> { ...@@ -57,7 +57,7 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> {
#include "utils/cudnn_device.h" #include "utils/cudnn_device.h"
template <class Context> template <class Context>
class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { class CuDNNConv2dTransposeOp final : public Conv2dTransposeOp<Context> {
public: public:
CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws), enable_tensor_core(true) { : Conv2dTransposeOp<Context>(def, ws), enable_tensor_core(true) {
...@@ -68,13 +68,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -68,13 +68,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnn_group = group; cudnn_group = group;
enable_tensor_core = false; enable_tensor_core = false;
#endif #endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
for (int g = 0; g < cudnn_group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
}
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc)); CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
...@@ -93,10 +86,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -93,10 +86,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc)); CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc)); if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
for (int g = 0; g < cudnn_group; g++) {
cudaStreamDestroy(stream[g]);
CUDNN_CHECK(cudnnDestroy(handle[g]));
}
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -104,8 +93,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -104,8 +93,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnDataType_t compute_type; cudnnDataType_t compute_type;
cudnnTensorFormat_t format; cudnnTensorFormat_t format;
cudnnConvolutionBwdDataAlgo_t fwd_algo; cudnnConvolutionBwdDataAlgo_t fwd_algo;
...@@ -119,7 +106,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -119,7 +106,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
}; };
template <class Context> template <class Context>
class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context> { class CuDNNConv2dTransposeGradientOp final : public Conv2dTransposeGradientOp<Context> {
public: public:
CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeGradientOp<Context>(def, ws), enable_tensor_core(true) { : Conv2dTransposeGradientOp<Context>(def, ws), enable_tensor_core(true) {
...@@ -130,13 +117,6 @@ public: ...@@ -130,13 +117,6 @@ public:
cudnn_group = group; cudnn_group = group;
enable_tensor_core = false; enable_tensor_core = false;
#endif #endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
for (int g = 0; g < cudnn_group * 3; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
}
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc)); CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
...@@ -155,10 +135,6 @@ public: ...@@ -155,10 +135,6 @@ public:
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc)); CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc)); if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
for (int g = 0; g < cudnn_group * 3; g++) {
cudaStreamDestroy(stream[g]);
CUDNN_CHECK(cudnnDestroy(handle[g]));
}
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -166,8 +142,6 @@ public: ...@@ -166,8 +142,6 @@ public:
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnDataType_t compute_type; cudnnDataType_t compute_type;
cudnnTensorFormat_t format; cudnnTensorFormat_t format;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo; cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
......
...@@ -19,17 +19,17 @@ namespace dragon { ...@@ -19,17 +19,17 @@ namespace dragon {
template <class Context> template <class Context>
class DenseConcatOp final : public ConcatOp<Context> { class DenseConcatOp final : public ConcatOp<Context> {
public: public:
DenseConcatOp(const OperatorDef& op_def, Workspace* ws) DenseConcatOp(const OperatorDef& def, Workspace* ws)
: ConcatOp<Context>(op_def, ws) {} : ConcatOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
class DenseConcatGradientOp : public ConcatGradientOp<Context> { class DenseConcatGradientOp final : public ConcatGradientOp<Context> {
public: public:
DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws) DenseConcatGradientOp(const OperatorDef& def, Workspace* ws)
: ConcatGradientOp<Context>(op_def, ws), : ConcatGradientOp<Context>(def, ws),
growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {} growth_rate(OperatorBase::Arg<int>("growth_rate", 0)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void ElimateCorruption() override; void ElimateCorruption() override;
......
...@@ -21,14 +21,14 @@ enum LRNMode { ACROSS_CHANNELS, WITHIN_CHANNEL }; ...@@ -21,14 +21,14 @@ enum LRNMode { ACROSS_CHANNELS, WITHIN_CHANNEL };
template <class Context> template <class Context>
class LRNOp : public Operator<Context> { class LRNOp : public Operator<Context> {
public: public:
LRNOp(const OperatorDef& op_def, Workspace* ws) LRNOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
local_size(OperatorBase::GetSingleArg<int>("local_size", 5)), local_size(OperatorBase::Arg<int>("local_size", 5)),
alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))), alpha(OperatorBase::Arg<float>("alpha", float(0.0001))),
beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))), beta(OperatorBase::Arg<float>("beta", float(0.75))),
k(OperatorBase::GetSingleArg<float>("k", float(2.0))), k(OperatorBase::Arg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")), mode(OperatorBase::Arg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -52,14 +52,14 @@ class LRNOp : public Operator<Context> { ...@@ -52,14 +52,14 @@ class LRNOp : public Operator<Context> {
template <class Context> template <class Context>
class LRNGradientOp : public Operator<Context> { class LRNGradientOp : public Operator<Context> {
public: public:
LRNGradientOp(const OperatorDef& op_def, Workspace* ws) LRNGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
local_size(OperatorBase::GetSingleArg<int>("local_size", 5)), local_size(OperatorBase::Arg<int>("local_size", 5)),
alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))), alpha(OperatorBase::Arg<float>("alpha", float(0.0001))),
beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))), beta(OperatorBase::Arg<float>("beta", float(0.75))),
k(OperatorBase::GetSingleArg<float>("k", float(2.0))), k(OperatorBase::Arg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")), mode(OperatorBase::Arg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -85,17 +85,15 @@ class LRNGradientOp : public Operator<Context> { ...@@ -85,17 +85,15 @@ class LRNGradientOp : public Operator<Context> {
#include "utils/cudnn_device.h" #include "utils/cudnn_device.h"
template <class Context> template <class Context>
class CuDNNLRNOp : public LRNOp<Context> { class CuDNNLRNOp final : public LRNOp<Context> {
public: public:
CuDNNLRNOp(const OperatorDef& op_def, Workspace* ws) CuDNNLRNOp(const OperatorDef& def, Workspace* ws)
: LRNOp<Context>(op_def, ws) { : LRNOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc)); CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc));
CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc, this->local_size, CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc,
this->alpha, this->local_size, this->alpha, this->beta, this->k));
this->beta,
this->k));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -114,17 +112,15 @@ class CuDNNLRNOp : public LRNOp<Context> { ...@@ -114,17 +112,15 @@ class CuDNNLRNOp : public LRNOp<Context> {
}; };
template <class Context> template <class Context>
class CuDNNLRNGradientOp : public LRNGradientOp<Context > { class CuDNNLRNGradientOp final : public LRNGradientOp<Context > {
public: public:
CuDNNLRNGradientOp(const OperatorDef& op_def, Workspace* ws) : CuDNNLRNGradientOp(const OperatorDef& def, Workspace* ws)
LRNGradientOp<Context>(op_def, ws) { : LRNGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc)); CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc));
CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc, this->local_size, CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc,
this->alpha, this->local_size, this->alpha, this->beta, this->k));
this->beta,
this->k));
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
......
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class NNResizeOp : public Operator<Context> { class NNResizeOp final : public Operator<Context> {
public: public:
NNResizeOp(const OperatorDef& op_def, Workspace* ws) NNResizeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)), fy(OperatorBase::Arg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)), fx(OperatorBase::Arg<float>("fx", -1.0)),
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")), shape_like_desc(OperatorBase::Arg<string>("shape_like", "")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {
GET_ARGUMENTS_WITH_DESC(int, dsize); GET_ARGUMENTS_WITH_DESC(int, dsize);
if (data_format == "NCHW") spatial_axis = 2; if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
...@@ -43,11 +43,11 @@ class NNResizeOp : public Operator<Context> { ...@@ -43,11 +43,11 @@ class NNResizeOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class NNResizeGradientOp : public Operator<Context> { class NNResizeGradientOp final : public Operator<Context> {
public: public:
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws) NNResizeGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -17,18 +17,18 @@ ...@@ -17,18 +17,18 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class Pooling2dOp: public Operator <Context> { class Pooling2dOp : public Operator<Context> {
public: public:
Pooling2dOp(const OperatorDef& op_def, Workspace* ws) Pooling2dOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
mode(OperatorBase::GetSingleArg<string>("mode", "MAX")), mode(OperatorBase::Arg<string>("mode", "MAX")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")), data_format(OperatorBase::Arg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")), padding(OperatorBase::Arg<string>("padding", "VALID")),
global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)), global_pooling(OperatorBase::Arg<bool>("global_pooling", false)),
ceil_mode(OperatorBase::GetSingleArg<bool>("ceil", true)) { ceil_mode(OperatorBase::Arg<bool>("ceil", true)) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size"); vector<int> ks = OperatorBase::Args<int>("kernel_size");
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride"); vector<int> s = OperatorBase::Args<int>("stride");
vector<int> p = OperatorBase::GetRepeatedArg<int>("pad"); vector<int> p = OperatorBase::Args<int>("pad");
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
if (global_pooling) { if (global_pooling) {
kernel_size.push_back(-1); kernel_size.push_back(-1);
...@@ -57,18 +57,18 @@ class Pooling2dOp: public Operator <Context> { ...@@ -57,18 +57,18 @@ class Pooling2dOp: public Operator <Context> {
}; };
template <class Context> template <class Context>
class Pooling2dGradientOp: public Operator<Context> { class Pooling2dGradientOp : public Operator<Context> {
public: public:
Pooling2dGradientOp(const OperatorDef& op_def, Workspace* ws) Pooling2dGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
mode(OperatorBase::GetSingleArg<string>("mode", "MAX")), mode(OperatorBase::Arg<string>("mode", "MAX")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")), data_format(OperatorBase::Arg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")), padding(OperatorBase::Arg<string>("padding", "VALID")),
global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)), global_pooling(OperatorBase::Arg<bool>("global_pooling", false)),
ceil_mode(OperatorBase::GetSingleArg<bool>("ceil", true)) { ceil_mode(OperatorBase::Arg<bool>("ceil", true)) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size"); vector<int> ks = OperatorBase::Args<int>("kernel_size");
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride"); vector<int> s = OperatorBase::Args<int>("stride");
vector<int> p = OperatorBase::GetRepeatedArg<int>("pad"); vector<int> p = OperatorBase::Args<int>("pad");
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
if (global_pooling) { if (global_pooling) {
kernel_size.push_back(-1); kernel_size.push_back(-1);
...@@ -101,8 +101,8 @@ class Pooling2dGradientOp: public Operator<Context> { ...@@ -101,8 +101,8 @@ class Pooling2dGradientOp: public Operator<Context> {
template <class Context> template <class Context>
class CuDNNPooling2dOp final : public Pooling2dOp<Context> { class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
public: public:
CuDNNPooling2dOp(const OperatorDef& op_def, Workspace* ws) CuDNNPooling2dOp(const OperatorDef& def, Workspace* ws)
: Pooling2dOp<Context>(op_def, ws) { : Pooling2dOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc)); CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
...@@ -136,8 +136,8 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> { ...@@ -136,8 +136,8 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
template <class Context> template <class Context>
class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> { class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> {
public: public:
CuDNNPooling2dGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNPooling2dGradientOp(const OperatorDef& def, Workspace* ws)
: Pooling2dGradientOp<Context>(op_def, ws) { : Pooling2dGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc)); CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
......
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class ROIAlignOp : public Operator<Context> { class ROIAlignOp final : public Operator<Context> {
public: public:
ROIAlignOp(const OperatorDef& op_def, Workspace *ws) ROIAlignOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)), pool_h(OperatorBase::Arg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::Arg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)), spatial_scale(OperatorBase::Arg<float>("spatial_scale", 1.0)),
sampling_ratio(OperatorBase::GetSingleArg<int>("sampling_ratio", 2)) { sampling_ratio(OperatorBase::Arg<int>("sampling_ratio", 2)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
...@@ -39,14 +39,14 @@ class ROIAlignOp : public Operator<Context> { ...@@ -39,14 +39,14 @@ class ROIAlignOp : public Operator<Context> {
}; };
template <class Context> template <class Context>
class ROIAlignGradientOp : public Operator<Context> { class ROIAlignGradientOp final : public Operator<Context> {
public: public:
ROIAlignGradientOp(const OperatorDef& op_def, Workspace *ws) ROIAlignGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)), pool_h(OperatorBase::Arg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::Arg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)), spatial_scale(OperatorBase::Arg<float>("spatial_scale", 1.f)),
sampling_ratio(OperatorBase::GetSingleArg<int>("sampling_ratio", 2)) { sampling_ratio(OperatorBase::Arg<int>("sampling_ratio", 2)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class ROIPoolingOp : public Operator<Context> { class ROIPoolingOp final : public Operator<Context> {
public: public:
ROIPoolingOp(const OperatorDef& op_def, Workspace *ws) ROIPoolingOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)), pool_h(OperatorBase::Arg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::Arg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) { spatial_scale(OperatorBase::Arg<float>("spatial_scale", 1.0)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
...@@ -40,11 +40,11 @@ class ROIPoolingOp : public Operator<Context> { ...@@ -40,11 +40,11 @@ class ROIPoolingOp : public Operator<Context> {
template <class Context> template <class Context>
class ROIPoolingGradientOp final : public Operator<Context> { class ROIPoolingGradientOp final : public Operator<Context> {
public: public:
ROIPoolingGradientOp(const OperatorDef& op_def, Workspace* ws) ROIPoolingGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)), pool_h(OperatorBase::Arg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::Arg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {} spatial_scale(OperatorBase::Arg<float>("spatial_scale", 1.f)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -33,19 +33,26 @@ using google::protobuf::io::ZeroCopyInputStream; ...@@ -33,19 +33,26 @@ using google::protobuf::io::ZeroCopyInputStream;
using google::protobuf::io::CodedInputStream; using google::protobuf::io::CodedInputStream;
using google::protobuf::io::FileInputStream; using google::protobuf::io::FileInputStream;
inline void WriteProtoToBinaryFile(const Message& proto, const char* filename) { inline void WriteProtoToBinaryFile(
std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary); const Message& proto,
const char* filename) {
std::fstream output(filename,
std::ios::out |
std::ios::trunc |
std::ios::binary);
proto.SerializeToOstream(&output); proto.SerializeToOstream(&output);
} }
inline bool ReadProtoFromBinaryFile(const char* filename, Message* proto) { inline bool ReadProtoFromBinaryFile(
const char* filename,
Message* proto) {
#ifdef _MSC_VER #ifdef _MSC_VER
int fd = _open(filename, O_RDONLY | O_BINARY); int fd = _open(filename, O_RDONLY | O_BINARY);
#else #else
int fd = open(filename, O_RDONLY); int fd = open(filename, O_RDONLY);
#endif #endif
ZeroCopyInputStream *raw_input = new FileInputStream(fd); ZeroCopyInputStream* raw_input = new FileInputStream(fd);
CodedInputStream *coded_input = new CodedInputStream(raw_input); CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(INT_MAX, -1); coded_input->SetTotalBytesLimit(INT_MAX, -1);
bool success = proto->ParseFromCodedStream(coded_input); bool success = proto->ParseFromCodedStream(coded_input);
delete raw_input; delete raw_input;
...@@ -54,7 +61,9 @@ inline bool ReadProtoFromBinaryFile(const char* filename, Message* proto) { ...@@ -54,7 +61,9 @@ inline bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
return success; return success;
} }
inline void LoadCaffeModel(string file, Workspace* ws) { inline void LoadCaffeModel(
string file,
Workspace* ws) {
NetParameter net_param; NetParameter net_param;
ReadProtoFromBinaryFile(file.c_str(), &net_param); ReadProtoFromBinaryFile(file.c_str(), &net_param);
LOG(INFO) << "Restore From Model @: " << file << "......"; LOG(INFO) << "Restore From Model @: " << file << "......";
...@@ -73,36 +82,40 @@ inline void LoadCaffeModel(string file, Workspace* ws) { ...@@ -73,36 +82,40 @@ inline void LoadCaffeModel(string file, Workspace* ws) {
vector<TIndex> dims; vector<TIndex> dims;
for (auto dim : blob.shape().dim()) dims.push_back(dim); for (auto dim : blob.shape().dim()) dims.push_back(dim);
Tensor* tensor = ws->GetTensor(tensor_name); Tensor* tensor = ws->GetTensor(tensor_name);
std::stringstream dim_string; std::stringstream DimString;
if (dims.size() > 0) { if (dims.size() > 0) {
tensor->Reshape(dims); tensor->Reshape(dims);
CHECK_EQ(tensor->count(), blob.data_size()) CHECK_EQ(tensor->count(), blob.data_size())
<< "\nTensor(" << tensor_name << ") " << "\nTensor(" << tensor_name << ") "
<< "failed to load, except size: " << "failed to load, except size: "
<< tensor->count() << ", loaded: " << blob.data_size(); << tensor->count()
dim_string << tensor->dim_string(); << ", loaded: " << blob.data_size();
DimString << tensor->DimString();
} else { } else {
tensor->Reshape(vector<TIndex>(1, blob.data_size())); tensor->Reshape({ blob.data_size() });
dim_string << "(missing)"; DimString << "(missing)";
} }
float* Xdata = tensor->mutable_data<float, CPUContext>(); float* Xdata = tensor->mutable_data<float, CPUContext>();
for (int idx = 0; idx < blob.data_size(); idx++) for (int idx = 0; idx < blob.data_size(); idx++)
Xdata[idx] = blob.data(idx); Xdata[idx] = blob.data(idx);
LOG(INFO) << "Tensor(" << tensor_name << ") " LOG(INFO) << "Tensor(" << tensor_name << ") "
<< "loaded, shape: " << dim_string.str() << "loaded, shape: " << DimString.str()
<< ", size: " << blob.data_size(); << ", size: " << blob.data_size();
} }
} }
} }
} }
inline void SavaCaffeModel(string file, const vector<Tensor*>& tensors) { inline void SavaCaffeModel(
string file,
const vector<Tensor*>& tensors) {
NetParameter net_param; NetParameter net_param;
Map<string, int> layer_hash; Map<string, int> layer_hash;
int layer_idx = -1; int layer_idx = -1;
for (int i = 0; i < tensors.size(); i++) { for (int i = 0; i < tensors.size(); i++) {
if (tensors[i]->count() <= 0) continue; if (tensors[i]->count() <= 0) continue;
vector<string> splits = SplitString(tensors[i]->name(), "/param:"); vector<string> splits = SplitString(
tensors[i]->name(), "/param:");
if (layer_hash.count(splits[0]) == 0) { if (layer_hash.count(splits[0]) == 0) {
layer_hash[splits[0]] = ++layer_idx; layer_hash[splits[0]] = ++layer_idx;
LayerParameter* layer = net_param.add_layer(); LayerParameter* layer = net_param.add_layer();
...@@ -110,11 +123,14 @@ inline void SavaCaffeModel(string file, const vector<Tensor*>& tensors) { ...@@ -110,11 +123,14 @@ inline void SavaCaffeModel(string file, const vector<Tensor*>& tensors) {
} }
BlobProto* blob = net_param.mutable_layer(layer_idx)->add_blobs(); BlobProto* blob = net_param.mutable_layer(layer_idx)->add_blobs();
for (auto dim : tensors[i]->dims()) blob->mutable_shape()->add_dim(dim); for (auto dim : tensors[i]->dims()) blob->mutable_shape()->add_dim(dim);
const float* Xdata = tensors[i]->data < float, CPUContext >(); const float* Xdata = tensors[i]->data<float, CPUContext>();
for (int id = 0; id < tensors[i]->count(); id++) for (int id = 0; id < tensors[i]->count(); id++)
blob->mutable_data()->Add(Xdata[id]); blob->mutable_data()->Add(Xdata[id]);
} }
std::fstream output(file, std::ios::out | std::ios::trunc | std::ios::binary); std::fstream output(file,
std::ios::out |
std::ios::trunc |
std::ios::binary);
CHECK(net_param.SerializeToOstream(&output)); CHECK(net_param.SerializeToOstream(&output));
LOG(INFO) << "Save the model @: " << file << "......"; LOG(INFO) << "Save the model @: " << file << "......";
LOG(INFO) << "Model format: caffemodel"; LOG(INFO) << "Model format: caffemodel";
......
...@@ -26,9 +26,13 @@ template<> inline int dragon_cast<int, float>(float val) { ...@@ -26,9 +26,13 @@ template<> inline int dragon_cast<int, float>(float val) {
return static_cast<int>(val); return static_cast<int>(val);
} }
template<> inline float dragon_cast<float, float>(float val) { return val; } template<> inline float dragon_cast<float, float>(float val) {
return val;
}
template<> inline float16 dragon_cast<float16, float16>(float16 val) { return val; } template<> inline float16 dragon_cast<float16, float16>(float16 val) {
return val;
}
template<> inline float16 dragon_cast<float16, float>(float val) { template<> inline float16 dragon_cast<float16, float>(float val) {
float16 ret; float16 ret;
...@@ -67,7 +71,8 @@ template<> inline float16 dragon_cast<float16, float>(float val) { ...@@ -67,7 +71,8 @@ template<> inline float16 dragon_cast<float16, float>(float val) {
// Round to nearest even. // Round to nearest even.
remainder = (mantissa & lsb_m1); remainder = (mantissa & lsb_m1);
mantissa >>= shift; mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { if (remainder > lsb_s1 ||
(remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa; ++mantissa;
if (!(mantissa & 0x3ff)) { if (!(mantissa & 0x3ff)) {
++exponent; ++exponent;
......
...@@ -29,9 +29,9 @@ namespace dragon { ...@@ -29,9 +29,9 @@ namespace dragon {
#ifdef WITH_CUDA #ifdef WITH_CUDA
static const int CUDA_NUM_THREADS = 1024; static const int CUDA_THREADS = 1024;
// We do have a server with 10 GPUs :-) // We do have a server with 10 GPUs :-)
#define MAX_GPUS 10 #define CUDA_MAX_DEVICES 10
#define CUDA_VERSION_MIN(major, minor, patch) \ #define CUDA_VERSION_MIN(major, minor, patch) \
(CUDA_VERSION >= (major * 1000 + minor * 100 + patch)) (CUDA_VERSION >= (major * 1000 + minor * 100 + patch))
...@@ -42,7 +42,8 @@ static const int CUDA_NUM_THREADS = 1024; ...@@ -42,7 +42,8 @@ static const int CUDA_NUM_THREADS = 1024;
#define CUDA_CHECK(condition) \ #define CUDA_CHECK(condition) \
do { \ do { \
cudaError_t error = condition; \ cudaError_t error = condition; \
CHECK_EQ(error, cudaSuccess) << "\n" << cudaGetErrorString(error); \ CHECK_EQ(error, cudaSuccess) \
<< "\n" << cudaGetErrorString(error); \
} while (0) } while (0)
#define CUBLAS_CHECK(condition) \ #define CUBLAS_CHECK(condition) \
...@@ -61,7 +62,8 @@ static const int CUDA_NUM_THREADS = 1024; ...@@ -61,7 +62,8 @@ static const int CUDA_NUM_THREADS = 1024;
#define NCCL_CHECK(condition) \ #define NCCL_CHECK(condition) \
do { \ do { \
ncclResult_t status = condition; \ ncclResult_t status = condition; \
CHECK_EQ(status, ncclSuccess) << "\n" << ncclGetErrorString(status); \ CHECK_EQ(status, ncclSuccess) \
<< "\n" << ncclGetErrorString(status); \
} while (0) } while (0)
#endif // WITH_MPI_NCCL #endif // WITH_MPI_NCCL
...@@ -69,12 +71,10 @@ static const int CUDA_NUM_THREADS = 1024; ...@@ -69,12 +71,10 @@ static const int CUDA_NUM_THREADS = 1024;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < n; i += blockDim.x * gridDim.x) i < n; i += blockDim.x * gridDim.x)
inline int GET_BLOCKS(const int N) { inline int CUDA_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; return (N + CUDA_THREADS - 1) / CUDA_THREADS;
} }
#define CUDA_POST_KERNEL_CHECK CUDA_CHECK(cudaPeekAtLastError())
#if CUDA_VERSION_MAX(9, 0, 0) #if CUDA_VERSION_MAX(9, 0, 0)
#define __hdiv hdiv #define __hdiv hdiv
#endif #endif
...@@ -83,18 +83,19 @@ inline int CUDA_NUM_DEVICES() { ...@@ -83,18 +83,19 @@ inline int CUDA_NUM_DEVICES() {
static int count = -1; static int count = -1;
if (count < 0) { if (count < 0) {
auto err = cudaGetDeviceCount(&count); auto err = cudaGetDeviceCount(&count);
if (err == cudaErrorNoDevice || err == cudaErrorInsufficientDriver) count = 0; if (err == cudaErrorNoDevice ||
err == cudaErrorInsufficientDriver) count = 0;
} }
return count; return count;
} }
inline int CUDA_CURRENT_DEVICE() { inline int CUDA_DEVICE() {
int gpu_id; int gpu_id;
cudaGetDevice(&gpu_id); cudaGetDevice(&gpu_id);
return gpu_id; return gpu_id;
} }
inline int CUDA_POINTER_DEVICE(const void* ptr) { inline int CUDA_DEVICE(const void* ptr) {
cudaPointerAttributes attr; cudaPointerAttributes attr;
CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
return attr.device; return attr.device;
...@@ -108,16 +109,18 @@ struct CUDADeviceProps { ...@@ -108,16 +109,18 @@ struct CUDADeviceProps {
vector<cudaDeviceProp> props; vector<cudaDeviceProp> props;
}; };
inline const cudaDeviceProp& GetDeviceProperty(const int device_id) { inline const cudaDeviceProp& GetDeviceProperty(
const int device_id) {
static CUDADeviceProps props; static CUDADeviceProps props;
CHECK_LT(device_id, (int)props.props.size()) CHECK_LT(device_id, (int)props.props.size())
<< "Invalid device id: " << device_id << "Invalid device id: " << device_id
<< "\nDetected " << props.props.size() << " eligible cuda devices."; << "\nDetected " << props.props.size()
<< " eligible cuda devices.";
return props.props[device_id]; return props.props[device_id];
} }
inline bool CUDA_TRUE_FP16_AVAILABLE() { inline bool CUDA_TRUE_FP16_AVAILABLE() {
int device = CUDA_CURRENT_DEVICE(); int device = CUDA_DEVICE();
auto& prop = GetDeviceProperty(device); auto& prop = GetDeviceProperty(device);
return prop.major >= 6; return prop.major >= 6;
} }
...@@ -126,7 +129,7 @@ inline bool TENSOR_CORE_AVAILABLE() { ...@@ -126,7 +129,7 @@ inline bool TENSOR_CORE_AVAILABLE() {
#if CUDA_VERSION < 9000 #if CUDA_VERSION < 9000
return false; return false;
#else #else
int device = CUDA_CURRENT_DEVICE(); int device = CUDA_DEVICE();
auto& prop = GetDeviceProperty(device); auto& prop = GetDeviceProperty(device);
return prop.major >= 7; return prop.major >= 7;
#endif #endif
...@@ -134,11 +137,15 @@ inline bool TENSOR_CORE_AVAILABLE() { ...@@ -134,11 +137,15 @@ inline bool TENSOR_CORE_AVAILABLE() {
class DeviceGuard { class DeviceGuard {
public: public:
DeviceGuard(int newDevice) : previous_(CUDA_CURRENT_DEVICE()) { DeviceGuard(int newDevice)
: previous_(CUDA_DEVICE()) {
if (previous_ != newDevice) if (previous_ != newDevice)
CUDA_CHECK(cudaSetDevice(newDevice)); CUDA_CHECK(cudaSetDevice(newDevice));
} }
~DeviceGuard() { CUDA_CHECK(cudaSetDevice(previous_)); }
~DeviceGuard() {
CUDA_CHECK(cudaSetDevice(previous_));
}
private: private:
int previous_; int previous_;
......
...@@ -21,8 +21,9 @@ namespace dragon { ...@@ -21,8 +21,9 @@ namespace dragon {
template <typename T, class Context> template <typename T, class Context>
class Filler { class Filler {
public: public:
Filler(const TensorFiller& filler): filler_(filler) {} Filler(const TensorFiller& filler) : filler_(filler) {}
virtual void Fill(Tensor* tensor) = 0;
virtual void Fill(Tensor* tensor, Context* ctx) = 0;
inline TensorFiller& filler() { return filler_; } inline TensorFiller& filler() { return filler_; }
...@@ -30,101 +31,125 @@ class Filler { ...@@ -30,101 +31,125 @@ class Filler {
TensorFiller filler_; TensorFiller filler_;
}; };
template <typename T, class Context> template <typename T, class Context>
class ConstantFiller final : public Filler<T, Context> { class ConstantFiller final : public Filler<T, Context> {
public: public:
ConstantFiller(const TensorFiller& filler): Filler<T, Context>(filler) {} ConstantFiller(const TensorFiller& filler)
void Fill(Tensor* tensor) override { : Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
math::Set<T, Context>(tensor->count(), math::Set<T, Context>(tensor->count(),
dragon_cast<T, float>(this->filler().value()), dragon_cast<T, float>(filler().value()),
tensor->mutable_data<T, Context>()); tensor->mutable_data<T, Context>());
} }
protected:
using Filler<T, Context>::filler;
}; };
template <typename T, class Context> template <typename T, class Context>
class NormalFiller final : public Filler<T, Context> { class NormalFiller final : public Filler<T, Context> {
public: public:
NormalFiller(const TensorFiller& filler): Filler<T, Context>(filler) {} NormalFiller(const TensorFiller& filler)
void Fill(Tensor* tensor) override { : Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
math::RandomNormal<T, Context>(tensor->count(), math::RandomNormal<T, Context>(tensor->count(),
this->filler().mean(), filler().mean(), filler().std(),
this->filler().std(), tensor->mutable_data<T, Context>(), ctx);
tensor->mutable_data<T, Context>());
} }
protected:
using Filler<T, Context>::filler;
}; };
template <typename T, class Context> template <typename T, class Context>
class TruncatedNormalFiller final : public Filler < T, Context > { class TruncatedNormalFiller final : public Filler<T, Context> {
public: public:
TruncatedNormalFiller(const TensorFiller& filler): Filler<T, Context>(filler) {} TruncatedNormalFiller(const TensorFiller& filler)
void Fill(Tensor* tensor) override { : Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
// implement it on gpu is difficult // implement it on gpu is difficult
static CPUContext cpu_ctx;
math::RandomTruncatedNormal<T, CPUContext>(tensor->count(), math::RandomTruncatedNormal<T, CPUContext>(tensor->count(),
this->filler().mean(), filler().mean(), filler().std(),
this->filler().std(), filler().low(), filler().high(),
this->filler().low(), tensor->mutable_data<T, CPUContext>(), &cpu_ctx);
this->filler().high(),
tensor->mutable_data<T, CPUContext>());
} }
protected:
using Filler<T, Context>::filler;
}; };
template <typename T, class Context> template <typename T, class Context>
class UniformFiller final : public Filler<T, Context> { class UniformFiller final : public Filler<T, Context> {
public: public:
UniformFiller(const TensorFiller& filler) : Filler<T, Context>(filler) {} UniformFiller(const TensorFiller& filler)
void Fill(Tensor* tensor) override { : Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
math::RandomUniform<T, Context>(tensor->count(), math::RandomUniform<T, Context>(tensor->count(),
this->filler().low(), filler().low(), filler().high(),
this->filler().high(), tensor->mutable_data<T, Context>(), ctx);
tensor->mutable_data<T, Context>());
} }
protected:
using Filler<T, Context>::filler;
}; };
template <typename T, class Context> template <typename T, class Context>
class XavierFiller final : public Filler<T, Context> { class XavierFiller final : public Filler<T, Context> {
public: public:
XavierFiller(const TensorFiller& filler) : Filler<T, Context>(filler) {} XavierFiller(const TensorFiller& filler)
using Filler<T, Context>::filler; : Filler<T, Context>(filler) {}
void Fill(Tensor* tensor) override {
void Fill(Tensor* tensor, Context* ctx) override {
int fan_in = tensor->count() / tensor->dim(0); int fan_in = tensor->count() / tensor->dim(0);
int fan_out = tensor->count() / tensor->dim(1); int fan_out = tensor->count() / tensor->dim(1);
float n = fan_in, scale = 3.0; float n = fan_in, scale = 3.0;
if (filler().has_scale()) scale = filler().scale(); if (filler().has_scale()) scale = filler().scale();
if (filler().variance_norm() == TensorFiller_VarianceNorm_FAN_AVG) { if (filler().variance_norm() ==
TensorFiller_VarianceNorm_FAN_AVG) {
n = (fan_in + fan_out) / float(2); n = (fan_in + fan_out) / float(2);
} else if (filler().variance_norm() == TensorFiller_VarianceNorm_FAN_OUT) { } else if (filler().variance_norm() ==
TensorFiller_VarianceNorm_FAN_OUT) {
n = fan_out; n = fan_out;
} }
float limit = std::sqrt(scale / n); float limit = std::sqrt(scale / n);
math::RandomUniform<T, Context>(tensor->count(), math::RandomUniform<T, Context>(tensor->count(),
-limit, -limit, limit, tensor->mutable_data<T, Context>(), ctx);
limit,
tensor->mutable_data<T, Context>());
} }
protected:
using Filler<T, Context>::filler;
}; };
template <typename T, class Context> template <typename T, class Context>
class MSRAFiller final : public Filler <T, Context> { class MSRAFiller final : public Filler <T, Context> {
public: public:
MSRAFiller(const TensorFiller& filler) : Filler<T, Context>(filler) {} MSRAFiller(const TensorFiller& filler)
using Filler<T, Context>::filler; : Filler<T, Context>(filler) {}
void Fill(Tensor* tensor) override {
void Fill(Tensor* tensor, Context* ctx) override {
int fan_in = tensor->count() / tensor->dim(0); int fan_in = tensor->count() / tensor->dim(0);
int fan_out = tensor->count() / tensor->dim(1); int fan_out = tensor->count() / tensor->dim(1);
float n = fan_in, scale = 2.0; float n = fan_in, scale = 2.0;
if (filler().has_scale()) scale = filler().scale(); if (filler().has_scale()) scale = filler().scale();
if (filler().variance_norm() == TensorFiller_VarianceNorm_FAN_AVG) { if (filler().variance_norm() ==
TensorFiller_VarianceNorm_FAN_AVG) {
n = (fan_in + fan_out) / float(2); n = (fan_in + fan_out) / float(2);
} else if (filler().variance_norm() == TensorFiller_VarianceNorm_FAN_OUT) { } else if (filler().variance_norm() ==
TensorFiller_VarianceNorm_FAN_OUT) {
n = fan_out; n = fan_out;
} }
float std = std::sqrt(scale / n); float std = std::sqrt(scale / n);
math::RandomNormal<T, Context>(tensor->count(), math::RandomNormal<T, Context>(tensor->count(),
float(0), 0.f, std, tensor->mutable_data<T, Context>(), ctx);
std,
tensor->mutable_data<T, Context>());
} }
protected:
using Filler<T, Context>::filler;
}; };
......
...@@ -43,14 +43,16 @@ void RandomUniform( ...@@ -43,14 +43,16 @@ void RandomUniform(
const int n, const int n,
const float low, const float low,
const float high, const float high,
T* x); T* x,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void RandomNormal( void RandomNormal(
const int n, const int n,
const float mu, const float mu,
const float sigma, const float sigma,
T* x); T* x,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void RandomTruncatedNormal( void RandomTruncatedNormal(
...@@ -59,13 +61,15 @@ void RandomTruncatedNormal( ...@@ -59,13 +61,15 @@ void RandomTruncatedNormal(
const float sigma, const float sigma,
const float low, const float low,
const float high, const float high,
T* x); T* x,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void RandomBernoulli( void RandomBernoulli(
const int n, const int n,
const float p, const float p,
uint32_t* x); uint32_t* x,
Context* ctx);
/******************** Level-1 ********************/ /******************** Level-1 ********************/
...@@ -148,14 +152,16 @@ template <typename T, class Context> ...@@ -148,14 +152,16 @@ template <typename T, class Context>
void Scal( void Scal(
const int n, const int n,
const float alpha, const float alpha,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Scale( void Scale(
const int n, const int n,
const float alpha, const float alpha,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
T StridedDot( T StridedDot(
...@@ -163,13 +169,15 @@ T StridedDot( ...@@ -163,13 +169,15 @@ T StridedDot(
const T* a, const T* a,
const int incx, const int incx,
const T* b, const T* b,
const int incy); const int incy,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
float Dot( float Dot(
const int n, const int n,
const T* a, const T* a,
const T* b); const T* b,
Context* ctx);
template<typename T, class Context> template<typename T, class Context>
float ASum( float ASum(
...@@ -193,7 +201,8 @@ void Axpy( ...@@ -193,7 +201,8 @@ void Axpy(
const int n, const int n,
float alpha, float alpha,
const T* x, const T* x,
T* y); T* y,
Context* ctx);
template<typename T, class Context> template<typename T, class Context>
void Axpby( void Axpby(
...@@ -201,7 +210,8 @@ void Axpby( ...@@ -201,7 +210,8 @@ void Axpby(
float alpha, float alpha,
const T* x, const T* x,
float beta, float beta,
T* y); T* y,
Context* ctx);
/******************** Level-3 ********************/ /******************** Level-3 ********************/
...@@ -217,6 +227,7 @@ void Gemm( ...@@ -217,6 +227,7 @@ void Gemm(
const T* B, const T* B,
const float beta, const float beta,
T* C, T* C,
Context* ctx,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT); TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
template<typename T, class Context> template<typename T, class Context>
...@@ -229,6 +240,7 @@ void Gemv( ...@@ -229,6 +240,7 @@ void Gemv(
const T* x, const T* x,
const float beta, const float beta,
T* y, T* y,
Context* ctx,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT); TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
} // namespace math } // namespace math
......
...@@ -20,9 +20,6 @@ namespace kernel { ...@@ -20,9 +20,6 @@ namespace kernel {
typedef int64_t TIndex; typedef int64_t TIndex;
template <typename T, class Context>
void Empty();
/******************** activation.dropout ********************/ /******************** activation.dropout ********************/
template <typename T, class Context> template <typename T, class Context>
...@@ -32,7 +29,8 @@ void Dropout( ...@@ -32,7 +29,8 @@ void Dropout(
T scale, T scale,
const T* x, const T* x,
uint32_t* mask, uint32_t* mask,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void DropoutGrad( void DropoutGrad(
...@@ -41,7 +39,8 @@ void DropoutGrad( ...@@ -41,7 +39,8 @@ void DropoutGrad(
T scale, T scale,
const T* dy, const T* dy,
const uint32_t* mask, const uint32_t* mask,
T* dx); T* dx,
Context* ctx);
/******************** activation.elu ********************/ /******************** activation.elu ********************/
...@@ -97,7 +96,8 @@ void PReluWGrad( ...@@ -97,7 +96,8 @@ void PReluWGrad(
const T* x, const T* x,
const T* multiplier, const T* multiplier,
T* bcast_dw, T* bcast_dw,
T* dw); T* dw,
Context* ctx);
/******************** activation.relu ********************/ /******************** activation.relu ********************/
...@@ -157,7 +157,8 @@ void Softmax( ...@@ -157,7 +157,8 @@ void Softmax(
const T* sum_multiplier, const T* sum_multiplier,
const T* x, const T* x,
T* scale, T* scale,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void SoftmaxGrad( void SoftmaxGrad(
...@@ -169,7 +170,8 @@ void SoftmaxGrad( ...@@ -169,7 +170,8 @@ void SoftmaxGrad(
const T* dy, const T* dy,
const T* y, const T* y,
T* scale, T* scale,
T* dx); T* dx,
Context* ctx);
/******************** activation.tanh ********************/ /******************** activation.tanh ********************/
...@@ -198,7 +200,8 @@ void Affine( ...@@ -198,7 +200,8 @@ void Affine(
const T* alpha, const T* alpha,
const T* beta, const T* beta,
const T* beta_multiplier, const T* beta_multiplier,
T* y); T* y,
Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void AffineGrad( void AffineGrad(
...@@ -208,7 +211,8 @@ void AffineGrad( ...@@ -208,7 +211,8 @@ void AffineGrad(
const int inner_dim, const int inner_dim,
const T* dy, const T* dy,
const T* alpha, const T* alpha,
T* dx); T* dx,
Context* ctx);
/******************** arithmetic.clip ********************/ /******************** arithmetic.clip ********************/
...@@ -293,7 +297,8 @@ void SparseSoftmaxCrossEntropy( ...@@ -293,7 +297,8 @@ void SparseSoftmaxCrossEntropy(
const Ty* labels, const Ty* labels,
Tx* loss, Tx* loss,
Tx* valid, Tx* valid,
Tensor* ignore); Tensor* ignore,
Context* ctx);
template <typename Tx, typename Ty, class Context> template <typename Tx, typename Ty, class Context>
void SparseSoftmaxCrossEntropyGrad( void SparseSoftmaxCrossEntropyGrad(
...@@ -305,7 +310,8 @@ void SparseSoftmaxCrossEntropyGrad( ...@@ -305,7 +310,8 @@ void SparseSoftmaxCrossEntropyGrad(
const Ty* labels, const Ty* labels,
Tx* valid, Tx* valid,
Tensor* ignore, Tensor* ignore,
Tx* dx); Tx* dx,
Context* ctx);
/******************** loss.sparse_softmax_focal_loss ********************/ /******************** loss.sparse_softmax_focal_loss ********************/
...@@ -585,7 +591,8 @@ void RepeatGrad( ...@@ -585,7 +591,8 @@ void RepeatGrad(
const int inner_dim, const int inner_dim,
const int repeats, const int repeats,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** ndarray.slice ********************/ /******************** ndarray.slice ********************/
...@@ -629,7 +636,8 @@ void TileGrad( ...@@ -629,7 +636,8 @@ void TileGrad(
const int ex_inner_dim, const int ex_inner_dim,
const int multiple, const int multiple,
const T* dy, const T* dy,
T* dx); T* dx,
Context* ctx);
/******************** ndarray.transpose ********************/ /******************** ndarray.transpose ********************/
...@@ -733,7 +741,8 @@ void BiasAdd( ...@@ -733,7 +741,8 @@ void BiasAdd(
const string& data_format, const string& data_format,
const T* bias, const T* bias,
const T* bias_multiplier, const T* bias_multiplier,
T* y); T* y,
Context* ctx);
/******************** vision.bilinear_resize ********************/ /******************** vision.bilinear_resize ********************/
......
...@@ -22,8 +22,11 @@ namespace dragon { ...@@ -22,8 +22,11 @@ namespace dragon {
using google::protobuf::Message; using google::protobuf::Message;
template <class IterableInputs,class IterableOutputs,class IterableArgs> template <class IterableInputs,
inline OperatorDef MakeOperatorDef(const string& type, class IterableOutputs,
class IterableArgs>
inline OperatorDef MakeOperatorDef(
const string& type,
const string& name, const string& name,
const IterableInputs& inputs, const IterableInputs& inputs,
const IterableOutputs& outputs, const IterableOutputs& outputs,
...@@ -36,29 +39,42 @@ inline OperatorDef MakeOperatorDef(const string& type, ...@@ -36,29 +39,42 @@ inline OperatorDef MakeOperatorDef(const string& type,
for (const string& in : inputs) def.add_input(in); for (const string& in : inputs) def.add_input(in);
for (const string& out : outputs) def.add_output(out); for (const string& out : outputs) def.add_output(out);
for (const Argument& arg : args) def.add_arg()->CopyFrom(arg); for (const Argument& arg : args) def.add_arg()->CopyFrom(arg);
if (device_option.has_device_type()) def.mutable_device_option()->CopyFrom(device_option); if (device_option.has_device_type())
def.mutable_device_option()->CopyFrom(device_option);
return def; return def;
} }
template <class IterableInputs, class IterableOutputs, class IterableArgs> template <class IterableInputs,
inline OperatorDef MakeOperatorDef(const string& type, class IterableOutputs,
class IterableArgs>
inline OperatorDef MakeOperatorDef(
const string& type,
const string& name, const string& name,
const IterableInputs& inputs, const IterableInputs& inputs,
const IterableOutputs& outputs, const IterableOutputs& outputs,
const IterableArgs& args) { const IterableArgs& args) {
return MakeOperatorDef(type, name, inputs, outputs, args, DeviceOption(), ""); return MakeOperatorDef(
type, name, inputs, outputs, args,
DeviceOption(), "");
} }
template <class IterableInputs, class IterableOutputs> template <class IterableInputs,
inline OperatorDef MakeOperatorDef(const string& type, class IterableOutputs>
inline OperatorDef MakeOperatorDef(
const string& type,
const string& name, const string& name,
const IterableInputs& inputs, const IterableInputs& inputs,
const IterableOutputs& outputs) { const IterableOutputs& outputs) {
return MakeOperatorDef(type, name, inputs, outputs, vector<Argument>(), DeviceOption(), ""); return MakeOperatorDef(
type, name, inputs, outputs,
vector<Argument>(), DeviceOption(), "");
} }
inline void ParseProtoFromText(string text, Message* proto) { inline void ParseProtoFromText(
google::protobuf::TextFormat::ParseFromString(text, proto); string text,
Message* proto) {
google::protobuf::TextFormat
::ParseFromString(text, proto);
} }
} // namespace dragon } // namespace dragon
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!