Commit 72f3c4ba by Ting PAN

Add License into headers

1 parent 44906e17
Showing with 1300 additions and 566 deletions
# ---------------- Welcom To Use Dragon ---------------- # ---------------- Welcom To Use Dragon ----------------
PROJECT(dragon) PROJECT(dragon)
CMAKE_MINIMUM_REQUIRED(VERSION 3.0.0) CMAKE_MINIMUM_REQUIRED(VERSION 3.0.0)
...@@ -158,7 +158,7 @@ if(WIN32) ...@@ -158,7 +158,7 @@ if(WIN32)
endif() endif()
if(UNIX) if(UNIX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -O2 -m64 -fpermissive -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w -fPIC -O2 -m64 -std=c++11")
if (WITH_OMP) if (WITH_OMP)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_COMMON_H_ #ifndef DRAGON_CORE_COMMON_H_
#define DRAGON_CORE_COMMON_H_ #define DRAGON_CORE_COMMON_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_CONTEXT_H_ #ifndef DRAGON_CORE_CONTEXT_H_
#define DRAGON_CORE_CONTEXT_H_ #define DRAGON_CORE_CONTEXT_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_CONTEXT_CUDA_H_ #ifndef DRAGON_CORE_CONTEXT_CUDA_H_
#define DRAGON_CORE_CONTEXT_CUDA_H_ #define DRAGON_CORE_CONTEXT_CUDA_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_GRAPH_H_ #ifndef DRAGON_CORE_GRAPH_H_
#define DRAGON_CORE_GRAPH_H_ #define DRAGON_CORE_GRAPH_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_GRAPH_GRADIENT_H_ #ifndef DRAGON_CORE_GRAPH_GRADIENT_H_
#define DRAGON_CORE_GRAPH_GRADIENT_H_ #define DRAGON_CORE_GRAPH_GRADIENT_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_MIXEDMEM_H_ #ifndef DRAGON_CORE_MIXEDMEM_H_
#define DRAGON_CORE_MIXEDMEM_H_ #define DRAGON_CORE_MIXEDMEM_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_H_ #ifndef DRAGON_CORE_OPERATOR_H_
#define DRAGON_CORE_OPERATOR_H_ #define DRAGON_CORE_OPERATOR_H_
...@@ -27,8 +32,8 @@ class OperatorBase { ...@@ -27,8 +32,8 @@ class OperatorBase {
OperatorBase(const OperatorDef& op_def, Workspace* ws); OperatorBase(const OperatorDef& op_def, Workspace* ws);
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
Tensor& input(int idx); Tensor& Input(int idx);
Tensor* output(int idx); Tensor* Output(int idx);
inline size_t InputSize() { return inputs_.size(); } inline size_t InputSize() { return inputs_.size(); }
inline size_t OutputSize() { return outputs_.size(); } inline size_t OutputSize() { return outputs_.size(); }
...@@ -55,7 +60,7 @@ class OperatorBase { ...@@ -55,7 +60,7 @@ class OperatorBase {
void set_recompute_map(RecomputeMap recompute_map) { recompute_map_ = recompute_map; } void set_recompute_map(RecomputeMap recompute_map) { recompute_map_ = recompute_map; }
inline const OperatorDef& op_def() const { return op_def_; } inline const OperatorDef& op_def() const { return op_def_; }
inline const string debug_string() const { return op_def_.DebugString(); } inline const string DebugString() const { return op_def_.DebugString(); }
protected: protected:
string phase_; string phase_;
...@@ -73,7 +78,7 @@ class Operator : public OperatorBase { ...@@ -73,7 +78,7 @@ class Operator : public OperatorBase {
: OperatorBase(op_def, ws), ctx_(op_def.device_option()) { : OperatorBase(op_def, ws), ctx_(op_def.device_option()) {
allow_run_ = true; allow_run_ = true;
allow_run_ &= _MPICheck(); allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && output(0)->name() == "ignore")); allow_run_ &= (!(OutputSize() == 1 && Output(0)->name() == "ignore"));
allow_share_grads_ = (!op_def.debug_mode()); allow_share_grads_ = (!op_def.debug_mode());
allow_share_grads_ &= op_def.share_grads(); allow_share_grads_ &= op_def.share_grads();
allow_share_grads_ &= (type().find("Gradient") != string::npos); allow_share_grads_ &= (type().find("Gradient") != string::npos);
...@@ -97,9 +102,9 @@ class Operator : public OperatorBase { ...@@ -97,9 +102,9 @@ class Operator : public OperatorBase {
void MemorySwitch() { void MemorySwitch() {
for (int i = 0; i < InputSize(); i++) for (int i = 0; i < InputSize(); i++)
if (input(i).name() != "ignore") input(i).SwitchToDevice(); if (Input(i).name() != "ignore") Input(i).SwitchToDevice();
for (int i = 0; i < OutputSize(); i++) for (int i = 0; i < OutputSize(); i++)
if (output(i)->name() != "ignore") output(i)->SwitchToDevice(); if (Output(i)->name() != "ignore") Output(i)->SwitchToDevice();
} }
virtual void RunOnDevice() = 0; virtual void RunOnDevice() = 0;
...@@ -135,6 +140,23 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws); ...@@ -135,6 +140,23 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
: Operator<Context>(op_def, ws) {} \ : Operator<Context>(op_def, ws) {} \
virtual ~name() {} virtual ~name() {}
#define USE_OPERATOR_BASE_FUNCTIONS \
using OperatorBase::Input; \
using OperatorBase::Output; \
using OperatorBase::ws; \
using OperatorBase::name; \
using OperatorBase::type; \
using OperatorBase::phase; \
using OperatorBase::op_def; \
using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \
using OperatorBase::DebugString \
#define USE_OPERATOR_FUNCTIONS(context) \
USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<context>::ctx; \
using Operator<context>::anchor
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*); DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
...@@ -242,6 +264,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -242,6 +264,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
INSTANTIATE_OPERATOR(name, CUDAContext); \ INSTANTIATE_OPERATOR(name, CUDAContext); \
#define DEPLOY_CPU_CUDA(name) \ #define DEPLOY_CPU_CUDA(name) \
REGISTER_CPU_OPERATOR(name, name##Op<CPUContext>); \
REGISTER_CUDA_OPERATOR(name, name##Op<CPUContext>); \ REGISTER_CUDA_OPERATOR(name, name##Op<CPUContext>); \
INSTANTIATE_OPERATOR(name, CPUContext); \ INSTANTIATE_OPERATOR(name, CPUContext); \
...@@ -250,4 +273,4 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -250,4 +273,4 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
INSTANTIATE_CUDNN_OPERATOR(name); INSTANTIATE_CUDNN_OPERATOR(name);
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_OPERATOR_H_ #endif // DRAGON_CORE_OPERATOR_H_
\ No newline at end of file
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_ #ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_
#define DRAGON_CORE_OPERATOR_GRADIENT_H_ #define DRAGON_CORE_OPERATOR_GRADIENT_H_
...@@ -18,7 +23,7 @@ struct Gradient { ...@@ -18,7 +23,7 @@ struct Gradient {
vector<OperatorDef> ops; vector<OperatorDef> ops;
vector<string> g_inputs; vector<string> g_inputs;
vector<float> defaults; vector<float> defaults;
Gradient(const vector<OperatorDef>& ops, Gradient(const vector<OperatorDef>& ops,
const vector<string>& g_inputs, const vector<string>& g_inputs,
const vector<float>& defaults) const vector<float>& defaults)
: ops(ops), g_inputs(g_inputs), defaults(defaults) {} : ops(ops), g_inputs(g_inputs), defaults(defaults) {}
...@@ -37,20 +42,20 @@ class GradientMakerBase { ...@@ -37,20 +42,20 @@ class GradientMakerBase {
inline virtual Gradient Make() { inline virtual Gradient Make() {
vector<OperatorDef> new_defs = MakeDefs(); vector<OperatorDef> new_defs = MakeDefs();
Argument anchor; Argument anchor;
anchor.set_name("anchor"); anchor.set_s(def.name()); anchor.set_name("anchor"); anchor.set_s(def.name());
for (int i = 0; i < new_defs.size(); i++) for (int i = 0; i < new_defs.size(); i++)
new_defs[i].add_arg()->CopyFrom(anchor); new_defs[i].add_arg()->CopyFrom(anchor);
return Gradient(new_defs, g_inputs_, DefaultValues()); return Gradient(new_defs, g_inputs_, DefaultValues());
}; };
virtual inline vector<OperatorDef> MakeDefs() { virtual inline vector<OperatorDef> MakeDefs() {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
return vector<OperatorDef>(); return vector<OperatorDef>();
} }
virtual inline vector<float> DefaultValues() { virtual inline vector<float> DefaultValues() {
return vector<float>(g_outputs_.size(), 1.0); return vector<float>(g_outputs_.size(), 1.0);
} }
template <class... Args> template <class... Args>
...@@ -79,24 +84,24 @@ Gradient MakeGradientForOp(const OperatorDef& op_def, const vector<string>& g_ou ...@@ -79,24 +84,24 @@ Gradient MakeGradientForOp(const OperatorDef& op_def, const vector<string>& g_ou
# define GRADIENT_MAKER_CTOR(name) \ # define GRADIENT_MAKER_CTOR(name) \
name(const OperatorDef& def, const vector<string>& g_output) \ name(const OperatorDef& def, const vector<string>& g_output) \
: GradientMakerBase(def, g_output) {} : GradientMakerBase(def, g_output) {}
class NoGradient : public GradientMakerBase { class NoGradient : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(NoGradient); GRADIENT_MAKER_CTOR(NoGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return vector<OperatorDef>(); return vector<OperatorDef>();
} }
}; };
DECLARE_REGISTRY(GradientRegistry, DECLARE_REGISTRY(GradientRegistry,
GradientMakerBase, GradientMakerBase,
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&);
DECLARE_REGISTRY(NoGradientRegistry, DECLARE_REGISTRY(NoGradientRegistry,
GradientMakerBase, GradientMakerBase,
const OperatorDef&, const OperatorDef&,
const vector<string>&); const vector<string>&);
// define in the operator.cc // define in the operator.cc
...@@ -109,4 +114,4 @@ DECLARE_REGISTRY(NoGradientRegistry, ...@@ -109,4 +114,4 @@ DECLARE_REGISTRY(NoGradientRegistry,
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_OPERATOR_GRADIENT_H_ #endif // DRAGON_CORE_OPERATOR_GRADIENT_H_
\ No newline at end of file
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_ #ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_
#define DRAGON_CORE_OPERATOR_SCHEMA_H_ #define DRAGON_CORE_OPERATOR_SCHEMA_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_REGISTRY_H_ #ifndef DRAGON_CORE_REGISTRY_H_
#define DRAGON_CORE_REGISTRY_H_ #define DRAGON_CORE_REGISTRY_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_TENSOR_H_ #ifndef DRAGON_CORE_TENSOR_H_
#define DRAONG_CORE_TENSOR_H_ #define DRAONG_CORE_TENSOR_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_TYPEID_H_ #ifndef DRAGON_CORE_TYPEID_H_
#define DRAGON_CORE_TYPEID_H_ #define DRAGON_CORE_TYPEID_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_TYPES_H_ #ifndef DRAGON_CORE_TYPES_H_
#define DRAGON_CORE_TYPES_H_ #define DRAGON_CORE_TYPES_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_WORKSPACE_H_ #ifndef DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_ #define DRAGON_CORE_WORKSPACE_H_
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
...@@ -20,6 +25,7 @@ class DropoutOp final : public Operator<Context> { ...@@ -20,6 +25,7 @@ class DropoutOp final : public Operator<Context> {
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) { use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -39,6 +45,7 @@ class DropoutGradientOp final : public Operator<Context> { ...@@ -39,6 +45,7 @@ class DropoutGradientOp final : public Operator<Context> {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
...@@ -17,6 +22,7 @@ class EluOp : public Operator<Context> { ...@@ -17,6 +22,7 @@ class EluOp : public Operator<Context> {
EluOp(const OperatorDef& op_def, Workspace* ws) EluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {} alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -33,6 +39,7 @@ class EluGradientOp : public Operator<Context> { ...@@ -33,6 +39,7 @@ class EluGradientOp : public Operator<Context> {
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) { alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -56,6 +63,7 @@ public: ...@@ -56,6 +63,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha)); CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNEluOp() { ~CuDNNEluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -82,6 +90,7 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> { ...@@ -82,6 +90,7 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha)); CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNEluGradientOp() { ~CuDNNEluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
...@@ -18,6 +23,7 @@ class PReluOp : public Operator<Context> { ...@@ -18,6 +23,7 @@ class PReluOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)), channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -35,6 +41,7 @@ class PReluGradientOp : public Operator<Context> { ...@@ -35,6 +41,7 @@ class PReluGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)), channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
...@@ -17,6 +22,7 @@ class ReluOp : public Operator<Context> { ...@@ -17,6 +22,7 @@ class ReluOp : public Operator<Context> {
ReluOp(const OperatorDef& op_def, Workspace* ws) ReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {} slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -33,6 +39,7 @@ class ReluGradientOp : public Operator<Context> { ...@@ -33,6 +39,7 @@ class ReluGradientOp : public Operator<Context> {
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) { slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -54,6 +61,7 @@ public: ...@@ -54,6 +61,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNReluOp() { ~CuDNNReluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -80,6 +88,7 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> { ...@@ -80,6 +88,7 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNReluGradientOp() { ~CuDNNReluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
...@@ -16,6 +21,7 @@ class SEluOp : public Operator<Context> { ...@@ -16,6 +21,7 @@ class SEluOp : public Operator<Context> {
public: public:
SEluOp(const OperatorDef& op_def, Workspace* ws) SEluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -28,6 +34,7 @@ class SEluGradientOp : public Operator<Context> { ...@@ -28,6 +34,7 @@ class SEluGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP #ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
#define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP #define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class SigmoidOp : public Operator<Context> { class SigmoidOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SigmoidOp); USE_SIMPLE_CTOR_DTOR(SigmoidOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -27,6 +33,7 @@ class SigmoidGradientOp : public Operator<Context> { ...@@ -27,6 +33,7 @@ class SigmoidGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -45,6 +52,7 @@ public: ...@@ -45,6 +52,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNSigmoidOp() { ~CuDNNSigmoidOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -71,6 +79,7 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> { ...@@ -71,6 +79,7 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNSigmoidGradientOp() { ~CuDNNSigmoidGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
...@@ -17,6 +22,7 @@ class SoftmaxOp final : public Operator<Context> { ...@@ -17,6 +22,7 @@ class SoftmaxOp final : public Operator<Context> {
SoftmaxOp(const OperatorDef& op_def, Workspace* ws) SoftmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -35,6 +41,7 @@ class SoftmaxGradientOp final : public Operator<Context> { ...@@ -35,6 +41,7 @@ class SoftmaxGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)) { axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -58,6 +65,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> { ...@@ -58,6 +65,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNSoftmaxOp() { ~CuDNNSoftmaxOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -82,6 +90,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> { ...@@ -82,6 +90,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNSoftmaxGradientOp() { ~CuDNNSoftmaxGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_ #ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_ #define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class TanhOp : public Operator<Context> { class TanhOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(TanhOp); USE_SIMPLE_CTOR_DTOR(TanhOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -27,6 +33,7 @@ class TanhGradientOp : public Operator<Context> { ...@@ -27,6 +33,7 @@ class TanhGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -45,6 +52,7 @@ public: ...@@ -45,6 +52,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNTanhOp() { ~CuDNNTanhOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -71,6 +79,7 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> { ...@@ -71,6 +79,7 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNTanhGradientOp() { ~CuDNNTanhGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class AddOp final : public Operator<Context> { class AddOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(AddOp); USE_SIMPLE_CTOR_DTOR(AddOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
...@@ -28,6 +34,7 @@ template <class Context> ...@@ -28,6 +34,7 @@ template <class Context>
class AddGradientOp final : public Operator<Context> { class AddGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(AddGradientOp); USE_SIMPLE_CTOR_DTOR(AddGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -42,6 +49,7 @@ template <class Context> ...@@ -42,6 +49,7 @@ template <class Context>
class RAddOp final : public Operator<Context> { class RAddOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RAddOp); USE_SIMPLE_CTOR_DTOR(RAddOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
...@@ -55,6 +63,7 @@ template <class Context> ...@@ -55,6 +63,7 @@ template <class Context>
class RAddGradientOp final : public Operator<Context> { class RAddGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RAddGradientOp); USE_SIMPLE_CTOR_DTOR(RAddGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
...@@ -17,6 +22,7 @@ class BiasAddOp : public Operator<Context> { ...@@ -17,6 +22,7 @@ class BiasAddOp : public Operator<Context> {
BiasAddOp(const OperatorDef& op_def, Workspace* ws) BiasAddOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -35,6 +41,7 @@ class BiasAddGradientOp final : public Operator<Context> { ...@@ -35,6 +41,7 @@ class BiasAddGradientOp final : public Operator<Context> {
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
...@@ -19,6 +24,7 @@ class ClipOp final : public Operator<Context> { ...@@ -19,6 +24,7 @@ class ClipOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)), low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)),
high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {} high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -32,6 +38,7 @@ template <class Context> ...@@ -32,6 +38,7 @@ template <class Context>
class ClipGradientOp final : public Operator<Context> { class ClipGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ClipGradientOp); USE_SIMPLE_CTOR_DTOR(ClipGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class DivOp final : public Operator<Context> { class DivOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(DivOp); USE_SIMPLE_CTOR_DTOR(DivOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
...@@ -28,6 +34,7 @@ template <class Context> ...@@ -28,6 +34,7 @@ template <class Context>
class DivGradientOp final : public Operator<Context> { class DivGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(DivGradientOp); USE_SIMPLE_CTOR_DTOR(DivGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -42,6 +49,7 @@ template <class Context> ...@@ -42,6 +49,7 @@ template <class Context>
class RDivOp final : public Operator<Context> { class RDivOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RDivOp); USE_SIMPLE_CTOR_DTOR(RDivOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
...@@ -55,6 +63,7 @@ template <class Context> ...@@ -55,6 +63,7 @@ template <class Context>
class RDivGradientOp final : public Operator<Context> { class RDivGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RDivGradientOp); USE_SIMPLE_CTOR_DTOR(RDivGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
...@@ -18,6 +23,7 @@ class DotOp final : public Operator<Context> { ...@@ -18,6 +23,7 @@ class DotOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void DotRunWithType(); template <typename T> void DotRunWithType();
...@@ -36,6 +42,7 @@ class DotGradientOp final : public Operator<Context> { ...@@ -36,6 +42,7 @@ class DotGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
...@@ -24,6 +29,7 @@ class EltwiseOp final : public Operator<Context> { ...@@ -24,6 +29,7 @@ class EltwiseOp final : public Operator<Context> {
<< "but provided " << coeffs.size() << " coeffs."; << "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1)); } else coeffs.resize(InputSize(), float(1));
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
...@@ -47,6 +53,7 @@ class EltwiseGradientOp final : public Operator<Context> { ...@@ -47,6 +53,7 @@ class EltwiseGradientOp final : public Operator<Context> {
<< "but provided " << coeffs.size() << " coeffs."; << "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1)); } else coeffs.resize(InputSize(), float(1));
} }
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class ExpOp final : public Operator<Context> { class ExpOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpOp); USE_SIMPLE_CTOR_DTOR(ExpOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -24,6 +30,7 @@ template <class Context> ...@@ -24,6 +30,7 @@ template <class Context>
class ExpGradientOp final : public Operator<Context> { class ExpGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpGradientOp); USE_SIMPLE_CTOR_DTOR(ExpGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
...@@ -17,6 +22,7 @@ class GramMatrixOp final : public Operator<Context> { ...@@ -17,6 +22,7 @@ class GramMatrixOp final : public Operator<Context> {
GramMatrixOp(const OperatorDef& op_def, Workspace* ws) GramMatrixOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -32,6 +38,7 @@ class GramMatrixGradientOp final : public Operator<Context> { ...@@ -32,6 +38,7 @@ class GramMatrixGradientOp final : public Operator<Context> {
GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws) GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
...@@ -19,6 +24,7 @@ class InnerProductOp: public Operator<Context> { ...@@ -19,6 +24,7 @@ class InnerProductOp: public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)), num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {} transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice(); void RunOnDevice();
template <typename T> void TransRunWithType(); template <typename T> void TransRunWithType();
...@@ -38,6 +44,7 @@ class InnerProductGradientOp final : public Operator<Context> { ...@@ -38,6 +44,7 @@ class InnerProductGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)), num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {} transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class LogOp final : public Operator<Context> { class LogOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(LogOp); USE_SIMPLE_CTOR_DTOR(LogOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -24,6 +30,7 @@ template <class Context> ...@@ -24,6 +30,7 @@ template <class Context>
class LogGradientOp final : public Operator<Context> { class LogGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(LogGradientOp); USE_SIMPLE_CTOR_DTOR(LogGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
...@@ -18,6 +23,7 @@ class MatmulOp final : public Operator<Context> { ...@@ -18,6 +23,7 @@ class MatmulOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -35,6 +41,7 @@ class MatmulGradientOp final : public Operator<Context> { ...@@ -35,6 +41,7 @@ class MatmulGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class MulOp final : public Operator<Context> { class MulOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MulOp); USE_SIMPLE_CTOR_DTOR(MulOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
...@@ -28,6 +34,7 @@ template <class Context> ...@@ -28,6 +34,7 @@ template <class Context>
class MulGradientOp final : public Operator<Context> { class MulGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MulGradientOp); USE_SIMPLE_CTOR_DTOR(MulGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -42,6 +49,7 @@ template <class Context> ...@@ -42,6 +49,7 @@ template <class Context>
class RMulOp final : public Operator<Context> { class RMulOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RMulOp); USE_SIMPLE_CTOR_DTOR(RMulOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
...@@ -55,6 +63,7 @@ template <class Context> ...@@ -55,6 +63,7 @@ template <class Context>
class RMulGradientOp final : public Operator<Context> { class RMulGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RMulGradientOp); USE_SIMPLE_CTOR_DTOR(RMulGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
...@@ -21,6 +26,7 @@ class PowOp: public Operator<Context> { ...@@ -21,6 +26,7 @@ class PowOp: public Operator<Context> {
power(OperatorBase::GetSingleArg<float>("power", 1.0)) { power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -39,6 +45,7 @@ class PowGradientOp final : public Operator<Context> { ...@@ -39,6 +45,7 @@ class PowGradientOp final : public Operator<Context> {
power(OperatorBase::GetSingleArg<float>("power", 1.0)) { power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
...@@ -18,6 +23,7 @@ class ScaleOp : public Operator<Context> { ...@@ -18,6 +23,7 @@ class ScaleOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {} num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -35,6 +41,7 @@ class ScaleGradientOp final : public Operator<Context> { ...@@ -35,6 +41,7 @@ class ScaleGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {} num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void BiasRunWithType(); template <typename T> void BiasRunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class SquareOp final : public Operator<Context> { class SquareOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SquareOp); USE_SIMPLE_CTOR_DTOR(SquareOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -24,6 +30,7 @@ template <class Context> ...@@ -24,6 +30,7 @@ template <class Context>
class SquareGradientOp final : public Operator<Context> { class SquareGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SquareGradientOp); USE_SIMPLE_CTOR_DTOR(SquareGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_ #ifndef DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_ #define DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class SubOp final : public Operator<Context> { class SubOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SubOp); USE_SIMPLE_CTOR_DTOR(SubOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
...@@ -28,6 +34,7 @@ template <class Context> ...@@ -28,6 +34,7 @@ template <class Context>
class SubGradientOp final : public Operator<Context> { class SubGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SubGradientOp); USE_SIMPLE_CTOR_DTOR(SubGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -42,6 +49,7 @@ template <class Context> ...@@ -42,6 +49,7 @@ template <class Context>
class RSubOp final : public Operator<Context> { class RSubOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RSubOp); USE_SIMPLE_CTOR_DTOR(RSubOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
...@@ -55,6 +63,7 @@ template <class Context> ...@@ -55,6 +63,7 @@ template <class Context>
class RSubGradientOp final : public Operator<Context> { class RSubGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RSubGradientOp); USE_SIMPLE_CTOR_DTOR(RSubGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_CAST_FLOAT2HALF_OP_H_ #ifndef DRAGON_OPERATORS_CAST_FLOAT2HALF_OP_H_
#define DRAGON_OPERATORS_CAST_FLOAT2HALF_OP_H_ #define DRAGON_OPERATORS_CAST_FLOAT2HALF_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class FloatToHalfOp final : public Operator<Context> { class FloatToHalfOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(FloatToHalfOp); USE_SIMPLE_CTOR_DTOR(FloatToHalfOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
...@@ -17,6 +22,7 @@ class CompareOp final : public Operator<Context> { ...@@ -17,6 +22,7 @@ class CompareOp final : public Operator<Context> {
CompareOp(const OperatorDef& op_def, Workspace* ws) CompareOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {} operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EqualRunWithType(); template <typename T> void EqualRunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
...@@ -15,6 +20,8 @@ template <class Context> ...@@ -15,6 +20,8 @@ template <class Context>
class CopyOp final : public Operator<Context> { class CopyOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(CopyOp); USE_SIMPLE_CTOR_DTOR(CopyOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_ #ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_ #define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
...@@ -26,6 +31,7 @@ class ScanOp final: public Operator<Context> { ...@@ -26,6 +31,7 @@ class ScanOp final: public Operator<Context> {
debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) { debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) {
InitTemplate(); InitTemplate();
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
void InitTemplate(); void InitTemplate();
...@@ -56,14 +62,15 @@ class ScanGradientOp final: public Operator<Context> { ...@@ -56,14 +62,15 @@ class ScanGradientOp final: public Operator<Context> {
forward_outputs(OperatorBase::GetRepeatedArg<string>("outputs_name")) { forward_outputs(OperatorBase::GetRepeatedArg<string>("outputs_name")) {
// handle GO(x) // handle GO(x)
for (int i = 0; i < forward_outputs.size(); i++) for (int i = 0; i < forward_outputs.size(); i++)
terms[forward_outputs[i] + "_grad"] = input(i + (int)OutputSize()).name(); terms[forward_outputs[i] + "_grad"] = Input(i + (int)OutputSize()).name();
// handle GI(x) // handle GI(x)
for (int i = 0; i < forward_inputs.size(); i++) for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = output(i)->name(); terms[forward_inputs[i] + "_grad"] = Output(i)->name();
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
void MakeGradientOps(); void MakeGradientOps();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
...@@ -17,6 +22,7 @@ class L1LossOp : public Operator<Context> { ...@@ -17,6 +22,7 @@ class L1LossOp : public Operator<Context> {
L1LossOp(const OperatorDef& op_def, Workspace* ws) L1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -32,6 +38,7 @@ class L1LossGradientOp final : public Operator<Context> { ...@@ -32,6 +38,7 @@ class L1LossGradientOp final : public Operator<Context> {
L1LossGradientOp(const OperatorDef& op_def, Workspace* ws) L1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
...@@ -17,6 +22,7 @@ class L2LossOp : public Operator<Context> { ...@@ -17,6 +22,7 @@ class L2LossOp : public Operator<Context> {
L2LossOp(const OperatorDef& op_def, Workspace* ws) L2LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -32,6 +38,7 @@ class L2LossGradientOp final : public Operator<Context> { ...@@ -32,6 +38,7 @@ class L2LossGradientOp final : public Operator<Context> {
L2LossGradientOp(const OperatorDef& op_def, Workspace* ws) L2LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_ #define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
...@@ -17,6 +22,7 @@ class SigmoidCrossEntropyOp final : public Operator<Context> { ...@@ -17,6 +22,7 @@ class SigmoidCrossEntropyOp final : public Operator<Context> {
SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -32,6 +38,7 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> { ...@@ -32,6 +38,7 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
...@@ -20,6 +25,7 @@ class SmoothL1LossOp final : public Operator<Context> { ...@@ -20,6 +25,7 @@ class SmoothL1LossOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {
sigma2 *= sigma2; sigma2 *= sigma2;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -39,6 +45,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> { ...@@ -39,6 +45,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {
sigma2 *= sigma2; sigma2 *= sigma2;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_ #define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
...@@ -19,13 +24,14 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> { ...@@ -19,13 +24,14 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
OperatorDef softmax_def = MakeOperatorDef("Softmax", "", OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ input(0).name() }), vector<string>({ Input(0).name() }),
vector<string>({ "/mnt/" + anchor() + "/softmax_prob" })); vector<string>({ "/mnt/" + anchor() + "/softmax_prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis")); softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option()) if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option()); softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws)); softmax_op.reset(CreateOperator(softmax_def, ws));
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -45,6 +51,7 @@ class SoftmaxCrossEntropyGradientOp final : public Operator<Context> { ...@@ -45,6 +51,7 @@ class SoftmaxCrossEntropyGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_ #define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
...@@ -14,7 +19,7 @@ namespace dragon { ...@@ -14,7 +19,7 @@ namespace dragon {
template <class Context> template <class Context>
class SparseSoftmaxCrossEntropyOp : public Operator<Context> { class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
public: public:
SparseSoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
...@@ -25,13 +30,14 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -25,13 +30,14 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
} }
OperatorDef softmax_def = MakeOperatorDef("Softmax", "", OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ input(0).name() }), vector<string>({ Input(0).name() }),
vector<string>({ "/mnt/" + anchor() + "/softmax_prob" })); vector<string>({ "/mnt/" + anchor() + "/softmax_prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis")); softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option()) if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option()); softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws)); softmax_op.reset(CreateOperator(softmax_def, ws));
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -47,7 +53,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -47,7 +53,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
template <class Context> template <class Context>
class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> { class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
public: public:
SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
...@@ -58,6 +64,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> { ...@@ -58,6 +64,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
} }
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_
...@@ -14,8 +19,8 @@ namespace dragon { ...@@ -14,8 +19,8 @@ namespace dragon {
template <class Context> template <class Context>
class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Context> { class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Context> {
public: public:
SparseSoftmaxFocalLossOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxFocalLossOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(op_def, ws), : SparseSoftmaxCrossEntropyOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")), normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
alpha(OperatorBase::GetSingleArg<float>("alpha", 0.5)), alpha(OperatorBase::GetSingleArg<float>("alpha", 0.5)),
...@@ -23,7 +28,8 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex ...@@ -23,7 +28,8 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) { neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {
pos_alpha = alpha * 2.0; pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0; neg_alpha = (1 - alpha) * 2.0;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -40,13 +46,14 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex ...@@ -40,13 +46,14 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
template <class Context> template <class Context>
class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyGradientOp<Context> { class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyGradientOp<Context> {
public: public:
SparseSoftmaxFocalLossGradientOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxFocalLossGradientOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyGradientOp<Context>(op_def, ws), : SparseSoftmaxCrossEntropyGradientOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")), normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)), gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {} neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_ACCURACY_OP_H_ #ifndef DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
#define DRAGON_OPERATORS_MISC_ACCURACY_OP_H_ #define DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
...@@ -25,6 +30,7 @@ class AccuracyOp final: public Operator<Context> { ...@@ -25,6 +30,7 @@ class AccuracyOp final: public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
} }
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_ #ifndef DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_
#define DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_ #define DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_
...@@ -21,6 +26,7 @@ class GradientGenerateOp final: public Operator<Context> { ...@@ -21,6 +26,7 @@ class GradientGenerateOp final: public Operator<Context> {
CHECK_EQ(defaults.size(), OutputSize()); CHECK_EQ(defaults.size(), OutputSize());
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -35,9 +41,10 @@ class GradientGatherOp final : public Operator<Context> { ...@@ -35,9 +41,10 @@ class GradientGatherOp final : public Operator<Context> {
GradientGatherOp(const OperatorDef& op_def, Workspace* ws) GradientGatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
for (int i = 0; i < InputSize(); i++) for (int i = 0; i < InputSize(); i++)
if (input(i).name() != "ignore") indices.push_back(i); if (Input(i).name() != "ignore") indices.push_back(i);
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -53,6 +60,7 @@ class StopGradientOp final : public Operator<Context> { ...@@ -53,6 +60,7 @@ class StopGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_ #ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_ #define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
...@@ -35,6 +40,7 @@ class ImageDataOp final : public Operator<Context> { ...@@ -35,6 +40,7 @@ class ImageDataOp final : public Operator<Context> {
std.mutable_data<float, CPUContext>()[i] = std_values[i]; std.mutable_data<float, CPUContext>()[i] = std_values[i];
} }
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
#define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_ #define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
...@@ -20,6 +25,7 @@ class InitializeOp: public Operator<Context> { ...@@ -20,6 +25,7 @@ class InitializeOp: public Operator<Context> {
shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) { shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) {
GET_ARGUMENTS_WITH_DESC(int, dims); GET_ARGUMENTS_WITH_DESC(int, dims);
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -38,6 +44,7 @@ public: ...@@ -38,6 +44,7 @@ public:
this->filler.set_type("constant"); this->filler.set_type("constant");
this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0)); this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0));
} }
USE_OPERATOR_FUNCTIONS(Context);
}; };
template <class Context> template <class Context>
...@@ -49,6 +56,7 @@ public: ...@@ -49,6 +56,7 @@ public:
this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0)); this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0));
this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0)); this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0));
} }
USE_OPERATOR_FUNCTIONS(Context);
}; };
template <class Context> template <class Context>
...@@ -60,6 +68,7 @@ public: ...@@ -60,6 +68,7 @@ public:
this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0)); this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0)); this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0));
} }
USE_OPERATOR_FUNCTIONS(Context);
}; };
template <class Context> template <class Context>
...@@ -75,6 +84,7 @@ public: ...@@ -75,6 +84,7 @@ public:
this->filler.set_low(mu - 2 * sigma); this->filler.set_low(mu - 2 * sigma);
this->filler.set_high(mu + 2 * sigma); this->filler.set_high(mu + 2 * sigma);
} }
USE_OPERATOR_FUNCTIONS(Context);
}; };
template <class Context> template <class Context>
...@@ -95,6 +105,7 @@ public: ...@@ -95,6 +105,7 @@ public:
} }
this->filler.set_scale(scale); this->filler.set_scale(scale);
} }
USE_OPERATOR_FUNCTIONS(Context);
}; };
template <class Context> template <class Context>
...@@ -115,6 +126,7 @@ public: ...@@ -115,6 +126,7 @@ public:
} }
this->filler.set_scale(scale); this->filler.set_scale(scale);
} }
USE_OPERATOR_FUNCTIONS(Context);
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims); DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims);
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_PYTHON_OP_H_ #ifndef DRAGON_OPERATORS_MISC_PYTHON_OP_H_
#define DRAGON_OPERATORS_MISC_PYTHON_OP_H_ #define DRAGON_OPERATORS_MISC_PYTHON_OP_H_
...@@ -11,19 +16,13 @@ ...@@ -11,19 +16,13 @@
#include "core/operator.h" #include "core/operator.h"
#ifdef WITH_PYTHON3
#define PyBytes_FromStringAndSize PyUnicode_FromStringAndSize
#endif
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class RunOp : public Operator<Context> { class RunOp : public Operator<Context> {
public: public:
RunOp(const OperatorDef& op_def, Workspace* ws); RunOp(const OperatorDef& op_def, Workspace* ws);
PyObject* String(const char* str) { USE_OPERATOR_FUNCTIONS(Context);
return PyBytes_FromStringAndSize(str, string(str).size());
}
void RunOnDevice() override; void RunOnDevice() override;
...@@ -37,6 +36,7 @@ class TemplateOp : public RunOp<Context> { ...@@ -37,6 +36,7 @@ class TemplateOp : public RunOp<Context> {
public: public:
TemplateOp(const OperatorDef& op_def, Workspace* ws) TemplateOp(const OperatorDef& op_def, Workspace* ws)
: RunOp<Context>(op_def, ws) {} : RunOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
}; };
template <class Context> template <class Context>
...@@ -46,10 +46,11 @@ public: ...@@ -46,10 +46,11 @@ public:
: TemplateOp<Context>(op_def, ws) { : TemplateOp<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_MISC_PYTHON_OP_H_ #endif // DRAGON_OPERATORS_MISC_PYTHON_OP_H_
\ No newline at end of file
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_ #ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_ #define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
...@@ -51,6 +56,9 @@ class ModelMPIBase : public Operator<Context> { ...@@ -51,6 +56,9 @@ class ModelMPIBase : public Operator<Context> {
string dtype; string dtype;
}; };
#define USE_MPIMODEL_FUNCTIONS(context) \
using ModelMPIBase<context>::mpi_dtype
} // namespace dragon } // namespace dragon
#endif // WITH_MPI #endif // WITH_MPI
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_ #ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_ #define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
...@@ -18,6 +23,8 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> { ...@@ -18,6 +23,8 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
public: public:
MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws) MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,6 +37,8 @@ public: ...@@ -30,6 +37,8 @@ public:
: ModelMPIBase<Context>(op_def, ws) { : ModelMPIBase<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_ #ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_ #define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
...@@ -18,6 +23,8 @@ class MPIGatherOp final : public ModelMPIBase<Context> { ...@@ -18,6 +23,8 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
public: public:
MPIGatherOp(const OperatorDef& op_def, Workspace *ws) MPIGatherOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,6 +37,8 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> { ...@@ -30,6 +37,8 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> {
: ModelMPIBase<Context>(op_def, ws) { : ModelMPIBase<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
...@@ -21,6 +26,7 @@ class ArangeOp final : public Operator<Context> { ...@@ -21,6 +26,7 @@ class ArangeOp final : public Operator<Context> {
GET_ARGUMENT_WITH_DESC(int, stop, 0); GET_ARGUMENT_WITH_DESC(int, stop, 0);
GET_ARGUMENT_WITH_DESC(int, step, 1); GET_ARGUMENT_WITH_DESC(int, step, 1);
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
...@@ -19,6 +24,7 @@ class ArgmaxOp final : public Operator<Context> { ...@@ -19,6 +24,7 @@ class ArgmaxOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)), keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {} top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMIN_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ARGMIN_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMIN_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ARGMIN_OP_H_
...@@ -19,6 +24,7 @@ class ArgminOp final : public Operator<Context> { ...@@ -19,6 +24,7 @@ class ArgminOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)), keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {} top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
...@@ -18,6 +23,7 @@ class ConcatOp : public Operator<Context> { ...@@ -18,6 +23,7 @@ class ConcatOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {} nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -37,6 +43,7 @@ class ConcatGradientOp : public Operator<Context> { ...@@ -37,6 +43,7 @@ class ConcatGradientOp : public Operator<Context> {
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) { nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_CROP_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CROP_OP_H_ #define DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
...@@ -22,6 +27,7 @@ class CropOp: public Operator<Context> { ...@@ -22,6 +27,7 @@ class CropOp: public Operator<Context> {
offsets(OperatorBase::GetRepeatedArg<int>("offsets")), offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")), shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {} shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
...@@ -49,6 +55,7 @@ class CropGradientOp final : public Operator<Context > { ...@@ -49,6 +55,7 @@ class CropGradientOp final : public Operator<Context > {
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) { shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
#define DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_ #define DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
...@@ -17,6 +22,7 @@ class ExpandDimsOp final : public Operator<Context> { ...@@ -17,6 +22,7 @@ class ExpandDimsOp final : public Operator<Context> {
ExpandDimsOp(const OperatorDef& op_def, Workspace* ws) ExpandDimsOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {} axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
...@@ -31,6 +37,7 @@ class ExpandDimsGradientOp final : public Operator<Context> { ...@@ -31,6 +37,7 @@ class ExpandDimsGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_
#define DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_ #define DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_
...@@ -19,6 +24,7 @@ class FlattenOp final : public Operator<Context> { ...@@ -19,6 +24,7 @@ class FlattenOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
keep_axes(OperatorBase::GetSingleArg<int>("keep_axes", INT_MAX)) {} keep_axes(OperatorBase::GetSingleArg<int>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
void SqueezeRun(); void SqueezeRun();
...@@ -35,6 +41,7 @@ class FlattenGradientOp final : public Operator<Context> { ...@@ -35,6 +41,7 @@ class FlattenGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_ #define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
...@@ -17,6 +22,7 @@ class GatherOp final : public Operator<Context> { ...@@ -17,6 +22,7 @@ class GatherOp final : public Operator<Context> {
GatherOp(const OperatorDef& op_def, Workspace* ws) GatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {} axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -33,6 +39,7 @@ class GatherGradientOp final : public Operator<Context> { ...@@ -33,6 +39,7 @@ class GatherGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {} acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
...@@ -19,6 +24,7 @@ class OneHotOp final : public Operator < Context > { ...@@ -19,6 +24,7 @@ class OneHotOp final : public Operator < Context > {
depth(OperatorBase::GetSingleArg<int>("depth", -1)), depth(OperatorBase::GetSingleArg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)), on_value(OperatorBase::GetSingleArg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {} off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
#define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_ #define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
...@@ -30,6 +35,7 @@ class PadOp final : public Operator<Context> { ...@@ -30,6 +35,7 @@ class PadOp final : public Operator<Context> {
} }
std::sort(process_axes.begin(), process_axes.end()); std::sort(process_axes.begin(), process_axes.end());
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ConstRunWithType(); template <typename T> void ConstRunWithType();
...@@ -65,6 +71,7 @@ class PadGradientOp final : public Operator<Context> { ...@@ -65,6 +71,7 @@ class PadGradientOp final : public Operator<Context> {
std::reverse(process_axes.begin(), process_axes.end()); std::reverse(process_axes.begin(), process_axes.end());
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ConstRunWithType(); template <typename T> void ConstRunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_ #define DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
...@@ -18,6 +23,7 @@ class RandomPickOp : public Operator<Context> { ...@@ -18,6 +23,7 @@ class RandomPickOp : public Operator<Context> {
Operator<Context>(op_def, ws), Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
max_samples(OperatorBase::GetSingleArg<int>("max_samples", 1)) {} max_samples(OperatorBase::GetSingleArg<int>("max_samples", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -37,6 +43,7 @@ public: ...@@ -37,6 +43,7 @@ public:
axis(OperatorBase::GetSingleArg<int>("axis", 0)) { axis(OperatorBase::GetSingleArg<int>("axis", 0)) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
...@@ -19,6 +24,7 @@ class ReduceOp final : public Operator<Context> { ...@@ -19,6 +24,7 @@ class ReduceOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")), operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {} keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
...@@ -38,6 +44,7 @@ class ReduceGradientOp final : public Operator<Context> { ...@@ -38,6 +44,7 @@ class ReduceGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {} operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
...@@ -19,6 +24,7 @@ class RepeatOp : public Operator<Context> { ...@@ -19,6 +24,7 @@ class RepeatOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1); GET_ARGUMENT_WITH_DESC(int, repeats, 1);
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunWithType();
...@@ -36,6 +42,7 @@ class RepeatGradientOp : public Operator<Context> { ...@@ -36,6 +42,7 @@ class RepeatGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1); GET_ARGUMENT_WITH_DESC(int, repeats, 1);
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
...@@ -19,6 +24,7 @@ class ReshapeOp final : public Operator<Context> { ...@@ -19,6 +24,7 @@ class ReshapeOp final : public Operator<Context> {
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) { shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape); GET_ARGUMENTS_WITH_DESC(int, shape);
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,6 +41,7 @@ class ReshapeGradientOp final : public Operator<Context> { ...@@ -35,6 +41,7 @@ class ReshapeGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
...@@ -15,6 +20,7 @@ template <class Context> ...@@ -15,6 +20,7 @@ template <class Context>
class ShapeOp final : public Operator<Context> { class ShapeOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ShapeOp); USE_SIMPLE_CTOR_DTOR(ShapeOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
...@@ -18,6 +23,7 @@ class SliceOp : public Operator<Context> { ...@@ -18,6 +23,7 @@ class SliceOp : public Operator<Context> {
Operator<Context>(op_def, ws), Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {} nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -38,6 +44,7 @@ class SliceGradientOp final : public Operator<Context> { ...@@ -38,6 +44,7 @@ class SliceGradientOp final : public Operator<Context> {
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) { nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_STACK_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_STACK_OP_H_ #define DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
...@@ -18,6 +23,7 @@ class StackOp : public Operator<Context> { ...@@ -18,6 +23,7 @@ class StackOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {} nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -35,6 +41,7 @@ class StackGradientOp : public Operator<Context> { ...@@ -35,6 +41,7 @@ class StackGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {} nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_TILE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TILE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
...@@ -18,6 +23,7 @@ class TileOp : public Operator<Context> { ...@@ -18,6 +23,7 @@ class TileOp : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples); GET_ARGUMENTS_WITH_DESC(int, multiples);
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
...@@ -36,6 +42,7 @@ class TileGradientOp : public Operator<Context> { ...@@ -36,6 +42,7 @@ class TileGradientOp : public Operator<Context> {
GET_ARGUMENTS_WITH_DESC(int, multiples); GET_ARGUMENTS_WITH_DESC(int, multiples);
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_ #define DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
...@@ -20,6 +25,7 @@ class TransposeOp final: public Operator<Context> { ...@@ -20,6 +25,7 @@ class TransposeOp final: public Operator<Context> {
if (perms.size() > 0) reverse_dims = false; if (perms.size() > 0) reverse_dims = false;
else reverse_dims = true; else reverse_dims = true;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -34,8 +40,9 @@ class TransposeOp final: public Operator<Context> { ...@@ -34,8 +40,9 @@ class TransposeOp final: public Operator<Context> {
template <class Context> template <class Context>
class TransposeGradientOp final : public Operator<Context> { class TransposeGradientOp final : public Operator<Context> {
public: public:
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws) TransposeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
...@@ -25,6 +30,7 @@ class BatchNormOp : public Operator<Context> { ...@@ -25,6 +30,7 @@ class BatchNormOp : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -54,6 +60,7 @@ class BatchNormGradientOp final : public Operator<Context> { ...@@ -54,6 +60,7 @@ class BatchNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -80,6 +87,7 @@ class FusedBatchNormOp : public Operator<Context> { ...@@ -80,6 +87,7 @@ class FusedBatchNormOp : public Operator<Context> {
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))), momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -106,6 +114,7 @@ class FusedBatchNormGradientOp : public Operator<Context> { ...@@ -106,6 +114,7 @@ class FusedBatchNormGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -142,6 +151,7 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> { ...@@ -142,6 +151,7 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON)); this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNBatchNormOp() { ~CuDNNBatchNormOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -172,6 +182,7 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context> ...@@ -172,6 +182,7 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context>
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON)); this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNBatchNormGradientOp() { ~CuDNNBatchNormGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_ #define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
...@@ -29,6 +34,7 @@ class BatchRenormOp : public Operator<Context> { ...@@ -29,6 +34,7 @@ class BatchRenormOp : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -59,6 +65,7 @@ class BatchRenormGradientOp final : public Operator<Context> { ...@@ -59,6 +65,7 @@ class BatchRenormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
...@@ -26,6 +31,7 @@ class GroupNormOp : public Operator<Context> { ...@@ -26,6 +31,7 @@ class GroupNormOp : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -56,6 +62,7 @@ class GroupNormGradientOp final : public Operator<Context> { ...@@ -56,6 +62,7 @@ class GroupNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -83,6 +90,7 @@ class FusedGroupNormOp : public Operator<Context> { ...@@ -83,6 +90,7 @@ class FusedGroupNormOp : public Operator<Context> {
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))), momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -110,6 +118,7 @@ class FusedGroupNormGradientOp : public Operator<Context> { ...@@ -110,6 +118,7 @@ class FusedGroupNormGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_ #ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_ #define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
...@@ -22,6 +27,7 @@ class InstanceNormOp : public Operator<Context> { ...@@ -22,6 +27,7 @@ class InstanceNormOp : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -46,6 +52,7 @@ class InstanceNormGradientOp final : public Operator<Context> { ...@@ -46,6 +52,7 @@ class InstanceNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_ #ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_
#define DRAGON_OPERATORS_NORM_L2_NORM_H_ #define DRAGON_OPERATORS_NORM_L2_NORM_H_
...@@ -20,6 +25,7 @@ class L2NormOp final : public Operator<Context> { ...@@ -20,6 +25,7 @@ class L2NormOp final : public Operator<Context> {
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-5))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-5))),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {} mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -41,6 +47,7 @@ class L2NormGradientOp final : public Operator<Context> { ...@@ -41,6 +47,7 @@ class L2NormGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {} mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_ #define DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
...@@ -17,6 +22,7 @@ class LSTMUnitOp : public Operator<Context> { ...@@ -17,6 +22,7 @@ class LSTMUnitOp : public Operator<Context> {
LSTMUnitOp(const OperatorDef& op_def, Workspace* ws) LSTMUnitOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
has_cont(OperatorBase::GetSingleArg<string>("cont_t", "")) {} has_cont(OperatorBase::GetSingleArg<string>("cont_t", "")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,10 +36,11 @@ class LSTMUnitOp : public Operator<Context> { ...@@ -30,10 +36,11 @@ class LSTMUnitOp : public Operator<Context> {
template <class Context> template <class Context>
class LSTMUnitGradientOp : public Operator<Context> { class LSTMUnitGradientOp : public Operator<Context> {
public: public:
LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws) LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
this->allow_share_grads_ = false; this->allow_share_grads_ = false;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
...@@ -17,9 +22,11 @@ class AdamUpdateOp final : public UpdateOpBase<Context> { ...@@ -17,9 +22,11 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws) AdamUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws),
t(0), t(0),
eps(param("eps")), eps(Param("eps")),
beta1(param("beta1")), beta1(Param("beta1")),
beta2(param("beta2")) {} beta2(Param("beta2")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
...@@ -16,12 +21,13 @@ namespace dragon { ...@@ -16,12 +21,13 @@ namespace dragon {
template <class Context> template <class Context>
class CollectiveUpdateOp : public Operator<Context> { class CollectiveUpdateOp : public Operator<Context> {
public: public:
CollectiveUpdateOp(const OperatorDef& op_def, Workspace* ws) CollectiveUpdateOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
mode(OperatorBase::GetSingleArg<string>("mode", "UNKNOWN")) { mode(OperatorBase::GetSingleArg<string>("mode", "UNKNOWN")) {
InitMPI(); InitMPI();
if (mode.find("NCCL") != string::npos) InitNCCL(); if (mode.find("NCCL") != string::npos) InitNCCL();
} }
USE_OPERATOR_FUNCTIONS(Context);
void InitMPI(); void InitMPI();
void InitNCCL(); void InitNCCL();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
#define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_ #define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
...@@ -17,6 +22,7 @@ class MovingAverageOp final : public Operator<Context> { ...@@ -17,6 +22,7 @@ class MovingAverageOp final : public Operator<Context> {
MovingAverageOp(const OperatorDef& op_def, Workspace* ws) MovingAverageOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {} decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
...@@ -14,11 +19,13 @@ namespace dragon { ...@@ -14,11 +19,13 @@ namespace dragon {
template <class Context> template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> { class NesterovUpdateOp final : public UpdateOpBase<Context> {
public: public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws) NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws),
momentum(param("momentum")) {} momentum(Param("momentum")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
protected: protected:
float lr, momentum; float lr, momentum;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
...@@ -16,8 +21,10 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> { ...@@ -16,8 +21,10 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public: public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws) RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws),
eps(param("eps")), eps(Param("eps")),
decay(param("decay")) {} decay(Param("decay")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_ #ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_ #define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
...@@ -16,7 +21,9 @@ class SGDUpdateOp final : public UpdateOpBase<Context> { ...@@ -16,7 +21,9 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
public: public:
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws) SGDUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws),
momentum(param("momentum")) {} momentum(Param("momentum")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_ #ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
#define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_ #define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
...@@ -19,8 +24,9 @@ class UpdateOpBase : public Operator<Context> { ...@@ -19,8 +24,9 @@ class UpdateOpBase : public Operator<Context> {
lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)), lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)), decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)),
domain(OperatorBase::GetSingleArg<string>("domain", "_")) {} domain(OperatorBase::GetSingleArg<string>("domain", "_")) {}
USE_OPERATOR_FUNCTIONS(Context);
float param(const string& name) const; float Param(const string& name) const;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void PreprocessRunWithType(); template <typename T> void PreprocessRunWithType();
...@@ -33,6 +39,9 @@ class UpdateOpBase : public Operator<Context> { ...@@ -33,6 +39,9 @@ class UpdateOpBase : public Operator<Context> {
string domain; string domain;
}; };
#define USE_UPDATER_FUNCTIONS(context) \
using UpdateOpBase<context>::Param
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_ #endif // DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
\ No newline at end of file
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_BILINEAR_RESIZE_OP_H_ #ifndef DRAGON_OPERATORS_VISION_BILINEAR_RESIZE_OP_H_
#define DRAGON_OPERATORS_VISION_BILINEAR_RESIZE_OP_H_ #define DRAGON_OPERATORS_VISION_BILINEAR_RESIZE_OP_H_
...@@ -25,6 +30,8 @@ class BilinearResizeOp : public Operator<Context> { ...@@ -25,6 +30,8 @@ class BilinearResizeOp : public Operator<Context> {
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -41,6 +48,7 @@ class BilinearResizeGradientOp : public Operator<Context> { ...@@ -41,6 +48,7 @@ class BilinearResizeGradientOp : public Operator<Context> {
BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws) BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_CONV_OP_H_ #ifndef DRAGON_OPERATORS_VISION_CONV_OP_H_
#define DRAGON_OPERATORS_VISION_CONV_OP_H_ #define DRAGON_OPERATORS_VISION_CONV_OP_H_
...@@ -19,9 +24,11 @@ class Conv2dOp : public ConvOpBase<Context> { ...@@ -19,9 +24,11 @@ class Conv2dOp : public ConvOpBase<Context> {
this->num_spatial_axes = 2; this->num_spatial_axes = 2;
Setup(); Setup();
} }
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
bool ReverseDimensions() override { return false; } bool ReverseDimensions() override { return false; }
virtual bool HasBias() { return InputSize() > 2; } bool HasBias() override { return InputSize() > 2; }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -32,8 +39,10 @@ class Conv2dGradientOp : public Conv2dOp<Context> { ...@@ -32,8 +39,10 @@ class Conv2dGradientOp : public Conv2dOp<Context> {
public: public:
Conv2dGradientOp(const OperatorDef& def, Workspace* ws) Conv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) {} : Conv2dOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
bool HasBias() override { return output(2)->name() != "ignore"; } bool HasBias() override { return Output(2)->name() != "ignore"; }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -70,6 +79,8 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -70,6 +79,8 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC; else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format; else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dOp() { ~CuDNNConv2dOp() {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc)); CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
...@@ -124,6 +135,8 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -124,6 +135,8 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC; else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format; else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dGradientOp() { ~CuDNNConv2dGradientOp() {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc)); CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_ #ifndef DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_
#define DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_ #define DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_
...@@ -29,6 +34,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -29,6 +34,7 @@ class ConvOpBase : public Operator<Context> {
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
num_spatial_axes = -1; // unknown num_spatial_axes = -1; // unknown
} }
USE_OPERATOR_FUNCTIONS(Context);
protected: protected:
vector<TIndex> kernel_size, stride, pad, dilation; vector<TIndex> kernel_size, stride, pad, dilation;
...@@ -50,6 +56,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -50,6 +56,7 @@ class ConvOpBase : public Operator<Context> {
void GradientReshape(); void GradientReshape();
virtual void ComputeOutputShape(); virtual void ComputeOutputShape();
virtual bool ReverseDimensions() = 0; virtual bool ReverseDimensions() = 0;
virtual bool HasBias() = 0;
template <typename T> void Wx(const T* x, const T* weights, T* y, bool skip_im2col = false); template <typename T> void Wx(const T* x, const T* weights, T* y, bool skip_im2col = false);
template <typename T> void Pb(const T* bias, T* y); template <typename T> void Pb(const T* bias, T* y);
...@@ -59,7 +66,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -59,7 +66,7 @@ class ConvOpBase : public Operator<Context> {
private: private:
template <typename T> void Im2Col(const T* im, T* col) { template <typename T> void Im2Col(const T* im, T* col) {
if (input(0).ndim() == 4) { if (Input(0).ndim() == 4) {
kernel::Im2Col2d<T, Context>(conv_in_channels, kernel::Im2Col2d<T, Context>(conv_in_channels,
input_shape[0], input_shape[1], input_shape[0], input_shape[1],
output_shape[0], output_shape[1], output_shape[0], output_shape[1],
...@@ -73,7 +80,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -73,7 +80,7 @@ class ConvOpBase : public Operator<Context> {
} else LOG(FATAL) << "ConvNd has not been implemented yet"; } else LOG(FATAL) << "ConvNd has not been implemented yet";
} }
template <typename T> void Col2Im(const T* col, T* im) { template <typename T> void Col2Im(const T* col, T* im) {
if (input(0).ndim() == 4) { if (Input(0).ndim() == 4) {
kernel::Col2Im2d<T, Context>(conv_in_channels, kernel::Col2Im2d<T, Context>(conv_in_channels,
input_shape[0], input_shape[1], input_shape[0], input_shape[1],
output_shape[0], output_shape[1], output_shape[0], output_shape[1],
...@@ -90,6 +97,19 @@ class ConvOpBase : public Operator<Context> { ...@@ -90,6 +97,19 @@ class ConvOpBase : public Operator<Context> {
DEFINE_ARGUMENTS_WITH_DESC(int, ConvOpBase, output_dims); DEFINE_ARGUMENTS_WITH_DESC(int, ConvOpBase, output_dims);
#define USE_CONVOLUTION_FUNCTIONS(context) \
using ConvOpBase<context>::Setup; \
using ConvOpBase<context>::Reshape; \
using ConvOpBase<context>::GradientReshape; \
using ConvOpBase<context>::ComputeOutputShape; \
using ConvOpBase<Context>::ReverseDimensions; \
using ConvOpBase<Context>::HasBias; \
using ConvOpBase<context>::Wx; \
using ConvOpBase<context>::Pb; \
using ConvOpBase<context>::Dx; \
using ConvOpBase<context>::Dw; \
using ConvOpBase<context>::Db
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_ #endif // DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_
\ No newline at end of file
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_ #ifndef DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_ #define DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
...@@ -19,9 +24,11 @@ class Conv2dTransposeOp: public ConvOpBase<Context> { ...@@ -19,9 +24,11 @@ class Conv2dTransposeOp: public ConvOpBase<Context> {
this->num_spatial_axes = 2; this->num_spatial_axes = 2;
Setup(); Setup();
} }
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
bool ReverseDimensions() override { return true; } bool ReverseDimensions() override { return true; }
virtual bool HasBias() { return InputSize() > 2; } bool HasBias() override { return InputSize() > 2; }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -36,8 +43,10 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> { ...@@ -36,8 +43,10 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> {
public: public:
Conv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws) Conv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) {} : Conv2dTransposeOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
bool HasBias() override { return output(2)->name() != "ignore"; } bool HasBias() override { return Output(2)->name() != "ignore"; }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -73,6 +82,8 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -73,6 +82,8 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC; else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format; else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dTransposeOp() { ~CuDNNConv2dTransposeOp() {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc)); CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
...@@ -127,6 +138,8 @@ public: ...@@ -127,6 +138,8 @@ public:
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC; else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format; else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dTransposeGradientOp() { ~CuDNNConv2dTransposeGradientOp() {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc)); CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_ #ifndef DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
#define DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_ #define DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
...@@ -14,8 +19,9 @@ namespace dragon { ...@@ -14,8 +19,9 @@ namespace dragon {
template <class Context> template <class Context>
class DenseConcatOp final : public ConcatOp<Context> { class DenseConcatOp final : public ConcatOp<Context> {
public: public:
DenseConcatOp(const OperatorDef& op_def, Workspace* ws) DenseConcatOp(const OperatorDef& op_def, Workspace* ws)
: ConcatOp<Context>(op_def, ws) {} : ConcatOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
}; };
template <class Context> template <class Context>
...@@ -24,12 +30,13 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> { ...@@ -24,12 +30,13 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> {
DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws) DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: ConcatGradientOp<Context>(op_def, ws), : ConcatGradientOp<Context>(op_def, ws),
growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {} growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ElimateCorruption() override; void ElimateCorruption() override;
template <typename T> void RestoreX1(); template <typename T> void RestoreX1();
protected: protected:
TIndex growth_rate; TIndex growth_rate;
}; };
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_LRN_OP_H_ #ifndef DRAGON_OPERATORS_VISION_LRN_OP_H_
#define DRAGON_OPERATORS_VISION_LRN_OP_H_ #define DRAGON_OPERATORS_VISION_LRN_OP_H_
...@@ -24,6 +29,7 @@ class LRNOp : public Operator<Context> { ...@@ -24,6 +29,7 @@ class LRNOp : public Operator<Context> {
k(OperatorBase::GetSingleArg<float>("k", float(2.0))), k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")), mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -54,6 +60,7 @@ class LRNGradientOp : public Operator<Context> { ...@@ -54,6 +60,7 @@ class LRNGradientOp : public Operator<Context> {
k(OperatorBase::GetSingleArg<float>("k", float(2.0))), k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")), mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -90,6 +97,7 @@ class CuDNNLRNOp : public LRNOp<Context> { ...@@ -90,6 +97,7 @@ class CuDNNLRNOp : public LRNOp<Context> {
this->beta, this->beta,
this->k)); this->k));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNLRNOp() { ~CuDNNLRNOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -118,6 +126,7 @@ class CuDNNLRNGradientOp : public LRNGradientOp<Context > { ...@@ -118,6 +126,7 @@ class CuDNNLRNGradientOp : public LRNGradientOp<Context > {
this->beta, this->beta,
this->k)); this->k));
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNLRNGradientOp() { ~CuDNNLRNGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_NN_RESIZE_OP_H_ #ifndef DRAGON_OPERATORS_VISION_NN_RESIZE_OP_H_
#define DRAGON_OPERATORS_VISION_NN_RESIZE_OP_H_ #define DRAGON_OPERATORS_VISION_NN_RESIZE_OP_H_
...@@ -25,6 +30,7 @@ class NNResizeOp : public Operator<Context> { ...@@ -25,6 +30,7 @@ class NNResizeOp : public Operator<Context> {
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -42,6 +48,7 @@ class NNResizeGradientOp : public Operator<Context> { ...@@ -42,6 +48,7 @@ class NNResizeGradientOp : public Operator<Context> {
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws) NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_POOLING_OP_H_ #ifndef DRAGON_OPERATORS_VISION_POOLING_OP_H_
#define DRAGON_OPERATORS_VISION_POOLING_OP_H_ #define DRAGON_OPERATORS_VISION_POOLING_OP_H_
...@@ -35,6 +40,7 @@ class Pooling2dOp: public Operator <Context> { ...@@ -35,6 +40,7 @@ class Pooling2dOp: public Operator <Context> {
} }
} }
} }
USE_OPERATOR_FUNCTIONS(Context);
void Reshape(); void Reshape();
void RunOnDevice() override; void RunOnDevice() override;
...@@ -73,6 +79,7 @@ class Pooling2dGradientOp: public Operator<Context> { ...@@ -73,6 +79,7 @@ class Pooling2dGradientOp: public Operator<Context> {
} }
} }
} }
USE_OPERATOR_FUNCTIONS(Context);
void Reshape(); void Reshape();
void RunOnDevice() override; void RunOnDevice() override;
...@@ -107,6 +114,7 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> { ...@@ -107,6 +114,7 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode; } else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNPooling2dOp() { ~CuDNNPooling2dOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -156,6 +164,7 @@ class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> { ...@@ -156,6 +164,7 @@ class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> {
this->stride[0], this->stride[1])); this->stride[0], this->stride[1]));
#endif #endif
} }
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNPooling2dGradientOp() { ~CuDNNPooling2dGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_ROI_ALIGN_OP_H_ #ifndef DRAGON_OPERATORS_VISION_ROI_ALIGN_OP_H_
#define DRAGON_OPERATORS_VISION_ROI_ALIGN_OP_H_ #define DRAGON_OPERATORS_VISION_ROI_ALIGN_OP_H_
...@@ -23,6 +28,7 @@ class ROIAlignOp : public Operator<Context> { ...@@ -23,6 +28,7 @@ class ROIAlignOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -44,6 +50,7 @@ class ROIAlignGradientOp : public Operator<Context> { ...@@ -44,6 +50,7 @@ class ROIAlignGradientOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// -------------------------------------------------------- // ------------------------------------------------------------
// Dragon // Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright(c) 2017 SeetaTech //
// Written by Ting Pan // Licensed under the BSD 2-Clause License.
// -------------------------------------------------------- // You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_ROI_POOLING_OP_H_ #ifndef DRAGON_OPERATORS_VISION_ROI_POOLING_OP_H_
#define DRAGON_OPERATORS_VISION_ROI_POOLING_OP_H_ #define DRAGON_OPERATORS_VISION_ROI_POOLING_OP_H_
...@@ -22,6 +27,7 @@ class ROIPoolingOp : public Operator<Context> { ...@@ -22,6 +27,7 @@ class ROIPoolingOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -40,6 +46,7 @@ class ROIPoolingGradientOp final : public Operator<Context> { ...@@ -40,6 +46,7 @@ class ROIPoolingGradientOp final : public Operator<Context> {
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)), pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {} spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!