Commit 5dea1524 by Ting PAN

Refactor Context/Stream & Add CTC Loss

1 parent 081258b9
Showing with 1399 additions and 1045 deletions
......@@ -12,6 +12,7 @@
#ifndef DRAGON_CORE_COMMON_H_
#define DRAGON_CORE_COMMON_H_
#include <ctime>
#include <climits>
#include <memory>
#include <string>
......@@ -48,7 +49,26 @@ using Map = std::unordered_map<Key, Value>;
template <typename Value>
using Set = std::unordered_set<Value> ;
#define DRAGON_VERSION 2205
/*
* Define the Kernel version.
*
* | Major(2) | Minor(2) | Patch(06) |
*/
#define DRAGON_VERSION 2206
/*
* Define the default random seed.
*/
#define DEFAULT_RNG_SEED 3
/*
* Define the common marcos.
*/
#ifdef _MSC_VER
#if _MSC_VER < 1900
#define thread_local __declspec(thread)
#endif
#endif
#define CONCATENATE_IMPL(s1, s2) s1##s2
#define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2)
......
......@@ -23,21 +23,18 @@
namespace dragon {
class CPUObject {
public:
unique_ptr<std::mt19937> rand_generator;
};
class CPUContext {
public:
CPUContext(): random_seed_(3) { generator(); }
CPUContext(unsigned int random_seed): random_seed_(random_seed) { generator(); }
CPUContext(const DeviceOption& option): random_seed_(option.has_random_seed() ?
option.random_seed() : 3) { generator(); }
CPUContext(): random_seed_(3) {}
CPUContext(unsigned int random_seed)
: random_seed_(random_seed) {}
CPUContext(const DeviceOption& option)
: random_seed_(option.has_random_seed() ?
option.random_seed() : DEFAULT_RNG_SEED) {}
virtual ~CPUContext() {}
inline void SwitchToDevice() {}
inline void static FinishDeviceCompution() { return; }
inline void FinishDeviceCompution() {}
inline static void* New(size_t nbytes) {
void* data;
......@@ -50,40 +47,54 @@ class CPUContext {
return data;
}
inline static void Memset(size_t nbytes, void* ptr) { memset(ptr, 0, nbytes); }
inline static void Memset(size_t nbytes, void* ptr) {
memset(ptr, 0, nbytes);
}
template<class DstContext, class SrcContext>
inline static void Memcpy(size_t nbytes, void* dst, const void* src) { memcpy(dst, src, nbytes); }
inline static void Memcpy(
size_t nbytes,
void* dst,
const void* src) {
memcpy(dst, src, nbytes);
}
inline static void Delete(void* data) { free(data); }
template<class DstContext, class SrcContext>
inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) { NOT_IMPLEMENTED; }
inline static void MemcpyAsync(
size_t nbytes,
void* dst,
const void* src) {
NOT_IMPLEMENTED;
}
template<typename T, class DstContext, class SrcContext>
inline static void Copy(int n, T* dst, const T* src) {
inline static void Copy(
int n,
T* dst,
const T* src) {
if (dst == src) return;
// only the basic types(e.g. int/float) can memcpy correctly
if (std::is_fundamental<T>::value)
Memcpy<DstContext, SrcContext>(n * sizeof(T), (void*)dst, (const void*)src);
Memcpy<DstContext, SrcContext>(
n * sizeof(T), (void*)dst, (const void*)src);
else for (int i = 0; i < n; i++) dst[i] = src[i];
}
inline std::mt19937* generator() {
auto& generator = cpu_object_.rand_generator;
if (!generator.get())
generator.reset(new std::mt19937(random_seed_));
return generator.get();
}
inline int device_id() const { return 0; }
static CPUObject cpu_object_;
inline std::mt19937* rand_generator() {
if (!rand_generator_.get())
rand_generator_.reset(new std::mt19937(random_seed_));
return rand_generator_.get();
}
private:
unsigned int random_seed_;
unique_ptr<std::mt19937> rand_generator_;
};
static inline std::mt19937* rand_generator() {
return CPUContext::cpu_object_.rand_generator.get();
}
#define CPU_FP16_NOT_SUPPORTED \
LOG(FATAL) << "FP16 is unsupported for CPUContext.";
......
......@@ -26,11 +26,18 @@ class GraphBase {
string op_type;
};
GraphBase(const GraphDef& meta_graph, Workspace* ws);
GraphBase(
const GraphDef& meta_graph,
Workspace* ws);
virtual ~GraphBase() {}
virtual bool Create(const GraphDef& optimized_graph, Workspace* ws) = 0;
virtual bool Run(const string& include, const string& exclude) = 0;
virtual bool Create(
const GraphDef& optimized_graph,
Workspace* ws) = 0;
virtual bool Run(
const string& include,
const string& exclude) = 0;
inline string name() const { return name_; }
......@@ -45,21 +52,31 @@ class Graph final : public GraphBase {
Graph(const GraphDef& meta_graph, Workspace* ws);
~Graph() { for (auto* op : ops_) delete op; }
bool Create(const GraphDef& optimized_graph, Workspace* ws) override;
bool Run(const string& include, const string& exclude) override;
bool Create(
const GraphDef& optimized_graph,
Workspace* ws) override;
bool Run(
const string& include,
const string& exclude) override;
GraphDef Prune(const GraphDef& meta_graph);
GraphDef MakeUpdate(const GraphDef& meta_graph);
GraphDef Share(const GraphDef& optimized_graph);
void ShareGrads(GraphDef& optimized_graph);
void RecomputingAware(const GraphDef& optimized_graph, Workspace* ws);
void RecomputingAware(
const GraphDef& optimized_graph,
Workspace* ws);
inline Workspace* ws() const { return ws_; }
private:
void ForwardShareDyeing(string u, string ancestor);
void ForwardPruneDyeing(string u, string leaf, vector<string> path);
void ForwardPruneDyeing(
string u,
string leaf,
vector<string> path);
void BackwardPruneDyeing(string v);
vector<OperatorBase*> ops_;
......@@ -69,8 +86,15 @@ class Graph final : public GraphBase {
Set<string> targets_;
};
GraphBase* NewGraph(const GraphDef& meta_graph, Workspace* ws);
DECLARE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*);
GraphBase* NewGraph(
const GraphDef& meta_graph,
Workspace* ws);
DECLARE_REGISTRY(
GraphRegistry,
GraphBase,
const GraphDef&,
Workspace*);
} // namespace dragon
......
......@@ -18,24 +18,32 @@ namespace dragon {
class GraphGradientMaker {
public:
GraphGradientMaker() : cur_op_idx_(0) {}
GraphGradientMaker(): cur_op_idx_(0) {}
void Make(const GraphDef& forward_def,
void Make(
const GraphDef& forward_def,
const vector<string>& targets,
GraphDef& new_def);
void Share(const string& grads_prefix, GraphDef& graph);
inline void SetTerms(const Map<string, string>& terms) { terms_ = terms; }
inline void SetOperatorPrefix(const string& prefix) { op_prefix_ = prefix; }
inline void SetOperatorSuffix(const string& suffix) { op_suffix_ = suffix; }
inline void AddExternalGrad(const string& name) { external_grads_.insert(name); }
inline void AddIgnoreGrad(const string& name) { ignore_grads_.insert(name); }
inline void SetTerms(
const Map<string, string>& terms) { terms_ = terms; }
inline void SetOperatorPrefix(
const string& prefix) { op_prefix_ = prefix; }
inline void SetOperatorSuffix(
const string& suffix) { op_suffix_ = suffix; }
inline void AddExternalGrad(
const string& name) { external_grads_.insert(name); }
inline void AddIgnoreGrad(
const string& name) { ignore_grads_.insert(name); }
private:
bool CheckGrad(const OperatorDef& forward_op,
bool CheckGrad(
const OperatorDef& forward_op,
const Set<string>& targets,
vector< pair<string, int> >& gen_grads);
string GetOperatorName();
Map<string, string> terms_, inputs_to_grads_;
......
......@@ -19,7 +19,13 @@ namespace dragon {
class MixedMemory {
public:
enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED };
enum State {
UNINITIALIZED,
STATE_AT_CPU,
STATE_AT_CUDA,
SWITCHED,
SYNCED };
MixedMemory() : cpu_ptr_(nullptr), cuda_ptr_(nullptr) {}
MixedMemory(const TypeMeta& meta, const size_t nbytes)
: meta_(meta), nbytes_(nbytes),
......@@ -31,9 +37,6 @@ class MixedMemory {
void* mutable_cpu_data();
void* mutable_cuda_data();
void set_cpu_data(void* cpu_ptr, size_t nbytes);
#ifdef WITH_CUDA
void async_cuda_data(const cudaStream_t& stream);
#endif
void SwitchToDevice();
void SwitchToCUDADevice(int device_id);
......
......@@ -29,7 +29,7 @@ class Workspace;
class OperatorBase {
public:
OperatorBase(const OperatorDef& op_def, Workspace* ws);
OperatorBase(const OperatorDef& def, Workspace* ws);
virtual ~OperatorBase() {}
Tensor& Input(int idx);
......@@ -38,66 +38,73 @@ class OperatorBase {
inline size_t InputSize() { return inputs_.size(); }
inline size_t OutputSize() { return outputs_.size(); }
void MutableOp(const OperatorDef& op_def);
void MutableOp(const OperatorDef& def);
void MutableOp(const vector<string>& inputs,
const vector<string>& outputs,
const string& anchor);
inline void SwitchToPhase(const string& phase) { this->phase_ = phase; }
inline void SwitchToPhase(const string& phase) { phase_ = phase; }
virtual void Run() { NOT_IMPLEMENTED; }
inline const string& name() const { return op_def_.name(); }
inline const string& type() const { return op_def_.type(); }
inline const string& name() const { return def_.name(); }
inline const string& type() const { return def_.type(); }
inline const string& phase() const { return phase_; }
inline const string& anchor() { return anchor_; }
inline Workspace* ws() const { return ws_; }
template <typename T>
T GetSingleArg(const string& name, const T& default_value);
T Arg(const string& name, const T& default_value);
template <typename T>
vector<T> GetRepeatedArg(const string& name);
vector<T> Args(const string& name);
inline const Map<std::string, const Argument*>& args() { return args_; }
inline const Argument& arg(const string& name) { return *(args_[name]); }
typedef Map<string, vector<OperatorBase*> > RecomputeMap;
inline RecomputeMap& recompute_map() { return recompute_map_; }
void set_recompute_map(RecomputeMap recompute_map) { recompute_map_ = recompute_map; }
void set_recompute_map(RecomputeMap recompute_map) {
recompute_map_ = recompute_map;
}
inline const OperatorDef& op_def() const { return op_def_; }
inline string DebugString() const { return op_def_.DebugString(); }
string DTypeHelper(const Tensor& tensor, const Set<string>& dtypes) const;
string DTypeHelper(const string& dtype, const Set<string>& dtypes) const;
inline const OperatorDef& def() const { return def_; }
inline string DebugString() const { return def_.DebugString(); }
string DTypeHelper(
const Tensor& tensor,
const Set<string>& dtypes) const;
string DTypeHelper(
const string& dtype,
const Set<string>& dtypes) const;
protected:
string phase_, anchor_;
Map<std::string, const Argument*> args_;
Map<string, vector<OperatorBase*> > recompute_map_;
vector<Tensor*> inputs_, outputs_;
OperatorDef op_def_;
OperatorDef def_;
Workspace* ws_;
};
template <class Context>
class Operator : public OperatorBase {
public:
Operator(const OperatorDef& op_def, Workspace* ws)
: OperatorBase(op_def, ws), ctx_(op_def.device_option()),
do_synchronize_(Operator::GetSingleArg<bool>("do_synchronize", false)),
recomputing_aware_(Operator::GetSingleArg<bool>("recomputing_aware", false)) {
Operator(const OperatorDef& def, Workspace* ws)
: OperatorBase(def, ws), ctx_(def.device_option()),
recomputing_aware_(OperatorBase::Arg<bool>(
"recomputing_aware", false)) {
allow_run_ = true;
allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && Output(0)->name() == "ignore"));
allow_run_ &= (!(OutputSize() == 1 &&
Output(0)->name() == "ignore"));
}
virtual void Run() final {
if (!allow_run_) return;
if (recomputing_aware_) MakeResource();
ctx_.SwitchToDevice();
ctx().SwitchToDevice();
MemorySwitch();
RunOnDevice();
if (do_synchronize_) ctx_.FinishDeviceCompution();
ctx().FinishDeviceCompution();
if (recomputing_aware_) CleanResource();
}
......@@ -106,8 +113,10 @@ class Operator : public OperatorBase {
virtual void CleanResource();
void MemorySwitch() {
for (auto* I : inputs_) if(I->name() != "ignore") I->SwitchToDevice();
for (auto* O : outputs_) if(O->name() != "ignore") O->SwitchToDevice();
for (auto* I : inputs_)
if(I->name() != "ignore") I->SwitchToDevice();
for (auto* O : outputs_)
if(O->name() != "ignore") O->SwitchToDevice();
}
virtual void RunOnDevice() = 0;
......@@ -117,14 +126,15 @@ class Operator : public OperatorBase {
protected:
Context ctx_;
bool allow_run_, recomputing_aware_, do_synchronize_;
bool allow_run_, recomputing_aware_;
private:
bool _MPICheck() {
#ifndef WITH_MPI
return true;
#else
vector<int> allow_ranks = Operator::GetRepeatedArg<int>("mpi_ranks");
vector<int> allow_ranks =
OperatorBase::Args<int>("mpi_ranks");
if (allow_ranks.empty()) return true;
int cur_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank);
......@@ -135,11 +145,11 @@ class Operator : public OperatorBase {
}
};
OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
OperatorBase* CreateOperator(const OperatorDef& def, Workspace* ws);
#define USE_SIMPLE_CTOR_DTOR(name) \
name(const OperatorDef& op_def, Workspace* ws) \
: Operator<Context>(op_def, ws) {} \
name(const OperatorDef& def, Workspace* ws) \
: Operator<Context>(def, ws) {} \
virtual ~name() {}
#define USE_OPERATOR_BASE_FUNCTIONS \
......@@ -150,7 +160,7 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
using OperatorBase::type; \
using OperatorBase::phase; \
using OperatorBase::anchor; \
using OperatorBase::op_def; \
using OperatorBase::def; \
using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \
using OperatorBase::DebugString; \
......@@ -162,9 +172,23 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
using Operator<Context>::ctx; \
using Operator<Context>::AllowRun
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
DECLARE_REGISTRY(
CPUOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
DECLARE_REGISTRY(
CUDAOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
DECLARE_REGISTRY(
CUDNNOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
#define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \
......@@ -174,7 +198,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
tensor.Reshape(shape); \
unique_ptr< Filler<T, Context> > filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor); \
filler->Fill(&tensor, &ctx()); \
} else { \
TIndex count = 1; \
for(int i = 0; i < shape.size(); i++) count *= shape[i]; \
......@@ -189,7 +213,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
#define INIT_MULTIPLIER(ptr_tensor, size) { \
ptr_tensor = ws()->CreateTensor("/share/multiplier"); \
if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape(vector<TIndex>(1, size)); \
ptr_tensor->Reshape({ size }); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
ptr_tensor->template mutable_data<T, Context>()); \
} \
......@@ -214,12 +238,12 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
type argument(int idx)
#define GET_ARGUMENT_WITH_DESC(type, argument, default_value) \
argument##_value = OperatorBase::GetSingleArg<type>(#argument, default_value); \
argument##_desc = OperatorBase::GetSingleArg<string>(string(#argument) + "_desc", "")
argument##_value = OperatorBase::Arg<type>(#argument, default_value); \
argument##_desc = OperatorBase::Arg<string>(string(#argument) + "_desc", "")
#define GET_ARGUMENTS_WITH_DESC(type, argument) \
argument##_value = OperatorBase::GetRepeatedArg<type>(#argument); \
argument##_desc = OperatorBase::GetRepeatedArg<string>(string(#argument) + "_desc")
argument##_value = OperatorBase::Args<type>(#argument); \
argument##_desc = OperatorBase::Args<string>(string(#argument) + "_desc")
#define DEFINE_ARGUMENT_WITH_DESC(type, classname, argument) \
template <class Context> \
......
......@@ -23,7 +23,8 @@ struct Gradient {
vector<OperatorDef> ops;
vector<string> g_inputs;
vector<float> defaults;
Gradient(const vector<OperatorDef>& ops,
Gradient(
const vector<OperatorDef>& ops,
const vector<string>& g_inputs,
const vector<float>& defaults)
: ops(ops), g_inputs(g_inputs), defaults(defaults) {}
......@@ -31,9 +32,11 @@ struct Gradient {
class GradientMakerBase {
public:
GradientMakerBase(const OperatorDef& def,
GradientMakerBase(
const OperatorDef& def,
const vector<string>& g_outputs)
: def(def), g_outputs_(g_outputs), g_inputs_(def.input_size()) {}
: def(def), g_outputs_(g_outputs),
g_inputs_(def.input_size()) {}
virtual ~GradientMakerBase() {}
inline virtual bool CopyDeviceOption() const { return true; }
......@@ -80,7 +83,9 @@ class GradientMakerBase {
};
// implemented in operator.cc
Gradient MakeGradientForOp(const OperatorDef& op_def, const vector<string>& g_outputs);
Gradient MakeGradientForOp(
const OperatorDef& op_def,
const vector<string>& g_outputs);
# define GRADIENT_MAKER_CTOR(name) \
name(const OperatorDef& def, const vector<string>& g_output) \
......@@ -94,12 +99,14 @@ class NoGradient : public GradientMakerBase {
}
};
DECLARE_REGISTRY(GradientRegistry,
DECLARE_REGISTRY(
GradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
DECLARE_REGISTRY(NoGradientRegistry,
DECLARE_REGISTRY(
NoGradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
......
......@@ -22,13 +22,24 @@ namespace dragon {
class OpSchema {
public:
OpSchema()
: op_type_("unknown"), file_("unknown"), line_(0) { Init(); }
OpSchema(const string& op_type, const string& file, const int line)
: op_type_(op_type), file_(file), line_(line) { Init(); }
: op_type_("unknown"), file_("unknown"), line_(0) {
Init();
}
OpSchema(
const string& op_type,
const string& file,
const int line)
: op_type_(op_type), file_(file), line_(line) {
Init();
}
bool Verify(const OperatorDef& def) const;
inline OpSchema& IgnoreVerify() { ignore_verify_ = true; return *this; }
inline OpSchema& IgnoreVerify() {
ignore_verify_ = true;
return *this;
}
OpSchema& Inplace(set<pair<int, int> > inplace);
std::function<bool(int, int)> CheckInplace;
......@@ -56,20 +67,26 @@ class OpSchema {
class OpSchemaRegistry {
public:
static OpSchema& NewSchema(const string& op_type, const string& file, const int line) {
static OpSchema& NewSchema(
const string& op_type,
const string& file,
const int line) {
auto& m = schema_map();
CHECK(!m.count(op_type))
<< "\nOpSchema(" << op_type << ") has registered before."
<< "\nOpSchema(" << op_type
<< ") has registered before."
<< "\nat file: " << file
<< "\n line: " << line;
m.emplace(std::make_pair(op_type, OpSchema(op_type, file, line)));
m.emplace(std::make_pair(op_type,
OpSchema(op_type, file, line)));
return m[op_type];
}
static const OpSchema* Schema(const string& op_type) {
auto& m = schema_map();
if (m.count(op_type)) return &m[op_type];
else LOG(FATAL) << "OpSchema(" << op_type << ") has not registered yet.";
else LOG(FATAL) << "OpSchema(" << op_type
<< ") has not registered yet.";
return nullptr;
}
......
......@@ -20,35 +20,50 @@ namespace dragon {
template <class SrcType, class ObjType, class... Args>
class Registry {
public:
public:
typedef std::function<ObjType*(Args ...)> Creator;
void Register(const SrcType& key, Creator creator) {
CHECK(!registry_.count(key)) << "\nKey(" << key << ") has already registered.";
CHECK(!registry_.count(key))
<< "\nKey(" << key << ") has already registered.";
registry_[key] = creator;
}
ObjType* Create(const SrcType& key, Args ... args) {
CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered yet.";
CHECK(registry_.count(key))
<< "\nKey(" << key << ") has not registered yet.";
return registry_[key](args...);
}
bool Has(const SrcType& key) { return (registry_.count(key)) != 0; }
bool Has(const SrcType& key) {
return (registry_.count(key)) != 0;
}
vector<SrcType> keys() {
vector<SrcType> ret;
for (const auto& it : registry_) ret.push_back(it.first);
for (const auto& it : registry_)
ret.push_back(it.first);
return ret;
}
private:
private:
Map<SrcType, Creator> registry_;
};
template <class SrcType, class ObjType, class... Args>
class Registerer {
public:
Registerer(const SrcType& key, Registry<SrcType, ObjType, Args...>* registry,
typename Registry<SrcType, ObjType, Args...>::Creator creator, const string& help_msg = "") {
public:
Registerer(
const SrcType& key,
Registry<SrcType, ObjType, Args...>* registry,
typename Registry<SrcType, ObjType, Args...>::Creator creator,
const string& help_msg = "") {
registry->Register(key, creator);
}
template <class DerivedType>
static ObjType* defaultCreator(Args ... args) {return new DerivedType(args...);}
static ObjType* defaultCreator(Args ... args) {
return new DerivedType(args...);
}
};
// use in *.h files
......
......@@ -74,10 +74,15 @@ class Tensor {
return ret;
}
inline TIndex count() const { return size_; }
inline TIndex count(const TIndex start) const { return count(start, ndim()); }
inline TIndex count(const TIndex start) const {
return count(start, ndim());
}
inline TIndex offset(const TIndex n, const TIndex c = 0,
const TIndex h = 0, const TIndex w = 0) {
inline TIndex offset(
const TIndex n,
const TIndex c = 0,
const TIndex h = 0,
const TIndex w = 0) {
CHECK_LE(n, dim(0));
CHECK_LE(c, dim(1));
CHECK_LE(h, dim(2));
......@@ -95,7 +100,7 @@ class Tensor {
return offset;
}
inline string dim_string() const {
inline string DimString() const {
if (ndim() == 0) return "(0,)";
std::stringstream ss;
ss << "(";
......@@ -108,9 +113,18 @@ class Tensor {
inline bool is_corrupted() const { return is_corrupted_; }
inline void Corrupt() { is_corrupted_ = true; }
inline bool has_memory() const { return memory_ || ex_memory_ != nullptr; }
MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; }
void set_memory(MixedMemory* mem) { memory_.reset(mem); capacity_ = mem->nbytes(); }
inline bool has_memory() const {
return memory_ || ex_memory_ != nullptr;
}
MixedMemory* memory() const {
return own_mem_ ? memory_.get() : ex_memory_;
}
void set_memory(MixedMemory* mem) {
memory_.reset(mem); capacity_ = mem->nbytes();
}
MixedMemory::State memory_state() const {
MixedMemory* mem = memory();
CHECK(mem) << "\nMemory access before allowcating.";
......@@ -124,7 +138,8 @@ class Tensor {
const TypeMeta& meta() const { return meta_; }
void SetMeta(const TypeMeta& meta) { meta_ = meta; }
template <typename T> inline bool IsType() { return meta_.Match<T>(); }
template <typename T>
inline bool IsType() { return meta_.Match<T>(); }
template <class Context>
void mutable_data_ptr(void** data_ptr) {
......@@ -132,12 +147,15 @@ class Tensor {
if (!mem) {
*data_ptr = nullptr;
} else {
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) {
if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CPUContext>()) {
*data_ptr = mem->mutable_cpu_data();
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) {
} else if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CUDAContext>()) {
*data_ptr = mem->mutable_cuda_data();
} else {
LOG(FATAL) << "Unknown memory type. Only CPU or CUDA is supported.";
LOG(FATAL) << "Unknown memory type. "
<< "Only CPU or CUDA is supported.";
}
}
}
......@@ -146,12 +164,15 @@ class Tensor {
const void* const_data_ptr() const {
MixedMemory* mem = memory();
CHECK(mem) << "\nMemory access before allowcating.";
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) {
if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CPUContext>()) {
return mem->cpu_data();
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) {
} else if (TypeMeta::Id<Context>() ==
TypeMeta::Id<CUDAContext>()) {
return mem->cuda_data();
} else {
LOG(FATAL) << "Unknown memory type. Only CPU or CUDA are supported.";
LOG(FATAL) << "Unknown memory type. "
<< "Only CPU or CUDA are supported.";
return nullptr;
}
}
......@@ -164,10 +185,17 @@ class Tensor {
if (meta_ != meta && data_ptr && !own_mem_) delete ex_memory_;
meta_ = meta;
CHECK_GT(size_, 0);
if (own_mem_) memory_.reset(new MixedMemory(meta, size_* meta_.itemsize()));
else ex_memory_ = new MixedMemory(meta, size_* meta_.itemsize());
mutable_data_ptr<Context>(&data_ptr); // malloc memory
if (meta.ctor()) meta_.ctor()(data_ptr, size_); // call the constructor
if (own_mem_) {
memory_.reset(new MixedMemory(
meta, size_* meta_.itemsize()));
} else {
ex_memory_ = new MixedMemory(
meta, size_* meta_.itemsize());
}
// malloc memory
mutable_data_ptr<Context>(&data_ptr);
// call the constructors
if (meta.ctor()) meta_.ctor()(data_ptr, size_);
capacity_ = size_ * meta.itemsize();
return data_ptr;
}
......@@ -189,8 +217,10 @@ class Tensor {
T* mutable_data() {
void* data_ptr;
mutable_data_ptr<Context>(&data_ptr);
if (data_ptr && meta_ == TypeMeta::Make<T>()) return static_cast<T*>(data_ptr);
return static_cast<T*>(raw_mutable_data<Context>(TypeMeta::Make<T>()));
if (data_ptr && meta_ == TypeMeta::Make<T>())
return static_cast<T*>(data_ptr);
return static_cast<T*>(
raw_mutable_data<Context>(TypeMeta::Make<T>()));
}
template <typename T, class Context>
......@@ -208,9 +238,11 @@ class Tensor {
auto* src = other.template raw_data<SrcCTX>();
auto* dst = raw_mutable_data<DstCTX>(other.meta_);
if (dst == src) return;
if (TypeMeta::Id<DstCTX>() == TypeMeta::Id<CPUContext>()) {
if (TypeMeta::Id<DstCTX>() ==
TypeMeta::Id<CPUContext>()) {
CPUContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src);
} else if (TypeMeta::Id<DstCTX>() == TypeMeta::Id<CUDAContext>()) {
} else if (TypeMeta::Id<DstCTX>() ==
TypeMeta::Id<CUDAContext>()) {
CUDAContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src);
}
}
......
......@@ -82,38 +82,47 @@ class TypeMeta {
template <typename T>
static void Ctor(void* ptr, size_t n) {
T* typed_ptr = static_cast<T*>(ptr);
for (unsigned int i = 0; i < n; i++) new(typed_ptr + i) T;
for (unsigned int i = 0; i < n; i++)
new(typed_ptr + i) T;
}
template <typename T>
static void Copy(const void* src, void* dst, size_t n) {
const T* typed_src = static_cast<const T*>(src);
T* typed_dst = static_cast<T*>(dst);
for (unsigned int i = 0; i < n; i++) typed_dst[i] = typed_src[i];
for (unsigned int i = 0; i < n; i++)
typed_dst[i] = typed_src[i];
}
template <typename T>
static void Dtor(void* ptr, size_t n) {
T* typed_ptr = static_cast<T*>(ptr);
for (unsigned int i = 0; i < n; i++) typed_ptr[i].~T();
for (unsigned int i = 0; i < n; i++)
typed_ptr[i].~T();
}
#define FundMeta std::enable_if<std::is_fundamental<T>::value,TypeMeta>::type
#define StructMeta std::enable_if<!std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, TypeMeta>::type
template <typename T>
static typename FundMeta Make() {
return TypeMeta(Id<T>(), Itemsize<T>(), nullptr, nullptr, nullptr);
return TypeMeta(Id<T>(), Itemsize<T>(),
nullptr, nullptr, nullptr);
}
#define StructMeta std::enable_if<!std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, TypeMeta>::type
template<typename T>
static typename StructMeta Make() {
return TypeMeta(Id<T>(), Itemsize<T>(), Ctor<T>, Copy<T>, Dtor<T>);
return TypeMeta(Id<T>(), Itemsize<T>(),
Ctor<T>, Copy<T>, Dtor<T>);
}
private:
TypeMeta(TypeId id, size_t itemsize,
PlacementNew ctor, TypedCopy copy, TypedDestructor dtor)
TypeMeta(
TypeId id,
size_t itemsize,
PlacementNew ctor,
TypedCopy copy,
TypedDestructor dtor)
: id_(id), itemsize_(itemsize),
ctor_(ctor), copy_(copy), dtor_(dtor) {}
......
......@@ -40,8 +40,10 @@ typedef struct {
#endif
inline const TypeMeta& TypeStringToMeta(const std::string& str_type) {
static std::unordered_map<std::string, TypeMeta> s2m_type_map {
inline const TypeMeta& TypeStringToMeta(
const std::string& str_type) {
static std::unordered_map<std::string, TypeMeta>
s2m_type_map {
{ "float32", TypeMeta::Make<float>() },
{ "int32", TypeMeta::Make<int>() },
{ "int64", TypeMeta::Make<int64_t>() },
......@@ -50,11 +52,14 @@ inline const TypeMeta& TypeStringToMeta(const std::string& str_type) {
{ "uint8", TypeMeta::Make<uint8_t>() }
};
static TypeMeta unknown_type;
return s2m_type_map.count(str_type) ? s2m_type_map[str_type] : unknown_type;
return s2m_type_map.count(str_type) ?
s2m_type_map[str_type] : unknown_type;
}
inline const std::string TypeMetaToString(const TypeMeta& meta) {
static std::unordered_map<TypeId, std::string> m2s_type_map {
inline const std::string TypeMetaToString(
const TypeMeta& meta) {
static std::unordered_map<TypeId, std::string>
m2s_type_map {
{ TypeMeta::Id<float>(), "float32" },
{ TypeMeta::Id<int>(), "int32" },
{ TypeMeta::Id<int64_t>(), "int64" },
......@@ -62,7 +67,8 @@ inline const std::string TypeMetaToString(const TypeMeta& meta) {
{ TypeMeta::Id<float16>(), "float16" },
{ TypeMeta::Id<uint8_t>(), "uint8" }
};
return m2s_type_map.count(meta.id()) ? m2s_type_map[meta.id()] : "unknown";
return m2s_type_map.count(meta.id()) ?
m2s_type_map[meta.id()] : "unknown";
}
} // namespace dragon
......
......@@ -39,13 +39,16 @@ class Workspace {
inline void InitWorkspace() {
CreateTensor("ignore");
Tensor* head = CreateTensor("/opt/mirror_stage/head");
head->Reshape(vector<TIndex>(1, WORKSPACE_MAX_CORRUPTED_SIZE));
Tensor* recompute_flag = CreateTensor("/opt/mirror_stage/recompute_flag");
recompute_flag->Reshape(vector<TIndex>(1, 1));
Tensor* head = CreateTensor(
"/opt/mirror_stage/head");
head->Reshape({ WORKSPACE_MAX_CORRUPTED_SIZE });
Tensor* recompute_flag = CreateTensor(
"/opt/mirror_stage/recompute_flag");
recompute_flag->Reshape({ 1 });
recompute_flag->mutable_data<bool, CPUContext>()[0] = false;
for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) {
string name = "/opt/mirror_stage/buffer_" + dragon_cast<string, int>(i);
string name = "/opt/mirror_stage/buffer_" +
dragon_cast<string, int>(i);
Tensor* buffer = CreateTensor(name);
head->mutable_data<string, CPUContext>()[i] = "";
}
......@@ -72,7 +75,9 @@ class Workspace {
} else { return name; }
}
inline Tensor* TryGetTensor(const string& name, bool use_remote=true) {
inline Tensor* TryGetTensor(
const string& name,
bool use_remote = true) {
string query = GetTensorName(name);
// search local workspace
if (tensor_map_.count(query) > 0)
......@@ -87,7 +92,9 @@ class Workspace {
return nullptr;
}
inline bool HasTensor(const string& name, bool use_remote=true) {
inline bool HasTensor(
const string& name,
bool use_remote = true) {
return TryGetTensor(name, use_remote) ? true : false;
}
......@@ -100,7 +107,9 @@ class Workspace {
return tensor;
}
inline Tensor* GetTensor(const string& name, bool use_remote=true) {
inline Tensor* GetTensor(
const string& name,
bool use_remote = true) {
Tensor* tensor = TryGetTensor(name, use_remote);
CHECK(tensor) << "\nTensor(" << name << ") does not exist "
<< "in current workspace or sub-workspace.";
......@@ -122,14 +131,17 @@ class Workspace {
// serach remote workspace
for (auto& it : workspace_map_) {
vector<string> sub_names = it.second->GetTensors();
names.insert(names.end(), sub_names.begin(), sub_names.end());
names.insert(names.end(),
sub_names.begin(), sub_names.end());
}
return names;
}
/******************** Filler ********************/
inline bool HasFiller(const string& name, bool use_remote=true) {
inline bool HasFiller(
const string& name,
bool use_remote = true) {
// search local workspace
bool result = filler_map_.count(name) > 0;
if (!use_remote) return result;
......@@ -140,14 +152,16 @@ class Workspace {
return result;
}
inline void CreateFiller(const TensorFiller filler) {
inline void CreateFiller(
const TensorFiller filler) {
CHECK_GT(filler.tensor().size(), 0)
<< "Tensor without a valid name can not be filled.";
if (HasFiller(filler.tensor())) return;
filler_map_[filler.tensor()] = filler;
}
inline const TensorFiller* GetFiller(const string& name) {
inline const TensorFiller* GetFiller(
const string& name) {
// search local workspace
if (filler_map_.count(name) > 0)
return &filler_map_[name];
......@@ -163,11 +177,12 @@ class Workspace {
/******************** Cache ********************/
template <class Context>
inline vector<void*> caches(const vector<size_t>& segments) {
inline vector<void*> caches(
const vector<size_t>& segments) {
TIndex total_size = 0;
for (auto& segment : segments) total_size += (TIndex)segment;
Tensor* cacheT = CreateTensor("/share/cache");
cacheT->Reshape(vector<TIndex>(1, total_size));
cacheT->Reshape({ total_size });
vector<void*> caches(segments.size());
caches[0] = cacheT->template mutable_data<uint8_t, Context>();
for (int i = 1; i < segments.size(); i++)
......@@ -176,11 +191,12 @@ class Workspace {
}
template <typename T, class Context>
inline vector<T*> caches(const vector<TIndex>& segments) {
inline vector<T*> caches(
const vector<TIndex>& segments) {
TIndex total_count = 0;
for (auto& segment : segments) total_count += segment;
Tensor* cacheT = CreateTensor("/share/cache");
cacheT->Reshape(vector<TIndex>(1, total_count));
cacheT->Reshape({ total_count });
vector<T*> caches(segments.size());
caches[0] = cacheT->template mutable_data<T, Context>();
for (int i = 1; i < segments.size(); i++)
......@@ -190,12 +206,14 @@ class Workspace {
/******************** Operator ********************/
inline void CreatePersistentOp(const OperatorDef& meta_op) {
inline void CreatePersistentOp(
const OperatorDef& meta_op) {
string persistent_key;
for (auto& arg : meta_op.arg())
if (arg.name() == "persistent_key")
persistent_key = arg.s();
CHECK(persistent_key.size() > 0) << "\nGot empty persistent key.";
CHECK(persistent_key.size() > 0)
<< "\nGot empty persistent key.";
if (!op_map_.count(persistent_key)) {
for (auto& input : meta_op.input()) CreateTensor(input);
op_map_[persistent_key] = unique_ptr<OperatorBase>(
......@@ -203,7 +221,9 @@ class Workspace {
}
}
inline void RunPersistentOp(const string& key, const string& anchor,
inline void RunPersistentOp(
const string& key,
const string& anchor,
const vector<string>& inputs,
const vector<string>& outputs) {
CHECK(op_map_.count(key) > 0)
......@@ -236,11 +256,13 @@ class Workspace {
GraphBase* CreateGraph(const GraphDef& meta_graph);
void RunGraph(const string& graph_name,
void RunGraph(
const string& graph_name,
const string& include,
const string& exclude) {
if (!graph_map_.count(graph_name))
LOG(FATAL) << "Graph(" << graph_name << ") does not exist.";
LOG(FATAL) << "Graph(" << graph_name
<< ") does not exist.";
graph_map_[graph_name]->Run(include, exclude);
}
......@@ -252,7 +274,8 @@ class Workspace {
/******************** Utility ********************/
inline void CreateRename(const string& old_tensor,
inline void CreateRename(
const string& old_tensor,
const string& new_tensor) {
rename_map_[old_tensor] = new_tensor;
}
......
......@@ -20,11 +20,11 @@ namespace dragon {
template <class Context>
class DropoutOp final : public Operator<Context> {
public:
DropoutOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
DropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
use_scale(OperatorBase::Arg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
}
USE_OPERATOR_FUNCTIONS;
......@@ -39,11 +39,11 @@ class DropoutOp final : public Operator<Context> {
template <class Context>
class DropoutGradientOp final : public Operator<Context> {
public:
DropoutGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
DropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
use_scale(OperatorBase::Arg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
}
USE_OPERATOR_FUNCTIONS;
......@@ -66,12 +66,12 @@ DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob);
template <class Context>
class CuDNNDropoutOp final : public Operator<Context> {
public:
CuDNNDropoutOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)),
random_seed(op_def.device_option().random_seed()) {
CuDNNDropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), states_initialized(false),
use_scale(OperatorBase::Arg<bool>("scale", true)),
random_seed(DEFAULT_RNG_SEED) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
}
......@@ -97,12 +97,12 @@ public:
template <class Context>
class CuDNNDropoutGradientOp final : public Operator<Context> {
public:
CuDNNDropoutGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)),
random_seed(op_def.device_option().random_seed()) {
CuDNNDropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), states_initialized(false),
use_scale(OperatorBase::Arg<bool>("scale", true)),
random_seed(DEFAULT_RNG_SEED) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
}
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class EluOp : public Operator<Context> {
public:
EluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
EluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -34,9 +34,9 @@ class EluOp : public Operator<Context> {
template <class Context>
class EluGradientOp : public Operator<Context> {
public:
EluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
EluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
alpha(OperatorBase::Arg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -53,8 +53,8 @@ class EluGradientOp : public Operator<Context> {
template <class Context>
class CuDNNEluOp final : public EluOp<Context> {
public:
CuDNNEluOp(const OperatorDef& op_def, Workspace* ws)
: EluOp<Context>(op_def, ws) {
CuDNNEluOp(const OperatorDef& def, Workspace* ws)
: EluOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......@@ -80,8 +80,8 @@ public:
template <class Context>
class CuDNNEluGradientOp final : public EluGradientOp<Context> {
public:
CuDNNEluGradientOp(const OperatorDef& op_def, Workspace* ws)
: EluGradientOp<Context>(op_def, ws) {
CuDNNEluGradientOp(const OperatorDef& def, Workspace* ws)
: EluGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......
......@@ -17,12 +17,12 @@
namespace dragon {
template <class Context>
class PReluOp : public Operator<Context> {
class PReluOp final : public Operator<Context> {
public:
PReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
PReluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
channel_shared(OperatorBase::Arg<bool>("channel_shared", false)),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -34,12 +34,12 @@ class PReluOp : public Operator<Context> {
};
template <class Context>
class PReluGradientOp : public Operator<Context> {
class PReluGradientOp final : public Operator<Context> {
public:
PReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
PReluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
channel_shared(OperatorBase::Arg<bool>("channel_shared", false)),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class ReluOp : public Operator<Context> {
public:
ReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
ReluOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
slope(OperatorBase::Arg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -34,9 +34,9 @@ class ReluOp : public Operator<Context> {
template <class Context>
class ReluGradientOp : public Operator<Context> {
public:
ReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
ReluGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
slope(OperatorBase::Arg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -51,8 +51,8 @@ class ReluGradientOp : public Operator<Context> {
template <class Context>
class CuDNNReluOp final : public ReluOp<Context> {
public:
CuDNNReluOp(const OperatorDef& op_def, Workspace* ws)
: ReluOp<Context>(op_def, ws) {
CuDNNReluOp(const OperatorDef& def, Workspace* ws)
: ReluOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......@@ -78,8 +78,8 @@ public:
template <class Context>
class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
public:
CuDNNReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: ReluGradientOp<Context>(op_def, ws) {
CuDNNReluGradientOp(const OperatorDef& def, Workspace* ws)
: ReluGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......
......@@ -17,7 +17,7 @@
namespace dragon {
template <class Context>
class SEluOp : public Operator<Context> {
class SEluOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SEluOp);
USE_OPERATOR_FUNCTIONS;
......@@ -27,7 +27,7 @@ class SEluOp : public Operator<Context> {
};
template <class Context>
class SEluGradientOp : public Operator<Context> {
class SEluGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SEluGradientOp);
USE_OPERATOR_FUNCTIONS;
......
......@@ -41,8 +41,8 @@ class SigmoidGradientOp : public Operator<Context> {
template <class Context>
class CuDNNSigmoidOp final : public SigmoidOp<Context> {
public:
CuDNNSigmoidOp(const OperatorDef& op_def, Workspace* ws)
: SigmoidOp<Context>(op_def, ws) {
CuDNNSigmoidOp(const OperatorDef& def, Workspace* ws)
: SigmoidOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......@@ -68,8 +68,8 @@ public:
template <class Context>
class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
public:
CuDNNSigmoidGradientOp(const OperatorDef& op_def, Workspace* ws)
: SigmoidGradientOp<Context>(op_def, ws) {
CuDNNSigmoidGradientOp(const OperatorDef& def, Workspace* ws)
: SigmoidGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class SoftmaxOp final : public Operator<Context> {
public:
SoftmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
SoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -34,9 +34,9 @@ class SoftmaxOp final : public Operator<Context> {
template <class Context>
class SoftmaxGradientOp final : public Operator<Context> {
public:
SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
SoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -53,9 +53,9 @@ class SoftmaxGradientOp final : public Operator<Context> {
template <class Context>
class CuDNNSoftmaxOp final : public Operator<Context> {
public:
CuDNNSoftmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
}
......@@ -78,9 +78,9 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
template <class Context>
class CuDNNSoftmaxGradientOp final : public Operator<Context> {
public:
CuDNNSoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
}
......
......@@ -41,8 +41,8 @@ class TanhGradientOp : public Operator<Context> {
template <class Context>
class CuDNNTanhOp final : public TanhOp<Context> {
public:
CuDNNTanhOp(const OperatorDef& op_def, Workspace* ws)
: TanhOp<Context>(op_def, ws) {
CuDNNTanhOp(const OperatorDef& def, Workspace* ws)
: TanhOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......@@ -68,8 +68,8 @@ public:
template <class Context>
class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
public:
CuDNNTanhGradientOp(const OperatorDef& op_def, Workspace* ws)
: TanhGradientOp<Context>(op_def, ws) {
CuDNNTanhGradientOp(const OperatorDef& def, Workspace* ws)
: TanhGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
......
......@@ -17,12 +17,12 @@
namespace dragon {
template <class Context>
class AffineOp : public Operator<Context> {
class AffineOp final : public Operator<Context> {
public:
AffineOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {}
AffineOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
num_axes(OperatorBase::Arg<int>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -36,10 +36,10 @@ class AffineOp : public Operator<Context> {
template <class Context>
class AffineGradientOp final : public Operator<Context> {
public:
AffineGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
AffineGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -60,10 +60,10 @@ class AffineGradientOp final : public Operator<Context> {
template <class Context>
class CuDNNAffineOpBase : public Operator<Context> {
public:
CuDNNAffineOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {
CuDNNAffineOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc));
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_desc));
......@@ -120,10 +120,10 @@ class CuDNNAffineOpBase : public Operator<Context> {
using CuDNNAffineOpBase<Context>::reduce_desc
template <class Context>
class CuDNNAffineOp : public CuDNNAffineOpBase<Context> {
class CuDNNAffineOp final : public CuDNNAffineOpBase<Context> {
public:
CuDNNAffineOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNAffineOpBase<Context>(op_def, ws) {}
CuDNNAffineOp(const OperatorDef& def, Workspace* ws)
: CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -133,10 +133,10 @@ class CuDNNAffineOp : public CuDNNAffineOpBase<Context> {
};
template <class Context>
class CuDNNAffineGradientOp : public CuDNNAffineOpBase<Context> {
class CuDNNAffineGradientOp final : public CuDNNAffineOpBase<Context> {
public:
CuDNNAffineGradientOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNAffineOpBase<Context>(op_def, ws) {}
CuDNNAffineGradientOp(const OperatorDef& def, Workspace* ws)
: CuDNNAffineOpBase<Context>(def, ws) {}
void RunOnDevice() override;
template <typename T> void ComputeScaleGradient(T* dYxX, T* dA);
......
......@@ -20,10 +20,10 @@ namespace dragon {
template <class Context>
class ClipOp final : public Operator<Context> {
public:
ClipOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)),
high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {}
ClipOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
low(OperatorBase::Arg<float>("low", -FLT_MAX)),
high(OperatorBase::Arg<float>("high", FLT_MAX)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,10 +19,10 @@ namespace dragon {
template <class Context>
class DotOp final : public Operator<Context> {
public:
DotOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
DotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
TransA(OperatorBase::Arg<bool>("TransA", false)),
TransB(OperatorBase::Arg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -37,10 +37,10 @@ class DotOp final : public Operator<Context> {
template <class Context>
class DotGradientOp final : public Operator<Context> {
public:
DotGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
DotGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
TransA(OperatorBase::Arg<bool>("TransA", false)),
TransB(OperatorBase::Arg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,10 +19,10 @@ namespace dragon {
template <class Context>
class EltwiseOp final : public Operator<Context> {
public:
EltwiseOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "SUM")),
coeffs(OperatorBase::GetRepeatedArg<float>("coeffs")) {
EltwiseOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
operation(OperatorBase::Arg<string>("operation", "SUM")),
coeffs(OperatorBase::Args<float>("coeffs")) {
if (coeffs.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize())
<< "\nOp has " << InputSize() << " inputs, "
......@@ -43,10 +43,10 @@ class EltwiseOp final : public Operator<Context> {
template <class Context>
class EltwiseGradientOp final : public Operator<Context> {
public:
EltwiseGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "SUM")),
coeffs(OperatorBase::GetRepeatedArg<float>("coeff")) {
EltwiseGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
operation(OperatorBase::Arg<string>("operation", "SUM")),
coeffs(OperatorBase::Args<float>("coeff")) {
if (coeffs.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize())
<< "\nop has " << InputSize() << " inputs, "
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class GramMatrixOp final : public Operator<Context> {
public:
GramMatrixOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
GramMatrixOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -35,9 +35,9 @@ class GramMatrixOp final : public Operator<Context> {
template <class Context>
class GramMatrixGradientOp final : public Operator<Context> {
public:
GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
GramMatrixGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,13 +17,13 @@
namespace dragon {
template <class Context>
class InnerProductOp: public Operator<Context> {
class InnerProductOp final : public Operator<Context> {
public:
InnerProductOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
TransW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
InnerProductOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
num_output(OperatorBase::Arg<int>("num_output", 0)),
TransW(OperatorBase::Arg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice();
......@@ -37,11 +37,11 @@ class InnerProductOp: public Operator<Context> {
template <class Context>
class InnerProductGradientOp final : public Operator<Context> {
public:
InnerProductGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
TransW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
InnerProductGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
num_output(OperatorBase::Arg<int>("num_output", 0)),
TransW(OperatorBase::Arg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,10 +19,10 @@ namespace dragon {
template <class Context>
class MatmulOp final : public Operator<Context> {
public:
MatmulOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
MatmulOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
TransA(OperatorBase::Arg<bool>("TransA", false)),
TransB(OperatorBase::Arg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -36,10 +36,10 @@ class MatmulOp final : public Operator<Context> {
template <class Context>
class MatmulGradientOp final : public Operator<Context> {
public:
MatmulGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
MatmulGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
TransA(OperatorBase::Arg<bool>("TransA", false)),
TransB(OperatorBase::Arg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,13 +17,13 @@
namespace dragon {
template <class Context>
class PowOp: public Operator<Context> {
class PowOp final : public Operator<Context> {
public:
PowOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
scale(OperatorBase::GetSingleArg<float>("scale", 1.0)),
shift(OperatorBase::GetSingleArg<float>("shift", 0.0)),
power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
PowOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.0)),
shift(OperatorBase::Arg<float>("shift", 0.0)),
power(OperatorBase::Arg<float>("power", 1.0)) {
power_scale = power * scale;
}
USE_OPERATOR_FUNCTIONS;
......@@ -38,11 +38,11 @@ class PowOp: public Operator<Context> {
template <class Context>
class PowGradientOp final : public Operator<Context> {
public:
PowGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
scale(OperatorBase::GetSingleArg<float>("scale", 1.0)),
shift(OperatorBase::GetSingleArg<float>("shift", 0.0)),
power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
PowGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
scale(OperatorBase::Arg<float>("scale", 1.0)),
shift(OperatorBase::Arg<float>("shift", 0.0)),
power(OperatorBase::Arg<float>("power", 1.0)) {
power_scale = power * scale;
}
USE_OPERATOR_FUNCTIONS;
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class CompareOp final : public Operator<Context> {
public:
CompareOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
CompareOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
operation(OperatorBase::Arg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,16 +19,16 @@ namespace dragon {
template <class Context>
class ScanOp final: public Operator<Context> {
public:
ScanOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nsteps(OperatorBase::GetSingleArg<int>("nsteps", 0)),
step_type(OperatorBase::GetSingleArg<string>("step_type", "Static")),
step_tensor(OperatorBase::GetSingleArg<string>("step_tensor", "")),
nseqs(OperatorBase::GetSingleArg<int>("nseqs", 0)),
default_outputs(OperatorBase::GetRepeatedArg<string>("default_outputs")),
ScanOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
nsteps(OperatorBase::Arg<int>("nsteps", 0)),
step_type(OperatorBase::Arg<string>("step_type", "Static")),
step_tensor(OperatorBase::Arg<string>("step_tensor", "")),
nseqs(OperatorBase::Arg<int>("nseqs", 0)),
default_outputs(OperatorBase::Args<string>("default_outputs")),
nout((int)default_outputs.size()),
debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) {
debug_mode(OperatorBase::Arg<bool>("debug_mode", false)) {
InitTemplate();
}
USE_OPERATOR_FUNCTIONS;
......@@ -52,14 +52,14 @@ class ScanOp final: public Operator<Context> {
template <class Context>
class ScanGradientOp final: public Operator<Context> {
public:
ScanGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nsteps(OperatorBase::GetSingleArg<int>("nsteps", 0)),
step_type(OperatorBase::GetSingleArg<string>("step_type", "Static")),
step_tensor(OperatorBase::GetSingleArg<string>("step_tensor", "")),
forward_inputs(OperatorBase::GetRepeatedArg<string>("inputs_name")),
forward_outputs(OperatorBase::GetRepeatedArg<string>("outputs_name")) {
ScanGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
nsteps(OperatorBase::Arg<int>("nsteps", 0)),
step_type(OperatorBase::Arg<string>("step_type", "Static")),
step_tensor(OperatorBase::Arg<string>("step_tensor", "")),
forward_inputs(OperatorBase::Args<string>("inputs_name")),
forward_outputs(OperatorBase::Args<string>("outputs_name")) {
// handle GO(x)
for (int i = 0; i < forward_outputs.size(); i++)
terms[forward_outputs[i] + "_grad"] = Input(i + (int)OutputSize()).name();
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class CTCLossOp final : public Operator<Context> {
public:
CTCLossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
LOG(FATAL) << "CTCLoss requires CuDNN support.";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {}
};
template <class Context>
class CTCLossGradientOp final : public Operator<Context> {
public:
CTCLossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
};
#ifdef WITH_CUDNN
#if CUDNN_VERSION_MIN(7, 0, 0)
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNCTCLossOp final : public Operator<Context> {
public:
CuDNNCTCLossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
blank_first(OperatorBase::Arg<bool>("blank_first", true)),
padding_mask(OperatorBase::Arg<int>("padding_mask", -1)) {
CUDNN_CHECK(cudnnCreateCTCLossDescriptor(&ctc_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&prob_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&grad_desc));
ctc_algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
}
USE_OPERATOR_FUNCTIONS;
~CuDNNCTCLossOp() {
CUDNN_CHECK(cudnnDestroyCTCLossDescriptor(ctc_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(prob_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(grad_desc));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
void WrapIO();
protected:
bool blank_first;
TIndex padding_mask;
cudnnCTCLossAlgo_t ctc_algo;
cudnnCTCLossDescriptor_t ctc_desc;
cudnnTensorDescriptor_t prob_desc, grad_desc;
size_t workspace_size;
vector<int> packed_labels, label_lengths, input_lengths;
};
#endif
#endif // WITH_CUDNN
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
\ No newline at end of file
......@@ -17,11 +17,12 @@
namespace dragon {
template <class Context>
class L1LossOp : public Operator<Context> {
class L1LossOp final : public Operator<Context> {
public:
L1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
L1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -35,9 +36,10 @@ class L1LossOp : public Operator<Context> {
template <class Context>
class L1LossGradientOp final : public Operator<Context> {
public:
L1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
L1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,11 +17,12 @@
namespace dragon {
template <class Context>
class L2LossOp : public Operator<Context> {
class L2LossOp final : public Operator<Context> {
public:
L2LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
L2LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -35,9 +36,10 @@ class L2LossOp : public Operator<Context> {
template <class Context>
class L2LossGradientOp final : public Operator<Context> {
public:
L2LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
L2LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,9 +19,10 @@ namespace dragon {
template <class Context>
class SigmoidCrossEntropyOp final : public Operator<Context> {
public:
SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
SigmoidCrossEntropyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
normalization(OperatorBase::Arg<string>(
"normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -35,9 +36,10 @@ class SigmoidCrossEntropyOp final : public Operator<Context> {
template <class Context>
class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
public:
SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
SigmoidCrossEntropyGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
normalization(OperatorBase::Arg<string>(
"normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,10 +19,11 @@ namespace dragon {
template <class Context>
class SmoothL1LossOp final : public Operator<Context> {
public:
SmoothL1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
SmoothL1LossOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.0)),
normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -37,10 +38,11 @@ class SmoothL1LossOp final : public Operator<Context> {
template <class Context>
class SmoothL1LossGradientOp final : public Operator<Context> {
public:
SmoothL1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
beta(OperatorBase::Arg<float>("beta", 1.0)),
normalization(OperatorBase::Arg<string>(
"normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,11 +19,11 @@ namespace dragon {
template <class Context>
class SoftmaxCrossEntropyOp final : public Operator<Context> {
public:
SoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
}
SoftmaxCrossEntropyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>(
"normalization", "FULL")) {}
USE_OPERATOR_FUNCTIONS;
void SoftmaxRun();
......@@ -41,10 +41,11 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
template <class Context>
class SoftmaxCrossEntropyGradientOp final : public Operator<Context> {
public:
SoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
SoftmaxCrossEntropyGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>(
"normalization", "FULL")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,15 +19,16 @@ namespace dragon {
template <class Context>
class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
public:
SparseSoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels");
if (args.size()) {
ignore.Reshape(vector<TIndex>(1, args.size()));
int* ignore_data = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
SparseSoftmaxCrossEntropyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>(
"normalization", "VALID")) {
vector<int> ignores = OperatorBase::Args<int>("ignore_labels");
if (ignores.size()) {
ignore.Reshape({ (TIndex)ignores.size() });
auto* Idata = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i];
}
}
USE_OPERATOR_FUNCTIONS;
......@@ -49,15 +50,16 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
template <class Context>
class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
public:
SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels");
if (args.size()) {
ignore.Reshape(vector<TIndex>(1, args.size()));
int* ignore_data = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>(
"normalization", "VALID")) {
vector<int> ignores = OperatorBase::Args<int>("ignore_labels");
if (ignores.size()) {
ignore.Reshape({ (TIndex)ignores.size() });
auto* Idata = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i];
}
}
USE_OPERATOR_FUNCTIONS;
......
......@@ -19,13 +19,14 @@ namespace dragon {
template <class Context>
class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Context> {
public:
SparseSoftmaxFocalLossOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
alpha(OperatorBase::GetSingleArg<float>("alpha", 0.5)),
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {
SparseSoftmaxFocalLossOp(const OperatorDef& def, Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>(
"normalization", "VALID")),
alpha(OperatorBase::Arg<float>("alpha", 0.5)),
gamma(OperatorBase::Arg<float>("gamma", 0.0)),
neg_id(OperatorBase::Arg<int>("neg_id", -1)) {
pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0;
}
......@@ -46,13 +47,14 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
template <class Context>
class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyGradientOp<Context> {
public:
SparseSoftmaxFocalLossGradientOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyGradientOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {}
SparseSoftmaxFocalLossGradientOp(const OperatorDef& def, Workspace* ws)
: SparseSoftmaxCrossEntropyGradientOp<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
normalization(OperatorBase::Arg<string>(
"normalization", "VALID")),
gamma(OperatorBase::Arg<float>("gamma", 0.0)),
eps(OperatorBase::Arg<float>("eps", float(1e-10))),
neg_id(OperatorBase::Arg<int>("neg_id", -1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,17 +17,17 @@
namespace dragon {
template <class Context>
class AccuracyOp final: public Operator<Context> {
class AccuracyOp final : public Operator<Context> {
public:
AccuracyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels");
if (args.size()) {
ignore_labels.Reshape(vector<TIndex>(1, args.size()));
int* ignore_data = ignore_labels.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
AccuracyOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
top_k(OperatorBase::Arg<int>("top_k", 1)),
axis(OperatorBase::Arg<int>("axis", 1)) {
vector<int> ignores = OperatorBase::Args<int>("ignore_labels");
if (ignores.size()) {
ignore.Reshape({ (TIndex)ignores.size() });
auto* Idata = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < ignores.size(); i++) Idata[i] = ignores[i];
}
}
USE_OPERATOR_FUNCTIONS;
......@@ -37,7 +37,7 @@ class AccuracyOp final: public Operator<Context> {
protected:
TIndex top_k, axis, outer_dim, inner_dim, num_classes;
Tensor ignore_labels;
Tensor ignore;
};
} // namespace dragon
......
......@@ -19,10 +19,10 @@ namespace dragon {
template <class Context>
class AsTypeOp final : public Operator<Context> {
public:
AsTypeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "float32")),
inplace(OperatorBase::GetSingleArg<bool>("inplace", false)) {}
AsTypeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
dtype(OperatorBase::Arg<string>("dtype", "float32")),
inplace(OperatorBase::Arg<bool>("inplace", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class GradientGenerateOp final: public Operator<Context> {
public:
GradientGenerateOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
defaults(OperatorBase::GetRepeatedArg<float>("defaults")) {
GradientGenerateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
defaults(OperatorBase::Args<float>("defaults")) {
CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize());
}
......@@ -37,8 +37,8 @@ class GradientGenerateOp final: public Operator<Context> {
template <class Context>
class GradientGatherOp final : public Operator<Context> {
public:
GradientGatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
GradientGatherOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
for (int i = 0; i < InputSize(); i++)
if (Input(i).name() != "ignore") indices.push_back(i);
}
......
......@@ -19,23 +19,23 @@ namespace dragon {
template <class Context>
class ImageDataOp final : public Operator<Context> {
public:
ImageDataOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")),
mean_values(OperatorBase::GetRepeatedArg<float>("mean_values")),
std_values(OperatorBase::GetRepeatedArg<float>("std_values")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
ImageDataOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
dtype(OperatorBase::Arg<string>("dtype", "FLOAT32")),
mean_values(OperatorBase::Args<float>("mean_values")),
std_values(OperatorBase::Args<float>("std_values")),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {
if (mean_values.size() > 0) {
CHECK_EQ((int)mean_values.size(), 3)
<< "The mean values should be a list with length 3.";
mean.Reshape(vector<TIndex>(1, 3));
mean.Reshape({ 3 });
for (int i = 0; i < 3; i++)
mean.mutable_data<float, CPUContext>()[i] = mean_values[i];
}
if (std_values.size() > 0) {
CHECK_EQ((int)std_values.size(), 3)
<< "The std values should be a list with length 3.";
std.Reshape(vector<TIndex>(1, 3));
std.Reshape({ 3 });
for (int i = 0; i < 3; i++)
std.mutable_data<float, CPUContext>()[i] = std_values[i];
}
......
......@@ -18,11 +18,11 @@
namespace dragon {
template <class Context>
class InitializeOp: public Operator<Context> {
class InitializeOp : public Operator<Context> {
public:
InitializeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) {
InitializeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
shape_desc(OperatorBase::Arg<string>("shape", "")) {
GET_ARGUMENTS_WITH_DESC(int, dims);
}
USE_OPERATOR_FUNCTIONS;
......@@ -38,11 +38,11 @@ class InitializeOp: public Operator<Context> {
template <class Context>
class FillOp final : public InitializeOp<Context> {
public:
FillOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
public:
FillOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
this->filler.set_type("constant");
this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0));
this->filler.set_value(OperatorBase::Arg<float>("value", 0.0));
}
USE_OPERATOR_FUNCTIONS;
};
......@@ -50,11 +50,11 @@ public:
template <class Context>
class RandomUniformOp final : public InitializeOp<Context> {
public:
RandomUniformOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
RandomUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
this->filler.set_type("uniform");
this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0));
this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0));
this->filler.set_low(OperatorBase::Arg<float>("low", -1.0));
this->filler.set_high(OperatorBase::Arg<float>("high", 1.0));
}
USE_OPERATOR_FUNCTIONS;
};
......@@ -62,11 +62,11 @@ public:
template <class Context>
class RandomNormalOp final : public InitializeOp<Context> {
public:
RandomNormalOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
RandomNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
this->filler.set_type("normal");
this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0));
this->filler.set_mean(OperatorBase::Arg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::Arg<float>("std", 1.0));
}
USE_OPERATOR_FUNCTIONS;
};
......@@ -74,11 +74,11 @@ public:
template <class Context>
class TruncatedNormalOp final : public InitializeOp<Context> {
public:
TruncatedNormalOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
TruncatedNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
this->filler.set_type("truncated_normal");
float mu = OperatorBase::GetSingleArg<float>("mean", 0.0);
float sigma = OperatorBase::GetSingleArg<float>("std", 1.0);
float mu = OperatorBase::Arg<float>("mean", 0.0);
float sigma = OperatorBase::Arg<float>("std", 1.0);
this->filler.set_mean(mu);
this->filler.set_std(sigma);
this->filler.set_low(mu - 2 * sigma);
......@@ -90,10 +90,10 @@ public:
template <class Context>
class GlorotUniformOp final : public InitializeOp<Context> {
public:
GlorotUniformOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
string mode = OperatorBase::GetSingleArg<string>("mode", "fan_in");
float scale = OperatorBase::GetSingleArg<float>("scale", 3.0);
GlorotUniformOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::Arg<float>("scale", 3.0);
this->filler.set_type("xavier");
if (mode == "fan_avg") {
......@@ -111,10 +111,10 @@ public:
template <class Context>
class GlorotNormalOp final : public InitializeOp<Context> {
public:
GlorotNormalOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
string mode = OperatorBase::GetSingleArg<string>("mode", "fan_in");
float scale = OperatorBase::GetSingleArg<float>("scale", 2.0);
GlorotNormalOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
string mode = OperatorBase::Arg<string>("mode", "fan_in");
float scale = OperatorBase::Arg<float>("scale", 2.0);
this->filler.set_type("msra");
if (mode == "fan_avg") {
......
......@@ -23,7 +23,7 @@ namespace dragon {
template <class Context>
class RunOp : public Operator<Context> {
public:
RunOp(const OperatorDef& op_def, Workspace* ws);
RunOp(const OperatorDef& def, Workspace* ws);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -36,16 +36,16 @@ class RunOp : public Operator<Context> {
template <class Context>
class TemplateOp : public RunOp<Context> {
public:
TemplateOp(const OperatorDef& op_def, Workspace* ws)
: RunOp<Context>(op_def, ws) {}
TemplateOp(const OperatorDef& def, Workspace* ws)
: RunOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
class TemplateGradientOp : public TemplateOp<Context> {
class TemplateGradientOp final : public TemplateOp<Context> {
public:
TemplateGradientOp(const OperatorDef& op_def, Workspace* ws)
: TemplateOp<Context>(op_def, ws) {}
TemplateGradientOp(const OperatorDef& def, Workspace* ws)
: TemplateOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -22,11 +22,11 @@ namespace dragon {
template <class Context>
class ModelMPIBase : public Operator<Context> {
public:
ModelMPIBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
comm((MPI_Comm)OperatorBase::GetSingleArg<int64_t>("comm", 0)),
group((MPI_Group)OperatorBase::GetSingleArg<int64_t>("group", 0)),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) {
ModelMPIBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
comm((MPI_Comm)OperatorBase::Arg<int64_t>("comm", 0)),
group((MPI_Group)OperatorBase::Arg<int64_t>("group", 0)),
dtype(OperatorBase::Arg<string>("dtype", "FLOAT32")) {
if (comm == MPI_COMM_NULL) return;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
......@@ -36,7 +36,7 @@ class ModelMPIBase : public Operator<Context> {
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
int world_root = OperatorBase::GetSingleArg<int>("root", 0);
int world_root = OperatorBase::Arg<int>("root", 0);
MPI_Group_translate_ranks(world_group, 1, &world_root, group, &comm_root);
CHECK(comm_root != MPI_UNDEFINED) << "MPI root is not included in layer group.";
......
......@@ -21,8 +21,8 @@ namespace dragon {
template <class Context>
class MPIBroadcastOp final : public ModelMPIBase<Context> {
public:
MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {}
MPIBroadcastOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
......@@ -33,8 +33,8 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
template <class Context>
class MPIBroadcastGradientOp final : public ModelMPIBase<Context> {
public:
MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {}
MPIBroadcastGradientOp(const OperatorDef& def, Workspace* ws)
: ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
......
......@@ -21,8 +21,8 @@ namespace dragon {
template <class Context>
class MPIGatherOp final : public ModelMPIBase<Context> {
public:
MPIGatherOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {}
MPIGatherOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
......@@ -33,8 +33,8 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
template <class Context>
class MPIGatherGradientOp final : public ModelMPIBase<Context> {
public:
MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {}
MPIGatherGradientOp(const OperatorDef& def, Workspace *ws)
: ModelMPIBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class ArangeOp final : public Operator<Context> {
public:
ArangeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) {
ArangeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
dtype(OperatorBase::Arg<string>("dtype", "FLOAT32")) {
GET_ARGUMENT_WITH_DESC(int, start, 0);
GET_ARGUMENT_WITH_DESC(int, stop, 0);
GET_ARGUMENT_WITH_DESC(int, step, 1);
......
......@@ -19,12 +19,12 @@ namespace dragon {
template <class Context>
class ArgReduceOp final : public Operator<Context> {
public:
ArgReduceOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {}
ArgReduceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
operation(OperatorBase::Arg<string>("operation", "NONE")),
keep_dims(OperatorBase::Arg<bool>("keep_dims", false)),
top_k(OperatorBase::Arg<int>("top_k", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,16 +19,17 @@ namespace dragon {
template <class Context>
class ConcatOp : public Operator<Context> {
public:
ConcatOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
ConcatOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim, x_concat_dim, y_concat_dim;
TIndex axis, outer_dim, inner_dim;
TIndex x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset;
vector<TIndex> concat_dims;
};
......@@ -36,16 +37,17 @@ class ConcatOp : public Operator<Context> {
template <class Context>
class ConcatGradientOp : public Operator<Context> {
public:
ConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
ConcatGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim, x_concat_dim, y_concat_dim;
TIndex axis, outer_dim, inner_dim;
TIndex x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset;
vector<TIndex> concat_dims;
};
......
......@@ -17,14 +17,14 @@
namespace dragon {
template <class Context>
class CropOp: public Operator<Context> {
class CropOp final : public Operator<Context> {
public:
CropOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
CropOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
start_axis(OperatorBase::Arg<int>("start_axis", -1)),
offsets(OperatorBase::Args<int>("offsets")),
shape(OperatorBase::Args<int>("shape")),
shape_like(OperatorBase::Arg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, starts);
GET_ARGUMENTS_WITH_DESC(int, ends);
}
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class ExpandDimsOp final : public Operator<Context> {
public:
ExpandDimsOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
ExpandDimsOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,11 +19,11 @@ namespace dragon {
template <class Context>
class FlattenOp final : public Operator<Context> {
public:
FlattenOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
keep_axes(OperatorBase::GetSingleArg<int>("keep_axes", INT_MAX)) {}
FlattenOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)),
keep_axes(OperatorBase::Arg<int>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class GatherOp final : public Operator<Context> {
public:
GatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
GatherOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -35,10 +35,10 @@ class GatherOp final : public Operator<Context> {
template <class Context>
class GatherGradientOp final : public Operator<Context> {
public:
GatherGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
GatherGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
acc_grad(OperatorBase::Arg<bool>("acc_gradient", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,11 +19,11 @@ namespace dragon {
template <class Context>
class OneHotOp final : public Operator < Context > {
public:
OneHotOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
depth(OperatorBase::GetSingleArg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {}
OneHotOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
depth(OperatorBase::Arg<int>("depth", -1)),
on_value(OperatorBase::Arg<int>("on_value", 1)),
off_value(OperatorBase::Arg<int>("off_value", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,12 +19,12 @@ namespace dragon {
template <class Context>
class PadOp final : public Operator<Context> {
public:
PadOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
pad_l(OperatorBase::GetRepeatedArg<int>("pad_l")),
pad_r(OperatorBase::GetRepeatedArg<int>("pad_r")),
mode(OperatorBase::GetSingleArg<string>("mode", "CONSTANT")),
value(OperatorBase::GetSingleArg<float>("value", 0.0f)) {
PadOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
pad_l(OperatorBase::Args<int>("pad_l")),
pad_r(OperatorBase::Args<int>("pad_r")),
mode(OperatorBase::Arg<string>("mode", "CONSTANT")),
value(OperatorBase::Arg<float>("value", 0.0f)) {
if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length.";
......@@ -54,11 +54,11 @@ class PadOp final : public Operator<Context> {
template <class Context>
class PadGradientOp final : public Operator<Context> {
public:
PadGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
pad_l(OperatorBase::GetRepeatedArg<int>("pad_l")),
pad_r(OperatorBase::GetRepeatedArg<int>("pad_r")),
mode(OperatorBase::GetSingleArg<string>("mode", "CONSTANT")) {
PadGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
pad_l(OperatorBase::Args<int>("pad_l")),
pad_r(OperatorBase::Args<int>("pad_r")),
mode(OperatorBase::Arg<string>("mode", "CONSTANT")) {
if (pad_r.size() == 0) pad_r = pad_l;
else CHECK_EQ(pad_l.size(), pad_r.size())
<< "The pad_l and pad_r should have the same length.";
......
......@@ -17,20 +17,20 @@
namespace dragon {
template <class Context>
class RandomPickOp : public Operator<Context> {
class RandomPickOp final : public Operator<Context> {
public:
RandomPickOp(const OperatorDef& op_def, Workspace* ws) :
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
max_samples(OperatorBase::GetSingleArg<int>("max_samples", 1)) {}
RandomPickOp(const OperatorDef& def, Workspace* ws) :
Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
max_samples(OperatorBase::Arg<int>("max_samples", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, max_samples;
TIndex outer_dim, inner_dim, x_slice_dim, y_slice_dim;
TIndex axis, outer_dim, inner_dim, max_samples;
TIndex x_slice_dim, y_slice_dim;
vector<TIndex> output_dims;
Tensor* pick_indices;
};
......@@ -38,17 +38,17 @@ class RandomPickOp : public Operator<Context> {
template <class Context>
class RandomPickGradientOp final : public Operator<Context> {
public:
RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
RandomPickGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis;
TIndex outer_dim, inner_dim, x_slice_dim, y_slice_dim;
TIndex axis, outer_dim, inner_dim;
TIndex x_slice_dim, y_slice_dim;
Tensor* pick_indices;
};
......
......@@ -19,11 +19,11 @@ namespace dragon {
template <class Context>
class ReduceOp final : public Operator<Context> {
public:
ReduceOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {}
ReduceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
operation(OperatorBase::Arg<string>("operation", "NONE")),
keep_dims(OperatorBase::Arg<bool>("keep_dims", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -38,10 +38,10 @@ class ReduceOp final : public Operator<Context> {
template <class Context>
class ReduceGradientOp final : public Operator<Context> {
public:
ReduceGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
ReduceGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
operation(OperatorBase::Arg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,11 +17,11 @@
namespace dragon {
template <class Context>
class RepeatOp : public Operator<Context> {
class RepeatOp final : public Operator<Context> {
public:
RepeatOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
RepeatOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1);
}
USE_OPERATOR_FUNCTIONS;
......@@ -35,11 +35,11 @@ class RepeatOp : public Operator<Context> {
};
template <class Context>
class RepeatGradientOp : public Operator<Context> {
class RepeatGradientOp final : public Operator<Context> {
public:
RepeatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
RepeatGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1);
}
USE_OPERATOR_FUNCTIONS;
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class ReshapeOp final : public Operator<Context> {
public:
ReshapeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) {
ReshapeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
shape_like_desc(OperatorBase::Arg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape);
}
USE_OPERATOR_FUNCTIONS;
......
......@@ -17,12 +17,12 @@
namespace dragon {
template <class Context>
class SliceOp : public Operator<Context> {
class SliceOp final : public Operator<Context> {
public:
SliceOp(const OperatorDef& op_def, Workspace* ws):
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
SliceOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
nout(OperatorBase::Arg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -38,10 +38,10 @@ class SliceOp : public Operator<Context> {
template <class Context>
class SliceGradientOp final : public Operator<Context> {
public:
SliceGradientOp(const OperatorDef& op_def, Workspace* ws):
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
SliceGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 1)),
nout(OperatorBase::Arg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,37 +17,37 @@
namespace dragon {
template <class Context>
class StackOp : public Operator<Context> {
class StackOp final : public Operator<Context> {
public:
StackOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
StackOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, nin, outer_dim, inner_dim, x_concat_dim, y_concat_dim;
TIndex axis, outer_dim, inner_dim;
TIndex x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset;
vector<TIndex> stack_dims, concat_dims;
};
template <class Context>
class StackGradientOp : public Operator<Context> {
class StackGradientOp final : public Operator<Context> {
public:
StackGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
StackGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, nin, outer_dim, inner_dim, x_concat_dim, y_concat_dim;
TIndex axis, outer_dim, inner_dim;
TIndex x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset;
vector<TIndex> concat_dims;
};
......
......@@ -17,10 +17,10 @@
namespace dragon {
template <class Context>
class TileOp : public Operator<Context> {
class TileOp final : public Operator<Context> {
public:
TileOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
TileOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples);
}
USE_OPERATOR_FUNCTIONS;
......@@ -35,10 +35,10 @@ class TileOp : public Operator<Context> {
};
template <class Context>
class TileGradientOp : public Operator<Context> {
class TileGradientOp final : public Operator<Context> {
public:
TileGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
TileGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples);
}
USE_OPERATOR_FUNCTIONS;
......
......@@ -19,11 +19,9 @@ namespace dragon {
template <class Context>
class TransposeOp final: public Operator<Context> {
public:
TransposeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
perms(OperatorBase::GetRepeatedArg<int>("perms")) {
if (perms.size() > 0) reverse_dims = false;
else reverse_dims = true;
TransposeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int, perms);
}
USE_OPERATOR_FUNCTIONS;
......@@ -31,17 +29,17 @@ class TransposeOp final: public Operator<Context> {
template <typename T> void RunWithType();
protected:
vector<int> perms;
bool reverse_dims;
DECLARE_ARGUMENTS_WITH_DESC(int, perms);
Tensor* order, *old_steps, *new_steps;
};
DEFINE_ARGUMENTS_WITH_DESC(int, TransposeOp, perms);
template <class Context>
class TransposeGradientOp final : public Operator<Context> {
public:
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
TransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,15 +19,15 @@
namespace dragon {
template <class Context>
class BatchNormOp : public Operator<Context> {
class BatchNormOp final : public Operator<Context> {
public:
BatchNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
BatchNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
momentum(OperatorBase::Arg<float>("momentum", 0.9f)),
eps(OperatorBase::Arg<float>("eps", 1e-3f)),
use_stats(OperatorBase::Arg<int>("use_stats", -1)),
mode(OperatorBase::Arg<string>("mode", "DEFAULT")) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
......@@ -51,10 +51,10 @@ class BatchNormOp : public Operator<Context> {
template <class Context>
class BatchNormGradientOp final : public Operator<Context> {
public:
BatchNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
BatchNormGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
use_stats(OperatorBase::Arg<int>("use_stats", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
......@@ -77,12 +77,12 @@ class BatchNormGradientOp final : public Operator<Context> {
template <class Context>
class FusedBatchNormOp : public Operator<Context> {
public:
FusedBatchNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
FusedBatchNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
momentum(OperatorBase::Arg<float>("momentum", 0.9f)),
eps(OperatorBase::Arg<float>("eps", 1e-3f)),
use_stats(OperatorBase::Arg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -102,11 +102,11 @@ class FusedBatchNormOp : public Operator<Context> {
template <class Context>
class FusedBatchNormGradientOp : public Operator<Context> {
public:
FusedBatchNormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
FusedBatchNormGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
eps(OperatorBase::Arg<float>("eps", 1e-3f)),
use_stats(OperatorBase::Arg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -132,9 +132,9 @@ class FusedBatchNormGradientOp : public Operator<Context> {
template <class Context>
class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
public:
CuDNNBatchNormOp(const OperatorDef& op_def, Workspace* ws)
: FusedBatchNormOp<Context>(op_def, ws),
eps64(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {
CuDNNBatchNormOp(const OperatorDef& def, Workspace* ws)
: FusedBatchNormOp<Context>(def, ws),
eps64(OperatorBase::Arg<float>("eps", 1e-3f)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
......@@ -169,9 +169,9 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
template <class Context>
class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context> {
public:
CuDNNBatchNormGradientOp(const OperatorDef& op_def, Workspace* ws)
: FusedBatchNormGradientOp<Context>(op_def, ws),
eps64(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {
CuDNNBatchNormGradientOp(const OperatorDef& def, Workspace* ws)
: FusedBatchNormGradientOp<Context>(def, ws),
eps64(OperatorBase::Arg<float>("eps", 1e-3f)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
......
......@@ -17,19 +17,19 @@
namespace dragon {
template <class Context>
class BatchRenormOp : public Operator<Context> {
class BatchRenormOp final : public Operator<Context> {
public:
BatchRenormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)),
r_max(OperatorBase::GetSingleArg<float>("r_max", 3.f)),
d_max(OperatorBase::GetSingleArg<float>("d_max", 5.f)),
t_delta(OperatorBase::GetSingleArg<float>("t_delta", 1.f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
BatchRenormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
momentum(OperatorBase::Arg<float>("momentum", 0.9f)),
eps(OperatorBase::Arg<float>("eps", 1e-3f)),
r_max(OperatorBase::Arg<float>("r_max", 3.f)),
d_max(OperatorBase::Arg<float>("d_max", 5.f)),
t_delta(OperatorBase::Arg<float>("t_delta", 1.f)),
use_stats(OperatorBase::Arg<int>("use_stats", -1)),
t_r_max(1.f), t_d_max(0.f), t_val(0.f),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
mode(OperatorBase::Arg<string>("mode", "DEFAULT")) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
......@@ -55,10 +55,10 @@ class BatchRenormOp : public Operator<Context> {
template <class Context>
class BatchRenormGradientOp final : public Operator<Context> {
public:
BatchRenormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
BatchRenormGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
use_stats(OperatorBase::Arg<int>("use_stats", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
......
......@@ -17,13 +17,13 @@
namespace dragon {
template <class Context>
class GroupNormOp : public Operator<Context> {
class GroupNormOp final : public Operator<Context> {
public:
GroupNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {
GroupNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
group(OperatorBase::Arg<int>("group", 32)),
axis(OperatorBase::Arg<int>("axis", -1)),
eps(OperatorBase::Arg<float>("eps", 1e-3f)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
......@@ -45,10 +45,10 @@ class GroupNormOp : public Operator<Context> {
template <class Context>
class GroupNormGradientOp final : public Operator<Context> {
public:
GroupNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GroupNormGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
group(OperatorBase::Arg<int>("group", 32)),
axis(OperatorBase::Arg<int>("axis", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
......@@ -67,13 +67,13 @@ class GroupNormGradientOp final : public Operator<Context> {
};
template <class Context>
class FusedGroupNormOp : public Operator<Context> {
class FusedGroupNormOp final : public Operator<Context> {
public:
FusedGroupNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {}
FusedGroupNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
group(OperatorBase::Arg<int>("group", 32)),
axis(OperatorBase::Arg<int>("axis", -1)),
eps(OperatorBase::Arg<float>("eps", 1e-3f)) {}
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -89,12 +89,12 @@ class FusedGroupNormOp : public Operator<Context> {
};
template <class Context>
class FusedGroupNormGradientOp : public Operator<Context> {
class FusedGroupNormGradientOp final : public Operator<Context> {
public:
FusedGroupNormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
FusedGroupNormGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
group(OperatorBase::Arg<int>("group", 32)),
axis(OperatorBase::Arg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS;
void Setup();
......
......@@ -17,12 +17,12 @@
namespace dragon {
template <class Context>
class InstanceNormOp : public Operator<Context> {
class InstanceNormOp final : public Operator<Context> {
public:
InstanceNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {
InstanceNormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)),
eps(OperatorBase::Arg<float>("eps", 1e-3f)) {
if (axis != -1)
CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1.";
}
......@@ -43,9 +43,9 @@ class InstanceNormOp : public Operator<Context> {
template <class Context>
class InstanceNormGradientOp final : public Operator<Context> {
public:
InstanceNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
InstanceNormGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1.";
}
......
......@@ -19,12 +19,12 @@ namespace dragon {
template <class Context>
class L2NormOp final : public Operator<Context> {
public:
L2NormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-5f)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
L2NormOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)),
eps(OperatorBase::Arg<float>("eps", 1e-5f)),
mode(OperatorBase::Arg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -42,11 +42,11 @@ class L2NormOp final : public Operator<Context> {
template <class Context>
class L2NormGradientOp final : public Operator<Context> {
public:
L2NormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
L2NormGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int>("axis", 0)),
num_axes(OperatorBase::Arg<int>("num_axes", -1)),
mode(OperatorBase::Arg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -49,29 +49,29 @@ class cudnnTensorDescriptors {
template <class Context>
class CuDNNRecurrentOpBase : public Operator<Context> {
public:
CuDNNRecurrentOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
hidden_size(OperatorBase::GetSingleArg<int>("hidden_size", 0)),
num_layers(OperatorBase::GetSingleArg<int>("num_layers", 1)),
bidirectional(OperatorBase::GetSingleArg<bool>("bidirectional", false)),
dropout_ratio(OperatorBase::GetSingleArg<float>("dropout_ratio", 1.0)),
random_seed(op_def.device_option().random_seed()) {
CuDNNRecurrentOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), states_initialized(false),
hidden_size(OperatorBase::Arg<int>("hidden_size", 0)),
num_layers(OperatorBase::Arg<int>("num_layers", 1)),
bidirectional(OperatorBase::Arg<bool>("bidirectional", false)),
dropout_ratio(OperatorBase::Arg<float>("dropout_ratio", 1.0)),
random_seed(def.device_option().random_seed()) {
// determine the rnn direction
rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// determine the rnn mode
const string mode = OperatorBase::GetSingleArg<string>("rnn_mode", "");
const string mode = OperatorBase::Arg<string>("rnn_mode", "");
if (mode == "rnn_tanh") rnn_mode = CUDNN_RNN_TANH;
else if (mode == "rnn_relu") rnn_mode = CUDNN_RNN_RELU;
else if (mode == "lstm") rnn_mode = CUDNN_LSTM;
else if (mode == "gru") rnn_mode = CUDNN_GRU;
else LOG(FATAL) << "Unsupported rnn mode: " << mode;
// determine the rnn input mode
const string input_mode = OperatorBase::GetSingleArg<string>("rnn_input_mode", "linear");
const string input_mode = OperatorBase::Arg<string>("rnn_input_mode", "linear");
if (input_mode == "skip") rnn_input_mode = CUDNN_SKIP_INPUT;
else if (input_mode == "linear") rnn_input_mode = CUDNN_LINEAR_INPUT;
else LOG(FATAL) << "Unsupported rnn input mode: " << input_mode;
// override the running phase
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
SwitchToPhase(OperatorBase::Arg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
CUDNN_CHECK(cudnnCreateFilterDescriptor(&w_desc));
......@@ -92,8 +92,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc));
}
template <typename T> void ResetDesc(Tensor* X, Tensor* Hx, Tensor* Cx,
Tensor* Y, Tensor* Hy, Tensor* Cy);
template <typename T> void ResetDesc();
public:
TIndex hidden_size, num_layers;
......@@ -109,7 +108,7 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
cudnnFilterDescriptor_t w_desc;
cudnnTensorDescriptor_t hx_desc, cx_desc;
cudnnTensorDescriptor_t hy_desc, cy_desc;
vector<TIndex> input_dims;
vector<TIndex> input_dims, output_dims, hidden_dims;
size_t workspace_size, reserve_size, states_size;
std::unique_ptr<cudnnTensorDescriptors> xs_desc;
......@@ -128,14 +127,16 @@ class CuDNNRecurrentOpBase : public Operator<Context> {
using CuDNNRecurrentOpBase<Context>::xs_desc; \
using CuDNNRecurrentOpBase<Context>::ys_desc; \
using CuDNNRecurrentOpBase<Context>::input_dims; \
using CuDNNRecurrentOpBase<Context>::output_dims; \
using CuDNNRecurrentOpBase<Context>::hidden_dims; \
using CuDNNRecurrentOpBase<Context>::workspace_size; \
using CuDNNRecurrentOpBase<Context>::reserve_size
template <class Context>
class CuDNNRecurrentOp : public CuDNNRecurrentOpBase<Context> {
class CuDNNRecurrentOp final : public CuDNNRecurrentOpBase<Context> {
public:
CuDNNRecurrentOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(op_def, ws) {}
CuDNNRecurrentOp(const OperatorDef& def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(def, ws) {}
USE_CUDNN_RECURRENT_FUNCTIONS;
void RunOnDevice() override;
......@@ -143,10 +144,10 @@ class CuDNNRecurrentOp : public CuDNNRecurrentOpBase<Context> {
};
template <class Context>
class CuDNNRecurrentGradientOp : public CuDNNRecurrentOpBase<Context> {
class CuDNNRecurrentGradientOp final : public CuDNNRecurrentOpBase<Context> {
public:
CuDNNRecurrentGradientOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(op_def, ws) {}
CuDNNRecurrentGradientOp(const OperatorDef& def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(def, ws) {}
USE_CUDNN_RECURRENT_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,10 +17,10 @@
namespace dragon {
template <class Context>
class LSTMCellOp : public Operator<Context> {
class LSTMCellOp final : public Operator<Context> {
public:
LSTMCellOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
LSTMCellOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -29,10 +29,10 @@ class LSTMCellOp : public Operator<Context> {
};
template <class Context>
class LSTMCellGradientOp : public Operator<Context> {
class LSTMCellGradientOp final : public Operator<Context> {
public:
LSTMCellGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
LSTMCellGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -19,8 +19,8 @@ namespace dragon {
template <class Context>
class RecurrentOp : public Operator<Context> {
public:
RecurrentOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
RecurrentOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
LOG(FATAL) << "RNN Operators require CuDNN support.";
}
USE_OPERATOR_FUNCTIONS;
......@@ -31,8 +31,8 @@ class RecurrentOp : public Operator<Context> {
template <class Context>
class RecurrentGradientOp : public Operator<Context> {
public:
RecurrentGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
RecurrentGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
LOG(FATAL) << "RNN Operators require CuDNN support.";
}
USE_OPERATOR_FUNCTIONS;
......
......@@ -17,18 +17,18 @@
namespace dragon {
template <class Context>
class RNNParamSetOp : public Operator<Context> {
class RNNParamSetOp final : public Operator<Context> {
public:
RNNParamSetOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
param_type(OperatorBase::GetSingleArg<string>("param_type", "matrix")),
rnn_mode(OperatorBase::GetSingleArg<string>("rnn_mode", "rnn_tanh")),
num_layers(OperatorBase::GetSingleArg<int>("num_layers", 1)),
num_directions(OperatorBase::GetSingleArg<int>("num_directions", 1)),
input_size(OperatorBase::GetSingleArg<int>("input_size", 0)),
hidden_size(OperatorBase::GetSingleArg<int>("hidden_size", 0)),
layer_id(OperatorBase::GetSingleArg<int>("layer_id", 0)),
param_id(OperatorBase::GetSingleArg<int>("param_id", 0)) {
RNNParamSetOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
param_type(OperatorBase::Arg<string>("param_type", "matrix")),
rnn_mode(OperatorBase::Arg<string>("rnn_mode", "rnn_tanh")),
num_layers(OperatorBase::Arg<int>("num_layers", 1)),
num_directions(OperatorBase::Arg<int>("num_directions", 1)),
input_size(OperatorBase::Arg<int>("input_size", 0)),
hidden_size(OperatorBase::Arg<int>("hidden_size", 0)),
layer_id(OperatorBase::Arg<int>("layer_id", 0)),
param_id(OperatorBase::Arg<int>("param_id", 0)) {
if (rnn_mode == "rnn_tanh") { num_params = 2; spliter = 1; }
else if (rnn_mode == "rnn_relu") { num_params = 2; spliter = 1; }
else if (rnn_mode == "lstm") { num_params = 8; spliter = 4; }
......
......@@ -19,8 +19,8 @@ namespace dragon {
template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), t(0) {}
AdamUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws), t(0) {}
USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context);
......
......@@ -19,16 +19,26 @@ namespace dragon {
#ifdef WITH_MPI
template <class Context>
class CollectiveUpdateOp : public Operator<Context> {
class CollectiveUpdateOp final : public Operator<Context> {
public:
CollectiveUpdateOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
mode(OperatorBase::GetSingleArg<string>("mode", "UNKNOWN")) {
CollectiveUpdateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
mode(OperatorBase::Arg<string>("mode", "UNKNOWN")) {
InitMPI();
if (mode.find("NCCL") != string::npos) InitNCCL();
}
USE_OPERATOR_FUNCTIONS;
~CollectiveUpdateOp() {
/* TODO(PhyscalX): Temporarily disable it,
to avoid a unhandled error. */
#ifdef WITH_MPI_NCCL
if (mode.find("NCCL") != string::npos) {
/* ncclCommDestroy(nccl_comm); */
}
#endif
}
void InitMPI();
void InitNCCL();
......@@ -48,7 +58,7 @@ class CollectiveUpdateOp : public Operator<Context> {
#ifdef WITH_MPI_NCCL
ncclComm_t nccl_comm;
cudaStream_t stream;
CUDAClosure<Context> closure;
#endif
};
......
......@@ -19,9 +19,9 @@ namespace dragon {
template <class Context>
class MovingAverageOp final : public Operator<Context> {
public:
MovingAverageOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {}
MovingAverageOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
decay(OperatorBase::Arg<float>("decay", 1.0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -29,7 +29,6 @@ class MovingAverageOp final : public Operator<Context> {
protected:
float decay;
};
} // namespace dragon
......
......@@ -19,8 +19,8 @@ namespace dragon {
template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws) {}
NesterovUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context);
......
......@@ -19,8 +19,8 @@ namespace dragon {
template <class Context>
class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws) {}
RMSPropUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context);
......
......@@ -19,8 +19,8 @@ namespace dragon {
template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> {
public:
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
SGDUpdateOp(const OperatorDef& def, Workspace* ws)
: UpdateOpBase<Context>(def, ws),
old_lr(-1.f), correction(1.f) {}
USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context);
......
......@@ -19,12 +19,12 @@ namespace dragon {
template <class Context>
class UpdateOpBase : public Operator<Context> {
public:
UpdateOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)),
slot(OperatorBase::GetSingleArg<string>("slot", "")),
zero_grad(OperatorBase::GetSingleArg<bool>("zero_grad", true)) {
UpdateOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
lr_mult(OperatorBase::Arg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::Arg<float>("decay_mult", 1.0)),
slot(OperatorBase::Arg<string>("slot", "")),
zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot";
}
USE_OPERATOR_FUNCTIONS;
......@@ -34,8 +34,13 @@ class UpdateOpBase : public Operator<Context> {
void RunOnDevice() override;
template <typename T> void PreprocessRunWithType();
virtual void ComputeRunWithFloat() = 0;
virtual void ComputeRunWithFloat16() { LOG(FATAL) << "This Updater does not support FP16."; }
virtual void ComputeRunWithFloat16() {
LOG(FATAL) << "This Updater does not support FP16.";
}
template <typename T> void UpdateRunWithType();
protected:
......
......@@ -17,11 +17,11 @@
namespace dragon {
template <class Context>
class BiasAddOp : public Operator<Context> {
class BiasAddOp final : public Operator<Context> {
public:
BiasAddOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
BiasAddOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -35,9 +35,9 @@ class BiasAddOp : public Operator<Context> {
template <class Context>
class BiasAddGradientOp final : public Operator<Context> {
public:
BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
BiasAddGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,14 +17,14 @@
namespace dragon {
template <class Context>
class BilinearResizeOp : public Operator<Context> {
class BilinearResizeOp final : public Operator<Context> {
public:
BilinearResizeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
BilinearResizeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
fy(OperatorBase::Arg<float>("fy", -1.0)),
fx(OperatorBase::Arg<float>("fx", -1.0)),
shape_like_desc(OperatorBase::Arg<string>("shape_like", "")),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {
GET_ARGUMENTS_WITH_DESC(int, dsize);
if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1;
......@@ -43,11 +43,11 @@ class BilinearResizeOp : public Operator<Context> {
};
template <class Context>
class BilinearResizeGradientOp : public Operator<Context> {
class BilinearResizeGradientOp final : public Operator<Context> {
public:
BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
BilinearResizeGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -53,10 +53,11 @@ class Conv2dGradientOp : public Conv2dOp<Context> {
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNConv2dOp : public Conv2dOp<Context> {
class CuDNNConv2dOp final : public Conv2dOp<Context> {
public:
CuDNNConv2dOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws), enable_tensor_core(true) {
: Conv2dOp<Context>(def, ws),
enable_tensor_core(true) {
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
......@@ -64,14 +65,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnn_group = group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
ctx().SwitchToDevice();
for (int g = 0; g < cudnn_group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
}
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
......@@ -90,10 +83,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
for (int g = 0; g < cudnn_group; g++) {
cudaStreamDestroy(stream[g]);
CUDNN_CHECK(cudnnDestroy(handle[g]));
}
}
void RunOnDevice() override;
......@@ -101,8 +90,6 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
template <typename T> void RunWithType();
protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnDataType_t compute_type;
cudnnTensorFormat_t format;
cudnnConvolutionFwdAlgo_t fwd_algo;
......@@ -116,10 +103,11 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
};
template <class Context>
class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
class CuDNNConv2dGradientOp final : public Conv2dGradientOp<Context> {
public:
CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dGradientOp<Context>(def, ws), enable_tensor_core(true) {
: Conv2dGradientOp<Context>(def, ws),
enable_tensor_core(true) {
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
......@@ -127,13 +115,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnn_group = group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
for (int g = 0; g < cudnn_group * 3; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
}
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
......@@ -152,10 +133,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
for (int g = 0; g < cudnn_group * 3; g++) {
cudaStreamDestroy(stream[g]);
CUDNN_CHECK(cudnnDestroy(handle[g]));
}
}
void RunOnDevice() override;
......@@ -163,8 +140,6 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
template <typename T> void RunWithType();
protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnDataType_t compute_type;
cudnnTensorFormat_t format;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
......
......@@ -21,14 +21,14 @@ namespace dragon {
template <class Context>
class ConvOpBase : public Operator<Context> {
public:
ConvOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
num_output(OperatorBase::GetSingleArg<int>("num_output", 1)),
group(OperatorBase::GetSingleArg<int>("group", 1)) {
output_dims_value = OperatorBase::GetRepeatedArg<int>("output_shape");
output_dims_desc = OperatorBase::GetRepeatedArg<string>("output_shape_desc");
ConvOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")),
padding(OperatorBase::Arg<string>("padding", "VALID")),
num_output(OperatorBase::Arg<int>("num_output", 1)),
group(OperatorBase::Arg<int>("group", 1)) {
output_dims_value = OperatorBase::Args<int>("output_shape");
output_dims_desc = OperatorBase::Args<string>("output_shape_desc");
if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
......
......@@ -17,7 +17,7 @@
namespace dragon {
template <class Context>
class Conv2dTransposeOp: public ConvOpBase<Context> {
class Conv2dTransposeOp : public ConvOpBase<Context> {
public:
Conv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {
......@@ -57,7 +57,7 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> {
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
class CuDNNConv2dTransposeOp final : public Conv2dTransposeOp<Context> {
public:
CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws), enable_tensor_core(true) {
......@@ -68,13 +68,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnn_group = group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
for (int g = 0; g < cudnn_group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
}
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
......@@ -93,10 +86,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
for (int g = 0; g < cudnn_group; g++) {
cudaStreamDestroy(stream[g]);
CUDNN_CHECK(cudnnDestroy(handle[g]));
}
}
void RunOnDevice() override;
......@@ -104,8 +93,6 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
template <typename T> void RunWithType();
protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnDataType_t compute_type;
cudnnTensorFormat_t format;
cudnnConvolutionBwdDataAlgo_t fwd_algo;
......@@ -119,7 +106,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
};
template <class Context>
class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context> {
class CuDNNConv2dTransposeGradientOp final : public Conv2dTransposeGradientOp<Context> {
public:
CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeGradientOp<Context>(def, ws), enable_tensor_core(true) {
......@@ -130,13 +117,6 @@ public:
cudnn_group = group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
for (int g = 0; g < cudnn_group * 3; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
}
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
......@@ -155,10 +135,6 @@ public:
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
for (int g = 0; g < cudnn_group * 3; g++) {
cudaStreamDestroy(stream[g]);
CUDNN_CHECK(cudnnDestroy(handle[g]));
}
}
void RunOnDevice() override;
......@@ -166,8 +142,6 @@ public:
template <typename T> void RunWithType();
protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnDataType_t compute_type;
cudnnTensorFormat_t format;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
......
......@@ -19,17 +19,17 @@ namespace dragon {
template <class Context>
class DenseConcatOp final : public ConcatOp<Context> {
public:
DenseConcatOp(const OperatorDef& op_def, Workspace* ws)
: ConcatOp<Context>(op_def, ws) {}
DenseConcatOp(const OperatorDef& def, Workspace* ws)
: ConcatOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
class DenseConcatGradientOp : public ConcatGradientOp<Context> {
class DenseConcatGradientOp final : public ConcatGradientOp<Context> {
public:
DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: ConcatGradientOp<Context>(op_def, ws),
growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {}
DenseConcatGradientOp(const OperatorDef& def, Workspace* ws)
: ConcatGradientOp<Context>(def, ws),
growth_rate(OperatorBase::Arg<int>("growth_rate", 0)) {}
USE_OPERATOR_FUNCTIONS;
void ElimateCorruption() override;
......
......@@ -21,14 +21,14 @@ enum LRNMode { ACROSS_CHANNELS, WITHIN_CHANNEL };
template <class Context>
class LRNOp : public Operator<Context> {
public:
LRNOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
local_size(OperatorBase::GetSingleArg<int>("local_size", 5)),
alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))),
beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))),
k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
LRNOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
local_size(OperatorBase::Arg<int>("local_size", 5)),
alpha(OperatorBase::Arg<float>("alpha", float(0.0001))),
beta(OperatorBase::Arg<float>("beta", float(0.75))),
k(OperatorBase::Arg<float>("k", float(2.0))),
mode(OperatorBase::Arg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -52,14 +52,14 @@ class LRNOp : public Operator<Context> {
template <class Context>
class LRNGradientOp : public Operator<Context> {
public:
LRNGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
local_size(OperatorBase::GetSingleArg<int>("local_size", 5)),
alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))),
beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))),
k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
LRNGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
local_size(OperatorBase::Arg<int>("local_size", 5)),
alpha(OperatorBase::Arg<float>("alpha", float(0.0001))),
beta(OperatorBase::Arg<float>("beta", float(0.75))),
k(OperatorBase::Arg<float>("k", float(2.0))),
mode(OperatorBase::Arg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -85,17 +85,15 @@ class LRNGradientOp : public Operator<Context> {
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNLRNOp : public LRNOp<Context> {
class CuDNNLRNOp final : public LRNOp<Context> {
public:
CuDNNLRNOp(const OperatorDef& op_def, Workspace* ws)
: LRNOp<Context>(op_def, ws) {
CuDNNLRNOp(const OperatorDef& def, Workspace* ws)
: LRNOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc));
CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc, this->local_size,
this->alpha,
this->beta,
this->k));
CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc,
this->local_size, this->alpha, this->beta, this->k));
}
USE_OPERATOR_FUNCTIONS;
......@@ -114,17 +112,15 @@ class CuDNNLRNOp : public LRNOp<Context> {
};
template <class Context>
class CuDNNLRNGradientOp : public LRNGradientOp<Context > {
class CuDNNLRNGradientOp final : public LRNGradientOp<Context > {
public:
CuDNNLRNGradientOp(const OperatorDef& op_def, Workspace* ws) :
LRNGradientOp<Context>(op_def, ws) {
CuDNNLRNGradientOp(const OperatorDef& def, Workspace* ws)
: LRNGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc));
CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc, this->local_size,
this->alpha,
this->beta,
this->k));
CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc,
this->local_size, this->alpha, this->beta, this->k));
}
USE_OPERATOR_FUNCTIONS;
......
......@@ -17,14 +17,14 @@
namespace dragon {
template <class Context>
class NNResizeOp : public Operator<Context> {
class NNResizeOp final : public Operator<Context> {
public:
NNResizeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
NNResizeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
fy(OperatorBase::Arg<float>("fy", -1.0)),
fx(OperatorBase::Arg<float>("fx", -1.0)),
shape_like_desc(OperatorBase::Arg<string>("shape_like", "")),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {
GET_ARGUMENTS_WITH_DESC(int, dsize);
if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1;
......@@ -43,11 +43,11 @@ class NNResizeOp : public Operator<Context> {
};
template <class Context>
class NNResizeGradientOp : public Operator<Context> {
class NNResizeGradientOp final : public Operator<Context> {
public:
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
NNResizeGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -17,18 +17,18 @@
namespace dragon {
template <class Context>
class Pooling2dOp: public Operator <Context> {
class Pooling2dOp : public Operator<Context> {
public:
Pooling2dOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
mode(OperatorBase::GetSingleArg<string>("mode", "MAX")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)),
ceil_mode(OperatorBase::GetSingleArg<bool>("ceil", true)) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size");
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
vector<int> p = OperatorBase::GetRepeatedArg<int>("pad");
Pooling2dOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
mode(OperatorBase::Arg<string>("mode", "MAX")),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")),
padding(OperatorBase::Arg<string>("padding", "VALID")),
global_pooling(OperatorBase::Arg<bool>("global_pooling", false)),
ceil_mode(OperatorBase::Arg<bool>("ceil", true)) {
vector<int> ks = OperatorBase::Args<int>("kernel_size");
vector<int> s = OperatorBase::Args<int>("stride");
vector<int> p = OperatorBase::Args<int>("pad");
for (int i = 0; i < 2; i++) {
if (global_pooling) {
kernel_size.push_back(-1);
......@@ -57,18 +57,18 @@ class Pooling2dOp: public Operator <Context> {
};
template <class Context>
class Pooling2dGradientOp: public Operator<Context> {
class Pooling2dGradientOp : public Operator<Context> {
public:
Pooling2dGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
mode(OperatorBase::GetSingleArg<string>("mode", "MAX")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)),
ceil_mode(OperatorBase::GetSingleArg<bool>("ceil", true)) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size");
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
vector<int> p = OperatorBase::GetRepeatedArg<int>("pad");
Pooling2dGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
mode(OperatorBase::Arg<string>("mode", "MAX")),
data_format(OperatorBase::Arg<string>("data_format", "NCHW")),
padding(OperatorBase::Arg<string>("padding", "VALID")),
global_pooling(OperatorBase::Arg<bool>("global_pooling", false)),
ceil_mode(OperatorBase::Arg<bool>("ceil", true)) {
vector<int> ks = OperatorBase::Args<int>("kernel_size");
vector<int> s = OperatorBase::Args<int>("stride");
vector<int> p = OperatorBase::Args<int>("pad");
for (int i = 0; i < 2; i++) {
if (global_pooling) {
kernel_size.push_back(-1);
......@@ -101,8 +101,8 @@ class Pooling2dGradientOp: public Operator<Context> {
template <class Context>
class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
public:
CuDNNPooling2dOp(const OperatorDef& op_def, Workspace* ws)
: Pooling2dOp<Context>(op_def, ws) {
CuDNNPooling2dOp(const OperatorDef& def, Workspace* ws)
: Pooling2dOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
......@@ -136,8 +136,8 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
template <class Context>
class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> {
public:
CuDNNPooling2dGradientOp(const OperatorDef& op_def, Workspace* ws)
: Pooling2dGradientOp<Context>(op_def, ws) {
CuDNNPooling2dGradientOp(const OperatorDef& def, Workspace* ws)
: Pooling2dGradientOp<Context>(def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
......
......@@ -17,14 +17,14 @@
namespace dragon {
template <class Context>
class ROIAlignOp : public Operator<Context> {
class ROIAlignOp final : public Operator<Context> {
public:
ROIAlignOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)),
sampling_ratio(OperatorBase::GetSingleArg<int>("sampling_ratio", 2)) {
ROIAlignOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
pool_h(OperatorBase::Arg<int>("pool_h", 0)),
pool_w(OperatorBase::Arg<int>("pool_w", 0)),
spatial_scale(OperatorBase::Arg<float>("spatial_scale", 1.0)),
sampling_ratio(OperatorBase::Arg<int>("sampling_ratio", 2)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
......@@ -39,14 +39,14 @@ class ROIAlignOp : public Operator<Context> {
};
template <class Context>
class ROIAlignGradientOp : public Operator<Context> {
class ROIAlignGradientOp final : public Operator<Context> {
public:
ROIAlignGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)),
sampling_ratio(OperatorBase::GetSingleArg<int>("sampling_ratio", 2)) {
ROIAlignGradientOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
pool_h(OperatorBase::Arg<int>("pool_h", 0)),
pool_w(OperatorBase::Arg<int>("pool_w", 0)),
spatial_scale(OperatorBase::Arg<float>("spatial_scale", 1.f)),
sampling_ratio(OperatorBase::Arg<int>("sampling_ratio", 2)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
......
......@@ -17,13 +17,13 @@
namespace dragon {
template <class Context>
class ROIPoolingOp : public Operator<Context> {
class ROIPoolingOp final : public Operator<Context> {
public:
ROIPoolingOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {
ROIPoolingOp(const OperatorDef& def, Workspace *ws)
: Operator<Context>(def, ws),
pool_h(OperatorBase::Arg<int>("pool_h", 0)),
pool_w(OperatorBase::Arg<int>("pool_w", 0)),
spatial_scale(OperatorBase::Arg<float>("spatial_scale", 1.0)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
......@@ -40,11 +40,11 @@ class ROIPoolingOp : public Operator<Context> {
template <class Context>
class ROIPoolingGradientOp final : public Operator<Context> {
public:
ROIPoolingGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {}
ROIPoolingGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
pool_h(OperatorBase::Arg<int>("pool_h", 0)),
pool_w(OperatorBase::Arg<int>("pool_w", 0)),
spatial_scale(OperatorBase::Arg<float>("spatial_scale", 1.f)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......
......@@ -33,19 +33,26 @@ using google::protobuf::io::ZeroCopyInputStream;
using google::protobuf::io::CodedInputStream;
using google::protobuf::io::FileInputStream;
inline void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary);
inline void WriteProtoToBinaryFile(
const Message& proto,
const char* filename) {
std::fstream output(filename,
std::ios::out |
std::ios::trunc |
std::ios::binary);
proto.SerializeToOstream(&output);
}
inline bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
inline bool ReadProtoFromBinaryFile(
const char* filename,
Message* proto) {
#ifdef _MSC_VER
int fd = _open(filename, O_RDONLY | O_BINARY);
#else
int fd = open(filename, O_RDONLY);
#endif
ZeroCopyInputStream *raw_input = new FileInputStream(fd);
CodedInputStream *coded_input = new CodedInputStream(raw_input);
ZeroCopyInputStream* raw_input = new FileInputStream(fd);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(INT_MAX, -1);
bool success = proto->ParseFromCodedStream(coded_input);
delete raw_input;
......@@ -54,7 +61,9 @@ inline bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
return success;
}
inline void LoadCaffeModel(string file, Workspace* ws) {
inline void LoadCaffeModel(
string file,
Workspace* ws) {
NetParameter net_param;
ReadProtoFromBinaryFile(file.c_str(), &net_param);
LOG(INFO) << "Restore From Model @: " << file << "......";
......@@ -73,36 +82,40 @@ inline void LoadCaffeModel(string file, Workspace* ws) {
vector<TIndex> dims;
for (auto dim : blob.shape().dim()) dims.push_back(dim);
Tensor* tensor = ws->GetTensor(tensor_name);
std::stringstream dim_string;
std::stringstream DimString;
if (dims.size() > 0) {
tensor->Reshape(dims);
CHECK_EQ(tensor->count(), blob.data_size())
<< "\nTensor(" << tensor_name << ") "
<< "failed to load, except size: "
<< tensor->count() << ", loaded: " << blob.data_size();
dim_string << tensor->dim_string();
<< tensor->count()
<< ", loaded: " << blob.data_size();
DimString << tensor->DimString();
} else {
tensor->Reshape(vector<TIndex>(1, blob.data_size()));
dim_string << "(missing)";
tensor->Reshape({ blob.data_size() });
DimString << "(missing)";
}
float* Xdata = tensor->mutable_data<float, CPUContext>();
for (int idx = 0; idx < blob.data_size(); idx++)
Xdata[idx] = blob.data(idx);
LOG(INFO) << "Tensor(" << tensor_name << ") "
<< "loaded, shape: " << dim_string.str()
<< "loaded, shape: " << DimString.str()
<< ", size: " << blob.data_size();
}
}
}
}
inline void SavaCaffeModel(string file, const vector<Tensor*>& tensors) {
inline void SavaCaffeModel(
string file,
const vector<Tensor*>& tensors) {
NetParameter net_param;
Map<string, int> layer_hash;
int layer_idx = -1;
for (int i = 0; i < tensors.size(); i++) {
if (tensors[i]->count() <= 0) continue;
vector<string> splits = SplitString(tensors[i]->name(), "/param:");
vector<string> splits = SplitString(
tensors[i]->name(), "/param:");
if (layer_hash.count(splits[0]) == 0) {
layer_hash[splits[0]] = ++layer_idx;
LayerParameter* layer = net_param.add_layer();
......@@ -110,11 +123,14 @@ inline void SavaCaffeModel(string file, const vector<Tensor*>& tensors) {
}
BlobProto* blob = net_param.mutable_layer(layer_idx)->add_blobs();
for (auto dim : tensors[i]->dims()) blob->mutable_shape()->add_dim(dim);
const float* Xdata = tensors[i]->data < float, CPUContext >();
const float* Xdata = tensors[i]->data<float, CPUContext>();
for (int id = 0; id < tensors[i]->count(); id++)
blob->mutable_data()->Add(Xdata[id]);
}
std::fstream output(file, std::ios::out | std::ios::trunc | std::ios::binary);
std::fstream output(file,
std::ios::out |
std::ios::trunc |
std::ios::binary);
CHECK(net_param.SerializeToOstream(&output));
LOG(INFO) << "Save the model @: " << file << "......";
LOG(INFO) << "Model format: caffemodel";
......
......@@ -26,9 +26,13 @@ template<> inline int dragon_cast<int, float>(float val) {
return static_cast<int>(val);
}
template<> inline float dragon_cast<float, float>(float val) { return val; }
template<> inline float dragon_cast<float, float>(float val) {
return val;
}
template<> inline float16 dragon_cast<float16, float16>(float16 val) { return val; }
template<> inline float16 dragon_cast<float16, float16>(float16 val) {
return val;
}
template<> inline float16 dragon_cast<float16, float>(float val) {
float16 ret;
......@@ -67,7 +71,8 @@ template<> inline float16 dragon_cast<float16, float>(float val) {
// Round to nearest even.
remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
if (remainder > lsb_s1 ||
(remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & 0x3ff)) {
++exponent;
......
......@@ -29,9 +29,9 @@ namespace dragon {
#ifdef WITH_CUDA
static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_THREADS = 1024;
// We do have a server with 10 GPUs :-)
#define MAX_GPUS 10
#define CUDA_MAX_DEVICES 10
#define CUDA_VERSION_MIN(major, minor, patch) \
(CUDA_VERSION >= (major * 1000 + minor * 100 + patch))
......@@ -42,7 +42,8 @@ static const int CUDA_NUM_THREADS = 1024;
#define CUDA_CHECK(condition) \
do { \
cudaError_t error = condition; \
CHECK_EQ(error, cudaSuccess) << "\n" << cudaGetErrorString(error); \
CHECK_EQ(error, cudaSuccess) \
<< "\n" << cudaGetErrorString(error); \
} while (0)
#define CUBLAS_CHECK(condition) \
......@@ -61,7 +62,8 @@ static const int CUDA_NUM_THREADS = 1024;
#define NCCL_CHECK(condition) \
do { \
ncclResult_t status = condition; \
CHECK_EQ(status, ncclSuccess) << "\n" << ncclGetErrorString(status); \
CHECK_EQ(status, ncclSuccess) \
<< "\n" << ncclGetErrorString(status); \
} while (0)
#endif // WITH_MPI_NCCL
......@@ -69,12 +71,10 @@ static const int CUDA_NUM_THREADS = 1024;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < n; i += blockDim.x * gridDim.x)
inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
inline int CUDA_BLOCKS(const int N) {
return (N + CUDA_THREADS - 1) / CUDA_THREADS;
}
#define CUDA_POST_KERNEL_CHECK CUDA_CHECK(cudaPeekAtLastError())
#if CUDA_VERSION_MAX(9, 0, 0)
#define __hdiv hdiv
#endif
......@@ -83,18 +83,19 @@ inline int CUDA_NUM_DEVICES() {
static int count = -1;
if (count < 0) {
auto err = cudaGetDeviceCount(&count);
if (err == cudaErrorNoDevice || err == cudaErrorInsufficientDriver) count = 0;
if (err == cudaErrorNoDevice ||
err == cudaErrorInsufficientDriver) count = 0;
}
return count;
}
inline int CUDA_CURRENT_DEVICE() {
inline int CUDA_DEVICE() {
int gpu_id;
cudaGetDevice(&gpu_id);
return gpu_id;
}
inline int CUDA_POINTER_DEVICE(const void* ptr) {
inline int CUDA_DEVICE(const void* ptr) {
cudaPointerAttributes attr;
CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
return attr.device;
......@@ -108,16 +109,18 @@ struct CUDADeviceProps {
vector<cudaDeviceProp> props;
};
inline const cudaDeviceProp& GetDeviceProperty(const int device_id) {
inline const cudaDeviceProp& GetDeviceProperty(
const int device_id) {
static CUDADeviceProps props;
CHECK_LT(device_id, (int)props.props.size())
<< "Invalid device id: " << device_id
<< "\nDetected " << props.props.size() << " eligible cuda devices.";
<< "\nDetected " << props.props.size()
<< " eligible cuda devices.";
return props.props[device_id];
}
inline bool CUDA_TRUE_FP16_AVAILABLE() {
int device = CUDA_CURRENT_DEVICE();
int device = CUDA_DEVICE();
auto& prop = GetDeviceProperty(device);
return prop.major >= 6;
}
......@@ -126,7 +129,7 @@ inline bool TENSOR_CORE_AVAILABLE() {
#if CUDA_VERSION < 9000
return false;
#else
int device = CUDA_CURRENT_DEVICE();
int device = CUDA_DEVICE();
auto& prop = GetDeviceProperty(device);
return prop.major >= 7;
#endif
......@@ -134,11 +137,15 @@ inline bool TENSOR_CORE_AVAILABLE() {
class DeviceGuard {
public:
DeviceGuard(int newDevice) : previous_(CUDA_CURRENT_DEVICE()) {
DeviceGuard(int newDevice)
: previous_(CUDA_DEVICE()) {
if (previous_ != newDevice)
CUDA_CHECK(cudaSetDevice(newDevice));
}
~DeviceGuard() { CUDA_CHECK(cudaSetDevice(previous_)); }
~DeviceGuard() {
CUDA_CHECK(cudaSetDevice(previous_));
}
private:
int previous_;
......
......@@ -21,8 +21,9 @@ namespace dragon {
template <typename T, class Context>
class Filler {
public:
Filler(const TensorFiller& filler): filler_(filler) {}
virtual void Fill(Tensor* tensor) = 0;
Filler(const TensorFiller& filler) : filler_(filler) {}
virtual void Fill(Tensor* tensor, Context* ctx) = 0;
inline TensorFiller& filler() { return filler_; }
......@@ -30,101 +31,125 @@ class Filler {
TensorFiller filler_;
};
template <typename T, class Context>
class ConstantFiller final : public Filler<T, Context> {
public:
ConstantFiller(const TensorFiller& filler): Filler<T, Context>(filler) {}
void Fill(Tensor* tensor) override {
ConstantFiller(const TensorFiller& filler)
: Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
math::Set<T, Context>(tensor->count(),
dragon_cast<T, float>(this->filler().value()),
dragon_cast<T, float>(filler().value()),
tensor->mutable_data<T, Context>());
}
protected:
using Filler<T, Context>::filler;
};
template <typename T, class Context>
class NormalFiller final : public Filler<T, Context> {
public:
NormalFiller(const TensorFiller& filler): Filler<T, Context>(filler) {}
void Fill(Tensor* tensor) override {
NormalFiller(const TensorFiller& filler)
: Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
math::RandomNormal<T, Context>(tensor->count(),
this->filler().mean(),
this->filler().std(),
tensor->mutable_data<T, Context>());
filler().mean(), filler().std(),
tensor->mutable_data<T, Context>(), ctx);
}
protected:
using Filler<T, Context>::filler;
};
template <typename T, class Context>
class TruncatedNormalFiller final : public Filler < T, Context > {
class TruncatedNormalFiller final : public Filler<T, Context> {
public:
TruncatedNormalFiller(const TensorFiller& filler): Filler<T, Context>(filler) {}
void Fill(Tensor* tensor) override {
TruncatedNormalFiller(const TensorFiller& filler)
: Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
// implement it on gpu is difficult
static CPUContext cpu_ctx;
math::RandomTruncatedNormal<T, CPUContext>(tensor->count(),
this->filler().mean(),
this->filler().std(),
this->filler().low(),
this->filler().high(),
tensor->mutable_data<T, CPUContext>());
filler().mean(), filler().std(),
filler().low(), filler().high(),
tensor->mutable_data<T, CPUContext>(), &cpu_ctx);
}
protected:
using Filler<T, Context>::filler;
};
template <typename T, class Context>
class UniformFiller final : public Filler<T, Context> {
public:
UniformFiller(const TensorFiller& filler) : Filler<T, Context>(filler) {}
void Fill(Tensor* tensor) override {
UniformFiller(const TensorFiller& filler)
: Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
math::RandomUniform<T, Context>(tensor->count(),
this->filler().low(),
this->filler().high(),
tensor->mutable_data<T, Context>());
filler().low(), filler().high(),
tensor->mutable_data<T, Context>(), ctx);
}
protected:
using Filler<T, Context>::filler;
};
template <typename T, class Context>
class XavierFiller final : public Filler<T, Context> {
public:
XavierFiller(const TensorFiller& filler) : Filler<T, Context>(filler) {}
using Filler<T, Context>::filler;
void Fill(Tensor* tensor) override {
XavierFiller(const TensorFiller& filler)
: Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
int fan_in = tensor->count() / tensor->dim(0);
int fan_out = tensor->count() / tensor->dim(1);
float n = fan_in, scale = 3.0;
if (filler().has_scale()) scale = filler().scale();
if (filler().variance_norm() == TensorFiller_VarianceNorm_FAN_AVG) {
if (filler().variance_norm() ==
TensorFiller_VarianceNorm_FAN_AVG) {
n = (fan_in + fan_out) / float(2);
} else if (filler().variance_norm() == TensorFiller_VarianceNorm_FAN_OUT) {
} else if (filler().variance_norm() ==
TensorFiller_VarianceNorm_FAN_OUT) {
n = fan_out;
}
float limit = std::sqrt(scale / n);
math::RandomUniform<T, Context>(tensor->count(),
-limit,
limit,
tensor->mutable_data<T, Context>());
-limit, limit, tensor->mutable_data<T, Context>(), ctx);
}
protected:
using Filler<T, Context>::filler;
};
template <typename T, class Context>
class MSRAFiller final : public Filler <T, Context> {
public:
MSRAFiller(const TensorFiller& filler) : Filler<T, Context>(filler) {}
using Filler<T, Context>::filler;
void Fill(Tensor* tensor) override {
MSRAFiller(const TensorFiller& filler)
: Filler<T, Context>(filler) {}
void Fill(Tensor* tensor, Context* ctx) override {
int fan_in = tensor->count() / tensor->dim(0);
int fan_out = tensor->count() / tensor->dim(1);
float n = fan_in, scale = 2.0;
if (filler().has_scale()) scale = filler().scale();
if (filler().variance_norm() == TensorFiller_VarianceNorm_FAN_AVG) {
if (filler().variance_norm() ==
TensorFiller_VarianceNorm_FAN_AVG) {
n = (fan_in + fan_out) / float(2);
} else if (filler().variance_norm() == TensorFiller_VarianceNorm_FAN_OUT) {
} else if (filler().variance_norm() ==
TensorFiller_VarianceNorm_FAN_OUT) {
n = fan_out;
}
float std = std::sqrt(scale / n);
math::RandomNormal<T, Context>(tensor->count(),
float(0),
std,
tensor->mutable_data<T, Context>());
0.f, std, tensor->mutable_data<T, Context>(), ctx);
}
protected:
using Filler<T, Context>::filler;
};
......
......@@ -43,14 +43,16 @@ void RandomUniform(
const int n,
const float low,
const float high,
T* x);
T* x,
Context* ctx);
template <typename T, class Context>
void RandomNormal(
const int n,
const float mu,
const float sigma,
T* x);
T* x,
Context* ctx);
template <typename T, class Context>
void RandomTruncatedNormal(
......@@ -59,13 +61,15 @@ void RandomTruncatedNormal(
const float sigma,
const float low,
const float high,
T* x);
T* x,
Context* ctx);
template <typename T, class Context>
void RandomBernoulli(
const int n,
const float p,
uint32_t* x);
uint32_t* x,
Context* ctx);
/******************** Level-1 ********************/
......@@ -148,14 +152,16 @@ template <typename T, class Context>
void Scal(
const int n,
const float alpha,
T* y);
T* y,
Context* ctx);
template <typename T, class Context>
void Scale(
const int n,
const float alpha,
const T* x,
T* y);
T* y,
Context* ctx);
template <typename T, class Context>
T StridedDot(
......@@ -163,13 +169,15 @@ T StridedDot(
const T* a,
const int incx,
const T* b,
const int incy);
const int incy,
Context* ctx);
template <typename T, class Context>
float Dot(
const int n,
const T* a,
const T* b);
const T* b,
Context* ctx);
template<typename T, class Context>
float ASum(
......@@ -193,7 +201,8 @@ void Axpy(
const int n,
float alpha,
const T* x,
T* y);
T* y,
Context* ctx);
template<typename T, class Context>
void Axpby(
......@@ -201,7 +210,8 @@ void Axpby(
float alpha,
const T* x,
float beta,
T* y);
T* y,
Context* ctx);
/******************** Level-3 ********************/
......@@ -217,6 +227,7 @@ void Gemm(
const T* B,
const float beta,
T* C,
Context* ctx,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
template<typename T, class Context>
......@@ -229,6 +240,7 @@ void Gemv(
const T* x,
const float beta,
T* y,
Context* ctx,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
} // namespace math
......
......@@ -20,9 +20,6 @@ namespace kernel {
typedef int64_t TIndex;
template <typename T, class Context>
void Empty();
/******************** activation.dropout ********************/
template <typename T, class Context>
......@@ -32,7 +29,8 @@ void Dropout(
T scale,
const T* x,
uint32_t* mask,
T* y);
T* y,
Context* ctx);
template <typename T, class Context>
void DropoutGrad(
......@@ -41,7 +39,8 @@ void DropoutGrad(
T scale,
const T* dy,
const uint32_t* mask,
T* dx);
T* dx,
Context* ctx);
/******************** activation.elu ********************/
......@@ -97,7 +96,8 @@ void PReluWGrad(
const T* x,
const T* multiplier,
T* bcast_dw,
T* dw);
T* dw,
Context* ctx);
/******************** activation.relu ********************/
......@@ -157,7 +157,8 @@ void Softmax(
const T* sum_multiplier,
const T* x,
T* scale,
T* y);
T* y,
Context* ctx);
template <typename T, class Context>
void SoftmaxGrad(
......@@ -169,7 +170,8 @@ void SoftmaxGrad(
const T* dy,
const T* y,
T* scale,
T* dx);
T* dx,
Context* ctx);
/******************** activation.tanh ********************/
......@@ -198,7 +200,8 @@ void Affine(
const T* alpha,
const T* beta,
const T* beta_multiplier,
T* y);
T* y,
Context* ctx);
template <typename T, class Context>
void AffineGrad(
......@@ -208,7 +211,8 @@ void AffineGrad(
const int inner_dim,
const T* dy,
const T* alpha,
T* dx);
T* dx,
Context* ctx);
/******************** arithmetic.clip ********************/
......@@ -293,7 +297,8 @@ void SparseSoftmaxCrossEntropy(
const Ty* labels,
Tx* loss,
Tx* valid,
Tensor* ignore);
Tensor* ignore,
Context* ctx);
template <typename Tx, typename Ty, class Context>
void SparseSoftmaxCrossEntropyGrad(
......@@ -305,7 +310,8 @@ void SparseSoftmaxCrossEntropyGrad(
const Ty* labels,
Tx* valid,
Tensor* ignore,
Tx* dx);
Tx* dx,
Context* ctx);
/******************** loss.sparse_softmax_focal_loss ********************/
......@@ -585,7 +591,8 @@ void RepeatGrad(
const int inner_dim,
const int repeats,
const T* dy,
T* dx);
T* dx,
Context* ctx);
/******************** ndarray.slice ********************/
......@@ -629,7 +636,8 @@ void TileGrad(
const int ex_inner_dim,
const int multiple,
const T* dy,
T* dx);
T* dx,
Context* ctx);
/******************** ndarray.transpose ********************/
......@@ -733,7 +741,8 @@ void BiasAdd(
const string& data_format,
const T* bias,
const T* bias_multiplier,
T* y);
T* y,
Context* ctx);
/******************** vision.bilinear_resize ********************/
......
......@@ -22,8 +22,11 @@ namespace dragon {
using google::protobuf::Message;
template <class IterableInputs,class IterableOutputs,class IterableArgs>
inline OperatorDef MakeOperatorDef(const string& type,
template <class IterableInputs,
class IterableOutputs,
class IterableArgs>
inline OperatorDef MakeOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs,
......@@ -36,29 +39,42 @@ inline OperatorDef MakeOperatorDef(const string& type,
for (const string& in : inputs) def.add_input(in);
for (const string& out : outputs) def.add_output(out);
for (const Argument& arg : args) def.add_arg()->CopyFrom(arg);
if (device_option.has_device_type()) def.mutable_device_option()->CopyFrom(device_option);
if (device_option.has_device_type())
def.mutable_device_option()->CopyFrom(device_option);
return def;
}
template <class IterableInputs, class IterableOutputs, class IterableArgs>
inline OperatorDef MakeOperatorDef(const string& type,
template <class IterableInputs,
class IterableOutputs,
class IterableArgs>
inline OperatorDef MakeOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs,
const IterableArgs& args) {
return MakeOperatorDef(type, name, inputs, outputs, args, DeviceOption(), "");
return MakeOperatorDef(
type, name, inputs, outputs, args,
DeviceOption(), "");
}
template <class IterableInputs, class IterableOutputs>
inline OperatorDef MakeOperatorDef(const string& type,
template <class IterableInputs,
class IterableOutputs>
inline OperatorDef MakeOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs) {
return MakeOperatorDef(type, name, inputs, outputs, vector<Argument>(), DeviceOption(), "");
return MakeOperatorDef(
type, name, inputs, outputs,
vector<Argument>(), DeviceOption(), "");
}
inline void ParseProtoFromText(string text, Message* proto) {
google::protobuf::TextFormat::ParseFromString(text, proto);
inline void ParseProtoFromText(
string text,
Message* proto) {
google::protobuf::TextFormat
::ParseFromString(text, proto);
}
} // namespace dragon
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!