Commit 72f3c4ba by Ting PAN

Add License into headers

1 parent 44906e17
Showing with 1300 additions and 566 deletions
# ---------------- Welcom To Use Dragon ----------------
# ---------------- Welcom To Use Dragon ----------------
PROJECT(dragon)
CMAKE_MINIMUM_REQUIRED(VERSION 3.0.0)
......@@ -158,7 +158,7 @@ if(WIN32)
endif()
if(UNIX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -O2 -m64 -fpermissive -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w -fPIC -O2 -m64 -std=c++11")
if (WITH_OMP)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_COMMON_H_
#define DRAGON_CORE_COMMON_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_CONTEXT_H_
#define DRAGON_CORE_CONTEXT_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_CONTEXT_CUDA_H_
#define DRAGON_CORE_CONTEXT_CUDA_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_GRAPH_H_
#define DRAGON_CORE_GRAPH_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_GRAPH_GRADIENT_H_
#define DRAGON_CORE_GRAPH_GRADIENT_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_MIXEDMEM_H_
#define DRAGON_CORE_MIXEDMEM_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_H_
#define DRAGON_CORE_OPERATOR_H_
......@@ -27,8 +32,8 @@ class OperatorBase {
OperatorBase(const OperatorDef& op_def, Workspace* ws);
virtual ~OperatorBase() {}
Tensor& input(int idx);
Tensor* output(int idx);
Tensor& Input(int idx);
Tensor* Output(int idx);
inline size_t InputSize() { return inputs_.size(); }
inline size_t OutputSize() { return outputs_.size(); }
......@@ -55,7 +60,7 @@ class OperatorBase {
void set_recompute_map(RecomputeMap recompute_map) { recompute_map_ = recompute_map; }
inline const OperatorDef& op_def() const { return op_def_; }
inline const string debug_string() const { return op_def_.DebugString(); }
inline const string DebugString() const { return op_def_.DebugString(); }
protected:
string phase_;
......@@ -73,7 +78,7 @@ class Operator : public OperatorBase {
: OperatorBase(op_def, ws), ctx_(op_def.device_option()) {
allow_run_ = true;
allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && output(0)->name() == "ignore"));
allow_run_ &= (!(OutputSize() == 1 && Output(0)->name() == "ignore"));
allow_share_grads_ = (!op_def.debug_mode());
allow_share_grads_ &= op_def.share_grads();
allow_share_grads_ &= (type().find("Gradient") != string::npos);
......@@ -97,9 +102,9 @@ class Operator : public OperatorBase {
void MemorySwitch() {
for (int i = 0; i < InputSize(); i++)
if (input(i).name() != "ignore") input(i).SwitchToDevice();
if (Input(i).name() != "ignore") Input(i).SwitchToDevice();
for (int i = 0; i < OutputSize(); i++)
if (output(i)->name() != "ignore") output(i)->SwitchToDevice();
if (Output(i)->name() != "ignore") Output(i)->SwitchToDevice();
}
virtual void RunOnDevice() = 0;
......@@ -135,6 +140,23 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
: Operator<Context>(op_def, ws) {} \
virtual ~name() {}
#define USE_OPERATOR_BASE_FUNCTIONS \
using OperatorBase::Input; \
using OperatorBase::Output; \
using OperatorBase::ws; \
using OperatorBase::name; \
using OperatorBase::type; \
using OperatorBase::phase; \
using OperatorBase::op_def; \
using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \
using OperatorBase::DebugString \
#define USE_OPERATOR_FUNCTIONS(context) \
USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<context>::ctx; \
using Operator<context>::anchor
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
......@@ -242,6 +264,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
INSTANTIATE_OPERATOR(name, CUDAContext); \
#define DEPLOY_CPU_CUDA(name) \
REGISTER_CPU_OPERATOR(name, name##Op<CPUContext>); \
REGISTER_CUDA_OPERATOR(name, name##Op<CPUContext>); \
INSTANTIATE_OPERATOR(name, CPUContext); \
......@@ -250,4 +273,4 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
INSTANTIATE_CUDNN_OPERATOR(name);
} // namespace dragon
#endif // DRAGON_CORE_OPERATOR_H_
#endif // DRAGON_CORE_OPERATOR_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_
#define DRAGON_CORE_OPERATOR_GRADIENT_H_
......@@ -18,7 +23,7 @@ struct Gradient {
vector<OperatorDef> ops;
vector<string> g_inputs;
vector<float> defaults;
Gradient(const vector<OperatorDef>& ops,
Gradient(const vector<OperatorDef>& ops,
const vector<string>& g_inputs,
const vector<float>& defaults)
: ops(ops), g_inputs(g_inputs), defaults(defaults) {}
......@@ -37,20 +42,20 @@ class GradientMakerBase {
inline virtual Gradient Make() {
vector<OperatorDef> new_defs = MakeDefs();
Argument anchor;
Argument anchor;
anchor.set_name("anchor"); anchor.set_s(def.name());
for (int i = 0; i < new_defs.size(); i++)
new_defs[i].add_arg()->CopyFrom(anchor);
return Gradient(new_defs, g_inputs_, DefaultValues());
};
virtual inline vector<OperatorDef> MakeDefs() {
NOT_IMPLEMENTED;
return vector<OperatorDef>();
virtual inline vector<OperatorDef> MakeDefs() {
NOT_IMPLEMENTED;
return vector<OperatorDef>();
}
virtual inline vector<float> DefaultValues() {
return vector<float>(g_outputs_.size(), 1.0);
virtual inline vector<float> DefaultValues() {
return vector<float>(g_outputs_.size(), 1.0);
}
template <class... Args>
......@@ -79,24 +84,24 @@ Gradient MakeGradientForOp(const OperatorDef& op_def, const vector<string>& g_ou
# define GRADIENT_MAKER_CTOR(name) \
name(const OperatorDef& def, const vector<string>& g_output) \
: GradientMakerBase(def, g_output) {}
: GradientMakerBase(def, g_output) {}
class NoGradient : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(NoGradient);
vector<OperatorDef> MakeDefs() override {
return vector<OperatorDef>();
vector<OperatorDef> MakeDefs() override {
return vector<OperatorDef>();
}
};
DECLARE_REGISTRY(GradientRegistry,
GradientMakerBase,
const OperatorDef&,
DECLARE_REGISTRY(GradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
DECLARE_REGISTRY(NoGradientRegistry,
GradientMakerBase,
const OperatorDef&,
DECLARE_REGISTRY(NoGradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
// define in the operator.cc
......@@ -109,4 +114,4 @@ DECLARE_REGISTRY(NoGradientRegistry,
} // namespace dragon
#endif // DRAGON_CORE_OPERATOR_GRADIENT_H_
\ No newline at end of file
#endif // DRAGON_CORE_OPERATOR_GRADIENT_H_
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_
#define DRAGON_CORE_OPERATOR_SCHEMA_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_REGISTRY_H_
#define DRAGON_CORE_REGISTRY_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_TENSOR_H_
#define DRAONG_CORE_TENSOR_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_TYPEID_H_
#define DRAGON_CORE_TYPEID_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_TYPES_H_
#define DRAGON_CORE_TYPES_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
......@@ -20,6 +25,7 @@ class DropoutOp final : public Operator<Context> {
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -39,6 +45,7 @@ class DropoutGradientOp final : public Operator<Context> {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_ELU_OP_H_
......@@ -17,6 +22,7 @@ class EluOp : public Operator<Context> {
EluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -33,6 +39,7 @@ class EluGradientOp : public Operator<Context> {
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -56,6 +63,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNEluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -82,6 +90,7 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNEluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
......@@ -18,6 +23,7 @@ class PReluOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -35,6 +41,7 @@ class PReluGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
......@@ -17,6 +22,7 @@ class ReluOp : public Operator<Context> {
ReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -33,6 +39,7 @@ class ReluGradientOp : public Operator<Context> {
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -54,6 +61,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNReluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -80,6 +88,7 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNReluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
......@@ -16,6 +21,7 @@ class SEluOp : public Operator<Context> {
public:
SEluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -28,6 +34,7 @@ class SEluGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
#define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
......@@ -15,6 +20,7 @@ template <class Context>
class SigmoidOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SigmoidOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -27,6 +33,7 @@ class SigmoidGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -45,6 +52,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNSigmoidOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -71,6 +79,7 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNSigmoidGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
......@@ -17,6 +22,7 @@ class SoftmaxOp final : public Operator<Context> {
SoftmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -35,6 +41,7 @@ class SoftmaxGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -58,6 +65,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNSoftmaxOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -82,6 +90,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNSoftmaxGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class TanhOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(TanhOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -27,6 +33,7 @@ class TanhGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -45,6 +52,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNTanhOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -71,6 +79,7 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNTanhGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class AddOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(AddOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
......@@ -28,6 +34,7 @@ template <class Context>
class AddGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(AddGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......@@ -42,6 +49,7 @@ template <class Context>
class RAddOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RAddOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
......@@ -55,6 +63,7 @@ template <class Context>
class RAddGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RAddGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
......@@ -17,6 +22,7 @@ class BiasAddOp : public Operator<Context> {
BiasAddOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -35,6 +41,7 @@ class BiasAddGradientOp final : public Operator<Context> {
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
......@@ -19,6 +24,7 @@ class ClipOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)),
high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -32,6 +38,7 @@ template <class Context>
class ClipGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ClipGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class DivOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(DivOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
......@@ -28,6 +34,7 @@ template <class Context>
class DivGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(DivGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......@@ -42,6 +49,7 @@ template <class Context>
class RDivOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RDivOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
......@@ -55,6 +63,7 @@ template <class Context>
class RDivGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RDivGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
......@@ -18,6 +23,7 @@ class DotOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void DotRunWithType();
......@@ -36,6 +42,7 @@ class DotGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
......@@ -24,6 +29,7 @@ class EltwiseOp final : public Operator<Context> {
<< "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1));
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void SumRunWithType();
......@@ -47,6 +53,7 @@ class EltwiseGradientOp final : public Operator<Context> {
<< "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1));
}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class ExpOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -24,6 +30,7 @@ template <class Context>
class ExpGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
......@@ -17,6 +22,7 @@ class GramMatrixOp final : public Operator<Context> {
GramMatrixOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -32,6 +38,7 @@ class GramMatrixGradientOp final : public Operator<Context> {
GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
......@@ -19,6 +24,7 @@ class InnerProductOp: public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice();
template <typename T> void TransRunWithType();
......@@ -38,6 +44,7 @@ class InnerProductGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class LogOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LogOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -24,6 +30,7 @@ template <class Context>
class LogGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LogGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
......@@ -18,6 +23,7 @@ class MatmulOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -35,6 +41,7 @@ class MatmulGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class MulOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MulOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
......@@ -28,6 +34,7 @@ template <class Context>
class MulGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MulGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......@@ -42,6 +49,7 @@ template <class Context>
class RMulOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RMulOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
......@@ -55,6 +63,7 @@ template <class Context>
class RMulGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RMulGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
......@@ -21,6 +26,7 @@ class PowOp: public Operator<Context> {
power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -39,6 +45,7 @@ class PowGradientOp final : public Operator<Context> {
power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
......@@ -18,6 +23,7 @@ class ScaleOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -35,6 +41,7 @@ class ScaleGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void BiasRunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class SquareOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SquareOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -24,6 +30,7 @@ template <class Context>
class SquareGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SquareGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class SubOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SubOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
......@@ -28,6 +34,7 @@ template <class Context>
class SubGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SubGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......@@ -42,6 +49,7 @@ template <class Context>
class RSubOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RSubOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
......@@ -55,6 +63,7 @@ template <class Context>
class RSubGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RSubGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_CAST_FLOAT2HALF_OP_H_
#define DRAGON_OPERATORS_CAST_FLOAT2HALF_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class FloatToHalfOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(FloatToHalfOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
};
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COMPARE_OP_H_
......@@ -17,6 +22,7 @@ class CompareOp final : public Operator<Context> {
CompareOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void EqualRunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_COPY_OP_H_
......@@ -15,6 +20,8 @@ template <class Context>
class CopyOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CopyOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_SCAN_OP_H_
......@@ -26,6 +31,7 @@ class ScanOp final: public Operator<Context> {
debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) {
InitTemplate();
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
void InitTemplate();
......@@ -56,14 +62,15 @@ class ScanGradientOp final: public Operator<Context> {
forward_outputs(OperatorBase::GetRepeatedArg<string>("outputs_name")) {
// handle GO(x)
for (int i = 0; i < forward_outputs.size(); i++)
terms[forward_outputs[i] + "_grad"] = input(i + (int)OutputSize()).name();
terms[forward_outputs[i] + "_grad"] = Input(i + (int)OutputSize()).name();
// handle GI(x)
for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = output(i)->name();
terms[forward_inputs[i] + "_grad"] = Output(i)->name();
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
void MakeGradientOps();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
......@@ -17,6 +22,7 @@ class L1LossOp : public Operator<Context> {
L1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -32,6 +38,7 @@ class L1LossGradientOp final : public Operator<Context> {
L1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
......@@ -17,6 +22,7 @@ class L2LossOp : public Operator<Context> {
L2LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -32,6 +38,7 @@ class L2LossGradientOp final : public Operator<Context> {
L2LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
......@@ -17,6 +22,7 @@ class SigmoidCrossEntropyOp final : public Operator<Context> {
SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -32,6 +38,7 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
......@@ -20,6 +25,7 @@ class SmoothL1LossOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {
sigma2 *= sigma2;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -39,6 +45,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {
sigma2 *= sigma2;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
......@@ -19,13 +24,14 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ input(0).name() }),
vector<string>({ Input(0).name() }),
vector<string>({ "/mnt/" + anchor() + "/softmax_prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws));
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -45,6 +51,7 @@ class SoftmaxCrossEntropyGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
......@@ -14,7 +19,7 @@ namespace dragon {
template <class Context>
class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
public:
SparseSoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
SparseSoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
......@@ -25,13 +30,14 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ input(0).name() }),
vector<string>({ Input(0).name() }),
vector<string>({ "/mnt/" + anchor() + "/softmax_prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws));
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -47,7 +53,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
template <class Context>
class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
public:
SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
......@@ -58,6 +64,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_
......@@ -14,8 +19,8 @@ namespace dragon {
template <class Context>
class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Context> {
public:
SparseSoftmaxFocalLossOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(op_def, ws),
SparseSoftmaxFocalLossOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
alpha(OperatorBase::GetSingleArg<float>("alpha", 0.5)),
......@@ -23,7 +28,8 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {
pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0;
}
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -40,13 +46,14 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
template <class Context>
class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyGradientOp<Context> {
public:
SparseSoftmaxFocalLossGradientOp(const OperatorDef& op_def, Workspace* ws)
SparseSoftmaxFocalLossGradientOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyGradientOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
#define DRAGON_OPERATORS_MISC_ACCURACY_OP_H_
......@@ -25,6 +30,7 @@ class AccuracyOp final: public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_
#define DRAGON_OPERATORS_MISC_GRADIENT_GENERATE_OP_H_
......@@ -21,6 +26,7 @@ class GradientGenerateOp final: public Operator<Context> {
CHECK_EQ(defaults.size(), OutputSize());
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -35,9 +41,10 @@ class GradientGatherOp final : public Operator<Context> {
GradientGatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
for (int i = 0; i < InputSize(); i++)
if (input(i).name() != "ignore") indices.push_back(i);
if (Input(i).name() != "ignore") indices.push_back(i);
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -53,6 +60,7 @@ class StopGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
};
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
......@@ -35,6 +40,7 @@ class ImageDataOp final : public Operator<Context> {
std.mutable_data<float, CPUContext>()[i] = std_values[i];
}
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
#define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
......@@ -20,6 +25,7 @@ class InitializeOp: public Operator<Context> {
shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) {
GET_ARGUMENTS_WITH_DESC(int, dims);
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -38,6 +44,7 @@ public:
this->filler.set_type("constant");
this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0));
}
USE_OPERATOR_FUNCTIONS(Context);
};
template <class Context>
......@@ -49,6 +56,7 @@ public:
this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0));
this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0));
}
USE_OPERATOR_FUNCTIONS(Context);
};
template <class Context>
......@@ -60,6 +68,7 @@ public:
this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0));
}
USE_OPERATOR_FUNCTIONS(Context);
};
template <class Context>
......@@ -75,6 +84,7 @@ public:
this->filler.set_low(mu - 2 * sigma);
this->filler.set_high(mu + 2 * sigma);
}
USE_OPERATOR_FUNCTIONS(Context);
};
template <class Context>
......@@ -95,6 +105,7 @@ public:
}
this->filler.set_scale(scale);
}
USE_OPERATOR_FUNCTIONS(Context);
};
template <class Context>
......@@ -115,6 +126,7 @@ public:
}
this->filler.set_scale(scale);
}
USE_OPERATOR_FUNCTIONS(Context);
};
DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims);
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_PYTHON_OP_H_
#define DRAGON_OPERATORS_MISC_PYTHON_OP_H_
......@@ -11,19 +16,13 @@
#include "core/operator.h"
#ifdef WITH_PYTHON3
#define PyBytes_FromStringAndSize PyUnicode_FromStringAndSize
#endif
namespace dragon {
template <class Context>
class RunOp : public Operator<Context> {
public:
RunOp(const OperatorDef& op_def, Workspace* ws);
PyObject* String(const char* str) {
return PyBytes_FromStringAndSize(str, string(str).size());
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......@@ -37,6 +36,7 @@ class TemplateOp : public RunOp<Context> {
public:
TemplateOp(const OperatorDef& op_def, Workspace* ws)
: RunOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
};
template <class Context>
......@@ -46,10 +46,11 @@ public:
: TemplateOp<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MISC_PYTHON_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
......@@ -51,6 +56,9 @@ class ModelMPIBase : public Operator<Context> {
string dtype;
};
#define USE_MPIMODEL_FUNCTIONS(context) \
using ModelMPIBase<context>::mpi_dtype
} // namespace dragon
#endif // WITH_MPI
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
......@@ -18,6 +23,8 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
public:
MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,6 +37,8 @@ public:
: ModelMPIBase<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
......@@ -18,6 +23,8 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
public:
MPIGatherOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,6 +37,8 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> {
: ModelMPIBase<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
......@@ -21,6 +26,7 @@ class ArangeOp final : public Operator<Context> {
GET_ARGUMENT_WITH_DESC(int, stop, 0);
GET_ARGUMENT_WITH_DESC(int, step, 1);
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMAX_OP_H_
......@@ -19,6 +24,7 @@ class ArgmaxOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_ARGMIN_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ARGMIN_OP_H_
......@@ -19,6 +24,7 @@ class ArgminOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CONCAT_OP_H_
......@@ -18,6 +23,7 @@ class ConcatOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -37,6 +43,7 @@ class ConcatGradientOp : public Operator<Context> {
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
#define DRAGON_OPERATORS_NDARRAY_CROP_OP_H_
......@@ -22,6 +27,7 @@ class CropOp: public Operator<Context> {
offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
void RunOnDevice() override;
......@@ -49,6 +55,7 @@ class CropGradientOp final : public Operator<Context > {
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
#define DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
......@@ -17,6 +22,7 @@ class ExpandDimsOp final : public Operator<Context> {
ExpandDimsOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......@@ -31,6 +37,7 @@ class ExpandDimsGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
};
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_
#define DRAGON_OPERATORS_NDARRAY_FLATTEN_OP_H_
......@@ -19,6 +24,7 @@ class FlattenOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
keep_axes(OperatorBase::GetSingleArg<int>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
void SqueezeRun();
......@@ -35,6 +41,7 @@ class FlattenGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
};
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
......@@ -17,6 +22,7 @@ class GatherOp final : public Operator<Context> {
GatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -33,6 +39,7 @@ class GatherGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_ONE_HOT_OP_H_
......@@ -19,6 +24,7 @@ class OneHotOp final : public Operator < Context > {
depth(OperatorBase::GetSingleArg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
#define DRAGON_OPERATORS_NDARRAY_PAD_OP_H_
......@@ -30,6 +35,7 @@ class PadOp final : public Operator<Context> {
}
std::sort(process_axes.begin(), process_axes.end());
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void ConstRunWithType();
......@@ -65,6 +71,7 @@ class PadGradientOp final : public Operator<Context> {
std::reverse(process_axes.begin(), process_axes.end());
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void ConstRunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_RANDOM_PICK_OP_H_
......@@ -18,6 +23,7 @@ class RandomPickOp : public Operator<Context> {
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
max_samples(OperatorBase::GetSingleArg<int>("max_samples", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -37,6 +43,7 @@ public:
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REDUCE_OP_H_
......@@ -19,6 +24,7 @@ class ReduceOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void SumRunWithType();
......@@ -38,6 +44,7 @@ class ReduceGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void SumRunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_REPEAT_OP_H_
......@@ -19,6 +24,7 @@ class RepeatOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1);
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template<typename T> void RunWithType();
......@@ -36,6 +42,7 @@ class RepeatGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1);
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template<typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
......@@ -19,6 +24,7 @@ class ReshapeOp final : public Operator<Context> {
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape);
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......@@ -35,6 +41,7 @@ class ReshapeGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
};
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SHAPE_OP_H_
......@@ -15,6 +20,7 @@ template <class Context>
class ShapeOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ShapeOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
};
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_SLICE_OP_H_
......@@ -18,6 +23,7 @@ class SliceOp : public Operator<Context> {
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -38,6 +44,7 @@ class SliceGradientOp final : public Operator<Context> {
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
#define DRAGON_OPERATORS_NDARRAY_STACK_OP_H_
......@@ -18,6 +23,7 @@ class StackOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -35,6 +41,7 @@ class StackGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TILE_OP_H_
......@@ -18,6 +23,7 @@ class TileOp : public Operator<Context> {
: Operator<Context>(op_def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples);
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template<typename T> void TileRunWithType();
......@@ -36,6 +42,7 @@ class TileGradientOp : public Operator<Context> {
GET_ARGUMENTS_WITH_DESC(int, multiples);
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template<typename T> void TileRunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_NDARRAY_TRANSPOSE_OP_H_
......@@ -20,6 +25,7 @@ class TransposeOp final: public Operator<Context> {
if (perms.size() > 0) reverse_dims = false;
else reverse_dims = true;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -34,8 +40,9 @@ class TransposeOp final: public Operator<Context> {
template <class Context>
class TransposeGradientOp final : public Operator<Context> {
public:
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws)
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
......@@ -25,6 +30,7 @@ class BatchNormOp : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -54,6 +60,7 @@ class BatchNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -80,6 +87,7 @@ class FusedBatchNormOp : public Operator<Context> {
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -106,6 +114,7 @@ class FusedBatchNormGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -142,6 +151,7 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNBatchNormOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -172,6 +182,7 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context>
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNBatchNormGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
......@@ -29,6 +34,7 @@ class BatchRenormOp : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -59,6 +65,7 @@ class BatchRenormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
......@@ -26,6 +31,7 @@ class GroupNormOp : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -56,6 +62,7 @@ class GroupNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -83,6 +90,7 @@ class FusedGroupNormOp : public Operator<Context> {
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -110,6 +118,7 @@ class FusedGroupNormGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
......@@ -22,6 +27,7 @@ class InstanceNormOp : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -46,6 +52,7 @@ class InstanceNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_
#define DRAGON_OPERATORS_NORM_L2_NORM_H_
......@@ -20,6 +25,7 @@ class L2NormOp final : public Operator<Context> {
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-5))),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -41,6 +47,7 @@ class L2NormGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// -// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
......@@ -17,6 +22,7 @@ class LSTMUnitOp : public Operator<Context> {
LSTMUnitOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
has_cont(OperatorBase::GetSingleArg<string>("cont_t", "")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,10 +36,11 @@ class LSTMUnitOp : public Operator<Context> {
template <class Context>
class LSTMUnitGradientOp : public Operator<Context> {
public:
LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws)
LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
this->allow_share_grads_ = false;
}
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
......@@ -17,9 +22,11 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
t(0),
eps(param("eps")),
beta1(param("beta1")),
beta2(param("beta2")) {}
eps(Param("eps")),
beta1(Param("beta1")),
beta2(Param("beta2")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_COLLECTIVE_UPDATE_OP_H_
......@@ -16,12 +21,13 @@ namespace dragon {
template <class Context>
class CollectiveUpdateOp : public Operator<Context> {
public:
CollectiveUpdateOp(const OperatorDef& op_def, Workspace* ws)
CollectiveUpdateOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
mode(OperatorBase::GetSingleArg<string>("mode", "UNKNOWN")) {
InitMPI();
if (mode.find("NCCL") != string::npos) InitNCCL();
}
}
USE_OPERATOR_FUNCTIONS(Context);
void InitMPI();
void InitNCCL();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
#define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
......@@ -17,6 +22,7 @@ class MovingAverageOp final : public Operator<Context> {
MovingAverageOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_NESTEROV_UPDATE_OP_H_
......@@ -14,11 +19,13 @@ namespace dragon {
template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
momentum(param("momentum")) {}
momentum(Param("momentum")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
void ComputeRunWithFloat() override;
protected:
float lr, momentum;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
......@@ -16,8 +21,10 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
eps(param("eps")),
decay(param("decay")) {}
eps(Param("eps")),
decay(Param("decay")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
......@@ -16,7 +21,9 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
public:
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
momentum(param("momentum")) {}
momentum(Param("momentum")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
#define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
......@@ -19,8 +24,9 @@ class UpdateOpBase : public Operator<Context> {
lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)),
domain(OperatorBase::GetSingleArg<string>("domain", "_")) {}
USE_OPERATOR_FUNCTIONS(Context);
float param(const string& name) const;
float Param(const string& name) const;
void RunOnDevice() override;
template <typename T> void PreprocessRunWithType();
......@@ -33,6 +39,9 @@ class UpdateOpBase : public Operator<Context> {
string domain;
};
#define USE_UPDATER_FUNCTIONS(context) \
using UpdateOpBase<context>::Param
} // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_BILINEAR_RESIZE_OP_H_
#define DRAGON_OPERATORS_VISION_BILINEAR_RESIZE_OP_H_
......@@ -25,6 +30,8 @@ class BilinearResizeOp : public Operator<Context> {
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -41,6 +48,7 @@ class BilinearResizeGradientOp : public Operator<Context> {
BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_CONV_OP_H_
#define DRAGON_OPERATORS_VISION_CONV_OP_H_
......@@ -19,9 +24,11 @@ class Conv2dOp : public ConvOpBase<Context> {
this->num_spatial_axes = 2;
Setup();
}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
bool ReverseDimensions() override { return false; }
virtual bool HasBias() { return InputSize() > 2; }
bool HasBias() override { return InputSize() > 2; }
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -32,8 +39,10 @@ class Conv2dGradientOp : public Conv2dOp<Context> {
public:
Conv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
bool HasBias() override { return output(2)->name() != "ignore"; }
bool HasBias() override { return Output(2)->name() != "ignore"; }
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -70,6 +79,8 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dOp() {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
......@@ -124,6 +135,8 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dGradientOp() {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_
#define DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_
......@@ -29,6 +34,7 @@ class ConvOpBase : public Operator<Context> {
else LOG(FATAL) << "Unknown data format: " << data_format;
num_spatial_axes = -1; // unknown
}
USE_OPERATOR_FUNCTIONS(Context);
protected:
vector<TIndex> kernel_size, stride, pad, dilation;
......@@ -50,6 +56,7 @@ class ConvOpBase : public Operator<Context> {
void GradientReshape();
virtual void ComputeOutputShape();
virtual bool ReverseDimensions() = 0;
virtual bool HasBias() = 0;
template <typename T> void Wx(const T* x, const T* weights, T* y, bool skip_im2col = false);
template <typename T> void Pb(const T* bias, T* y);
......@@ -59,7 +66,7 @@ class ConvOpBase : public Operator<Context> {
private:
template <typename T> void Im2Col(const T* im, T* col) {
if (input(0).ndim() == 4) {
if (Input(0).ndim() == 4) {
kernel::Im2Col2d<T, Context>(conv_in_channels,
input_shape[0], input_shape[1],
output_shape[0], output_shape[1],
......@@ -73,7 +80,7 @@ class ConvOpBase : public Operator<Context> {
} else LOG(FATAL) << "ConvNd has not been implemented yet";
}
template <typename T> void Col2Im(const T* col, T* im) {
if (input(0).ndim() == 4) {
if (Input(0).ndim() == 4) {
kernel::Col2Im2d<T, Context>(conv_in_channels,
input_shape[0], input_shape[1],
output_shape[0], output_shape[1],
......@@ -90,6 +97,19 @@ class ConvOpBase : public Operator<Context> {
DEFINE_ARGUMENTS_WITH_DESC(int, ConvOpBase, output_dims);
#define USE_CONVOLUTION_FUNCTIONS(context) \
using ConvOpBase<context>::Setup; \
using ConvOpBase<context>::Reshape; \
using ConvOpBase<context>::GradientReshape; \
using ConvOpBase<context>::ComputeOutputShape; \
using ConvOpBase<Context>::ReverseDimensions; \
using ConvOpBase<Context>::HasBias; \
using ConvOpBase<context>::Wx; \
using ConvOpBase<context>::Pb; \
using ConvOpBase<context>::Dx; \
using ConvOpBase<context>::Dw; \
using ConvOpBase<context>::Db
} // namespace dragon
#endif // DRAGON_OPERATORS_VISION_CONV_OP_BASE_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
......@@ -19,9 +24,11 @@ class Conv2dTransposeOp: public ConvOpBase<Context> {
this->num_spatial_axes = 2;
Setup();
}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
bool ReverseDimensions() override { return true; }
virtual bool HasBias() { return InputSize() > 2; }
bool HasBias() override { return InputSize() > 2; }
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -36,8 +43,10 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> {
public:
Conv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
bool HasBias() override { return output(2)->name() != "ignore"; }
bool HasBias() override { return Output(2)->name() != "ignore"; }
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -73,6 +82,8 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dTransposeOp() {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
......@@ -127,6 +138,8 @@ public:
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dTransposeGradientOp() {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
#define DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
......@@ -14,8 +19,9 @@ namespace dragon {
template <class Context>
class DenseConcatOp final : public ConcatOp<Context> {
public:
DenseConcatOp(const OperatorDef& op_def, Workspace* ws)
: ConcatOp<Context>(op_def, ws) {}
DenseConcatOp(const OperatorDef& op_def, Workspace* ws)
: ConcatOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
};
template <class Context>
......@@ -24,12 +30,13 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> {
DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: ConcatGradientOp<Context>(op_def, ws),
growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ElimateCorruption() override;
template <typename T> void RestoreX1();
protected:
TIndex growth_rate;
TIndex growth_rate;
};
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_LRN_OP_H_
#define DRAGON_OPERATORS_VISION_LRN_OP_H_
......@@ -24,6 +29,7 @@ class LRNOp : public Operator<Context> {
k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -54,6 +60,7 @@ class LRNGradientOp : public Operator<Context> {
k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -90,6 +97,7 @@ class CuDNNLRNOp : public LRNOp<Context> {
this->beta,
this->k));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNLRNOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -118,6 +126,7 @@ class CuDNNLRNGradientOp : public LRNGradientOp<Context > {
this->beta,
this->k));
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNLRNGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_NN_RESIZE_OP_H_
#define DRAGON_OPERATORS_VISION_NN_RESIZE_OP_H_
......@@ -25,6 +30,7 @@ class NNResizeOp : public Operator<Context> {
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -42,6 +48,7 @@ class NNResizeGradientOp : public Operator<Context> {
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_POOLING_OP_H_
#define DRAGON_OPERATORS_VISION_POOLING_OP_H_
......@@ -35,6 +40,7 @@ class Pooling2dOp: public Operator <Context> {
}
}
}
USE_OPERATOR_FUNCTIONS(Context);
void Reshape();
void RunOnDevice() override;
......@@ -73,6 +79,7 @@ class Pooling2dGradientOp: public Operator<Context> {
}
}
}
USE_OPERATOR_FUNCTIONS(Context);
void Reshape();
void RunOnDevice() override;
......@@ -107,6 +114,7 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNPooling2dOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -156,6 +164,7 @@ class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> {
this->stride[0], this->stride[1]));
#endif
}
USE_OPERATOR_FUNCTIONS(Context);
~CuDNNPooling2dGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_ROI_ALIGN_OP_H_
#define DRAGON_OPERATORS_VISION_ROI_ALIGN_OP_H_
......@@ -23,6 +28,7 @@ class ROIAlignOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -44,6 +50,7 @@ class ROIAlignGradientOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_ROI_POOLING_OP_H_
#define DRAGON_OPERATORS_VISION_ROI_POOLING_OP_H_
......@@ -22,6 +27,7 @@ class ROIPoolingOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -40,6 +46,7 @@ class ROIPoolingGradientOp final : public Operator<Context> {
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!