Commit 081258b9 by Ting PAN

Preliminary RNN & LSTM & GRU Support

1 parent fe161546
Showing with 943 additions and 529 deletions
...@@ -27,7 +27,8 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \ ...@@ -27,7 +27,8 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \
six \ six \
Pillow Pillow
matplotlib \ matplotlib \
pyyaml pyyaml \
cython
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/3rdparty.zip && \ RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/3rdparty.zip && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \ unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \
......
FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN rm /etc/apt/sources.list.d/cuda.list && rm /etc/apt/sources.list.d/nvidia-ml.list && \
apt-get update && apt-get install -y --no-install-recommends \
build-essential \ build-essential \
cmake \ cmake \
git \ git \
...@@ -29,7 +30,8 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \ ...@@ -29,7 +30,8 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \
six \ six \
Pillow \ Pillow \
matplotlib \ matplotlib \
pyyaml pyyaml \
cython
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/3rdparty.zip && \ RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/3rdparty.zip && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \ unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \
......
...@@ -171,7 +171,7 @@ endif() ...@@ -171,7 +171,7 @@ endif()
# ---[ Flags # ---[ Flags
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if(WIN32) if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /O2 /Oi /GL /Ot /Gy") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /O2")
if (WITH_OMP) if (WITH_OMP)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp")
endif() endif()
......
...@@ -39,7 +39,6 @@ using std::vector; ...@@ -39,7 +39,6 @@ using std::vector;
using std::pair; using std::pair;
using std::set; using std::set;
using std::map; using std::map;
using std::mutex;
using std::unique_ptr; using std::unique_ptr;
using std::shared_ptr; using std::shared_ptr;
...@@ -49,7 +48,7 @@ using Map = std::unordered_map<Key, Value>; ...@@ -49,7 +48,7 @@ using Map = std::unordered_map<Key, Value>;
template <typename Value> template <typename Value>
using Set = std::unordered_set<Value> ; using Set = std::unordered_set<Value> ;
#define DRAGON_VERSION 2204 #define DRAGON_VERSION 2205
#define CONCATENATE_IMPL(s1, s2) s1##s2 #define CONCATENATE_IMPL(s1, s2) s1##s2
#define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2) #define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2)
......
...@@ -84,6 +84,9 @@ static inline std::mt19937* rand_generator() { ...@@ -84,6 +84,9 @@ static inline std::mt19937* rand_generator() {
return CPUContext::cpu_object_.rand_generator.get(); return CPUContext::cpu_object_.rand_generator.get();
} }
#define CPU_FP16_NOT_SUPPORTED \
LOG(FATAL) << "FP16 is unsupported for CPUContext.";
} // namepsace dragon } // namepsace dragon
#endif // DRAGON_CORE_CONTEXT_H_ #endif // DRAGON_CORE_CONTEXT_H_
\ No newline at end of file
...@@ -134,7 +134,8 @@ class CUDAContext { ...@@ -134,7 +134,8 @@ class CUDAContext {
DeviceGuard gurad(gpu_id_); DeviceGuard gurad(gpu_id_);
CUBLAS_CHECK(cublasCreate_v2(&handle)); CUBLAS_CHECK(cublasCreate_v2(&handle));
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); if (TENSOR_CORE_AVAILABLE())
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
#endif #endif
return handle; return handle;
} }
...@@ -165,6 +166,8 @@ class CUDAContext { ...@@ -165,6 +166,8 @@ class CUDAContext {
} }
#endif #endif
static std::mutex& mutex() { static std::mutex m; return m; }
static CUDAObject cuda_object_; static CUDAObject cuda_object_;
private: private:
......
...@@ -154,12 +154,13 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws); ...@@ -154,12 +154,13 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
using OperatorBase::InputSize; \ using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \ using OperatorBase::OutputSize; \
using OperatorBase::DebugString; \ using OperatorBase::DebugString; \
using OperatorBase::DTypeHelper \ using OperatorBase::DTypeHelper; \
using OperatorBase::SwitchToPhase
#define USE_OPERATOR_FUNCTIONS(context) \ #define USE_OPERATOR_FUNCTIONS \
USE_OPERATOR_BASE_FUNCTIONS; \ USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<context>::ctx; \ using Operator<Context>::ctx; \
using Operator<context>::AllowRun using Operator<Context>::AllowRun
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*); DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
...@@ -189,11 +190,19 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -189,11 +190,19 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
ptr_tensor = ws()->CreateTensor("/share/multiplier"); \ ptr_tensor = ws()->CreateTensor("/share/multiplier"); \
if (size > ptr_tensor->count()) { \ if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape(vector<TIndex>(1, size)); \ ptr_tensor->Reshape(vector<TIndex>(1, size)); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.0f), \ math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
ptr_tensor->template mutable_data<T, Context>()); \ ptr_tensor->template mutable_data<T, Context>()); \
} \ } \
} }
#define DECLARE_MULTIPLIER(name, size) \
const T* name; \
{ \
Tensor* _auto_multiplier_; \
INIT_MULTIPLIER(_auto_multiplier_, size); \
name = _auto_multiplier_->template data<T, Context>(); \
}
#define DECLARE_ARGUMENT_WITH_DESC(type, argument) \ #define DECLARE_ARGUMENT_WITH_DESC(type, argument) \
type argument##_value; \ type argument##_value; \
string argument##_desc; \ string argument##_desc; \
......
...@@ -23,14 +23,15 @@ typedef size_t TSize; ...@@ -23,14 +23,15 @@ typedef size_t TSize;
class Tensor { class Tensor {
public: public:
Tensor() {} Tensor() {}
Tensor(const vector<TIndex>& dims) { Reshape(dims); }
Tensor(const string& name) : name_(name) {} Tensor(const string& name) : name_(name) {}
void Reshape(const vector<TIndex>& dims) { void Reshape(const vector<TIndex>& dims) {
dims_ = dims; dims_ = dims;
TIndex new_size = 1; TIndex new_size = 1;
for (auto d : dims_) { for (auto d : dims_) {
CHECK_GT(d, 0); CHECK_GE(d, 0);
new_size *= d; if (d > 0) new_size *= d;
} }
if (own_mem_) { if (own_mem_) {
if (size_ != new_size && if (size_ != new_size &&
......
...@@ -18,26 +18,27 @@ ...@@ -18,26 +18,27 @@
namespace dragon { namespace dragon {
#define WORKSPACE_COMMON_BUFFER_SIZE 2
#define WORKSPACE_MAX_CORRUPTED_SIZE 2 #define WORKSPACE_MAX_CORRUPTED_SIZE 2
class Workspace { class Workspace {
public: public:
typedef Map<string, Workspace*> WorkspaceMap; typedef Map<string, Workspace*> WorkspaceMap;
typedef Map<string, unique_ptr<Tensor> > TensorMap; typedef Map<string, unique_ptr<Tensor> > TensorMap;
typedef Map<string, stack<string> > BufferMap;
typedef Map<string, unique_ptr<mutex> > LockMap;
typedef Map<string, unique_ptr<OperatorBase> > OperatorMap; typedef Map<string, unique_ptr<OperatorBase> > OperatorMap;
typedef Map<string, unique_ptr<GraphBase> > GraphMap; typedef Map<string, unique_ptr<GraphBase> > GraphMap;
typedef Map<string, TensorFiller> FillerMap; typedef Map<string, TensorFiller> FillerMap;
typedef Map<string, string> RenameMap; typedef Map<string, string> RenameMap;
Workspace(const string& name) : name_(name) { Init(); } Workspace(const string& name) : name_(name) { InitWorkspace(); }
~Workspace(); ~Workspace();
void Init() { inline const string& name() { return name_; }
/******************** Workspace ********************/
inline void InitWorkspace() {
CreateTensor("ignore"); CreateTensor("ignore");
CreateBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE);
Tensor* head = CreateTensor("/opt/mirror_stage/head"); Tensor* head = CreateTensor("/opt/mirror_stage/head");
head->Reshape(vector<TIndex>(1, WORKSPACE_MAX_CORRUPTED_SIZE)); head->Reshape(vector<TIndex>(1, WORKSPACE_MAX_CORRUPTED_SIZE));
Tensor* recompute_flag = CreateTensor("/opt/mirror_stage/recompute_flag"); Tensor* recompute_flag = CreateTensor("/opt/mirror_stage/recompute_flag");
...@@ -50,10 +51,6 @@ class Workspace { ...@@ -50,10 +51,6 @@ class Workspace {
} }
} }
inline const string& name() { return name_; }
/******************** Workspace ********************/
inline Workspace* MoveWorkspace(Workspace* ws) { inline Workspace* MoveWorkspace(Workspace* ws) {
CHECK(ws) << "The given Workspace is invalid."; CHECK(ws) << "The given Workspace is invalid.";
if (workspace_map_.count(ws->name())) if (workspace_map_.count(ws->name()))
...@@ -62,11 +59,9 @@ class Workspace { ...@@ -62,11 +59,9 @@ class Workspace {
} }
inline void ClearWorkspace() { inline void ClearWorkspace() {
// clear tensors & buffers // clear tensors & buffers & re-initialization
for (auto& kv : tensor_map_) kv.second->Reset(); for (auto& kv : tensor_map_) kv.second->Reset();
ResetBuffers("Common"); InitWorkspace();
// Re-Initialization
Init();
} }
/******************** Tensor ********************/ /******************** Tensor ********************/
...@@ -112,20 +107,6 @@ class Workspace { ...@@ -112,20 +107,6 @@ class Workspace {
return tensor; return tensor;
} }
inline void LockTensor(const string& name) {
string query = GetTensorName(name);
if (!lock_map_.count(query))
lock_map_[query] = unique_ptr<mutex>(new mutex);
lock_map_[query]->lock();
}
inline void UnlockTensor(const string& name) {
string query = GetTensorName(name);
if (!lock_map_.count(query))
lock_map_[query] = unique_ptr<mutex>(new mutex);
lock_map_[query]->unlock();
}
inline void ResetTensor(const string& name) { inline void ResetTensor(const string& name) {
Tensor* tensor = TryGetTensor(name, false); Tensor* tensor = TryGetTensor(name, false);
CHECK(tensor) << "\nTensor(" << name << ") does not " CHECK(tensor) << "\nTensor(" << name << ") does not "
...@@ -179,49 +160,32 @@ class Workspace { ...@@ -179,49 +160,32 @@ class Workspace {
return nullptr; return nullptr;
} }
/******************** Buffer ********************/ /******************** Cache ********************/
inline void CreateBuffer(string category, int num) { template <class Context>
if (!buffer_map_.count(category)) inline vector<void*> caches(const vector<size_t>& segments) {
buffer_map_[category] = stack<string>(); TIndex total_size = 0;
for (int i = 1; i <= num; i++) { for (auto& segment : segments) total_size += (TIndex)segment;
string name = "/share/buffer/" + category + "_" + dragon_cast<string, int>(i); Tensor* cacheT = CreateTensor("/share/cache");
buffer_map_[category].push(name); cacheT->Reshape(vector<TIndex>(1, total_size));
CreateTensor(name); vector<void*> caches(segments.size());
} caches[0] = cacheT->template mutable_data<uint8_t, Context>();
for (int i = 1; i < segments.size(); i++)
caches[i] = (uint8_t*)caches[i - 1] + segments[i - 1];
return caches;
} }
inline Tensor* GetBuffer(string category = "Common") { template <typename T, class Context>
if (!buffer_map_[category].empty()) { inline vector<T*> caches(const vector<TIndex>& segments) {
string name = buffer_map_[category].top(); TIndex total_count = 0;
buffer_map_[category].pop(); for (auto& segment : segments) total_count += segment;
return tensor_map_[name].get(); Tensor* cacheT = CreateTensor("/share/cache");
} cacheT->Reshape(vector<TIndex>(1, total_count));
LOG(FATAL) << "Buffers of [" << category << "] " vector<T*> caches(segments.size());
<< "are not enough, add more if necessary."; caches[0] = cacheT->template mutable_data<T, Context>();
return nullptr; for (int i = 1; i < segments.size(); i++)
} caches[i] = caches[i - 1] + segments[i - 1];
return caches;
inline void ReleaseBuffer(Tensor* tensor,
string category = "Common",
bool enforce = false) {
static Map<string, int> limits = {
{ "Common", WORKSPACE_COMMON_BUFFER_SIZE }};
if (buffer_map_[category].size() >= limits[category] || enforce) {
ResetTensor(tensor->name());
if (buffer_map_[category].empty())
buffer_map_[category].push(tensor->name());
} else {
buffer_map_[category].push(tensor->name());
}
}
inline void ResetBuffers(string category) {
while (!buffer_map_[category].empty()) {
string name = buffer_map_[category].top();
buffer_map_[category].pop();
tensor_map_[name]->Reset();
}
} }
/******************** Operator ********************/ /******************** Operator ********************/
...@@ -297,8 +261,6 @@ class Workspace { ...@@ -297,8 +261,6 @@ class Workspace {
string name_; string name_;
WorkspaceMap workspace_map_; WorkspaceMap workspace_map_;
TensorMap tensor_map_; TensorMap tensor_map_;
BufferMap buffer_map_;
LockMap lock_map_;
OperatorMap op_map_; OperatorMap op_map_;
GraphMap graph_map_; GraphMap graph_map_;
FillerMap filler_map_; FillerMap filler_map_;
......
...@@ -24,8 +24,9 @@ class DropoutOp final : public Operator<Context> { ...@@ -24,8 +24,9 @@ class DropoutOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) { use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -42,8 +43,9 @@ class DropoutGradientOp final : public Operator<Context> { ...@@ -42,8 +43,9 @@ class DropoutGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) { use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -57,6 +59,79 @@ class DropoutGradientOp final : public Operator<Context> { ...@@ -57,6 +59,79 @@ class DropoutGradientOp final : public Operator<Context> {
DEFINE_ARGUMENT_WITH_DESC(float, DropoutOp, prob); DEFINE_ARGUMENT_WITH_DESC(float, DropoutOp, prob);
DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob); DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob);
#ifdef WITH_CUDNN
#if CUDNN_VERSION_MIN(7, 0, 0)
template <class Context>
class CuDNNDropoutOp final : public Operator<Context> {
public:
CuDNNDropoutOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)),
random_seed(op_def.device_option().random_seed()) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
}
USE_OPERATOR_FUNCTIONS;
~CuDNNDropoutOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
DECLARE_ARGUMENT_WITH_DESC(float, prob);
bool use_scale, states_initialized;
cudnnTensorDescriptor_t input_desc;
cudnnDropoutDescriptor_t dropout_desc;
size_t states_size, reserve_space_size;
unsigned long long random_seed;
};
template <class Context>
class CuDNNDropoutGradientOp final : public Operator<Context> {
public:
CuDNNDropoutGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)),
random_seed(op_def.device_option().random_seed()) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
}
USE_OPERATOR_FUNCTIONS;
~CuDNNDropoutGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
DECLARE_ARGUMENT_WITH_DESC(float, prob);
bool use_scale, states_initialized;
cudnnTensorDescriptor_t input_desc;
cudnnDropoutDescriptor_t dropout_desc;
size_t states_size, reserve_space_size;
unsigned long long random_seed;
};
DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutOp, prob);
DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutGradientOp, prob);
#endif
#endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_ #endif // DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
\ No newline at end of file
...@@ -22,7 +22,7 @@ class EluOp : public Operator<Context> { ...@@ -22,7 +22,7 @@ class EluOp : public Operator<Context> {
EluOp(const OperatorDef& op_def, Workspace* ws) EluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {} alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -37,7 +37,7 @@ class EluGradientOp : public Operator<Context> { ...@@ -37,7 +37,7 @@ class EluGradientOp : public Operator<Context> {
EluGradientOp(const OperatorDef& op_def, Workspace* ws) EluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {} alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -61,7 +61,7 @@ public: ...@@ -61,7 +61,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha)); CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNEluOp() { ~CuDNNEluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -88,7 +88,7 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> { ...@@ -88,7 +88,7 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha)); CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNEluGradientOp() { ~CuDNNEluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
...@@ -23,15 +23,14 @@ class PReluOp : public Operator<Context> { ...@@ -23,15 +23,14 @@ class PReluOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)), channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
bool channel_shared; TIndex channel_shared, channels, dim;
string data_format; string data_format;
TIndex channels, dim;
}; };
template <class Context> template <class Context>
...@@ -41,16 +40,14 @@ class PReluGradientOp : public Operator<Context> { ...@@ -41,16 +40,14 @@ class PReluGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)), channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
bool channel_shared; TIndex channel_shared, channels, dim;
string data_format; string data_format;
TIndex channels, dim;
Tensor* bcast_dw, *multiplier;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,7 +22,7 @@ class ReluOp : public Operator<Context> { ...@@ -22,7 +22,7 @@ class ReluOp : public Operator<Context> {
ReluOp(const OperatorDef& op_def, Workspace* ws) ReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {} slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -37,7 +37,7 @@ class ReluGradientOp : public Operator<Context> { ...@@ -37,7 +37,7 @@ class ReluGradientOp : public Operator<Context> {
ReluGradientOp(const OperatorDef& op_def, Workspace* ws) ReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {} slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -59,7 +59,7 @@ public: ...@@ -59,7 +59,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNReluOp() { ~CuDNNReluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -86,7 +86,7 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> { ...@@ -86,7 +86,7 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNReluGradientOp() { ~CuDNNReluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
...@@ -20,7 +20,7 @@ template <class Context> ...@@ -20,7 +20,7 @@ template <class Context>
class SEluOp : public Operator<Context> { class SEluOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SEluOp); USE_SIMPLE_CTOR_DTOR(SEluOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,7 +30,7 @@ template <class Context> ...@@ -30,7 +30,7 @@ template <class Context>
class SEluGradientOp : public Operator<Context> { class SEluGradientOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SEluGradientOp); USE_SIMPLE_CTOR_DTOR(SEluGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -20,7 +20,7 @@ template <class Context> ...@@ -20,7 +20,7 @@ template <class Context>
class SigmoidOp : public Operator<Context> { class SigmoidOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SigmoidOp); USE_SIMPLE_CTOR_DTOR(SigmoidOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,7 +30,7 @@ template <class Context> ...@@ -30,7 +30,7 @@ template <class Context>
class SigmoidGradientOp : public Operator<Context> { class SigmoidGradientOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SigmoidGradientOp); USE_SIMPLE_CTOR_DTOR(SigmoidGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNSigmoidOp() { ~CuDNNSigmoidOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -76,7 +76,7 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> { ...@@ -76,7 +76,7 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNSigmoidGradientOp() { ~CuDNNSigmoidGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
...@@ -22,15 +22,13 @@ class SoftmaxOp final : public Operator<Context> { ...@@ -22,15 +22,13 @@ class SoftmaxOp final : public Operator<Context> {
SoftmaxOp(const OperatorDef& op_def, Workspace* ws) SoftmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
int axis; TIndex axis, outer_dim, inner_dim;
TIndex outer_dim, inner_dim;
Tensor* sum_multiplier, *scale;
}; };
template <class Context> template <class Context>
...@@ -39,15 +37,13 @@ class SoftmaxGradientOp final : public Operator<Context> { ...@@ -39,15 +37,13 @@ class SoftmaxGradientOp final : public Operator<Context> {
SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws) SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
int axis; TIndex axis, outer_dim, inner_dim;
TIndex outer_dim, inner_dim;
Tensor* sum_multiplier, *scale;
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
...@@ -63,7 +59,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> { ...@@ -63,7 +59,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNSoftmaxOp() { ~CuDNNSoftmaxOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -88,7 +84,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> { ...@@ -88,7 +84,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNSoftmaxGradientOp() { ~CuDNNSoftmaxGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
...@@ -20,7 +20,7 @@ template <class Context> ...@@ -20,7 +20,7 @@ template <class Context>
class TanhOp : public Operator<Context> { class TanhOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(TanhOp); USE_SIMPLE_CTOR_DTOR(TanhOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,7 +30,7 @@ template <class Context> ...@@ -30,7 +30,7 @@ template <class Context>
class TanhGradientOp : public Operator<Context> { class TanhGradientOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(TanhGradientOp); USE_SIMPLE_CTOR_DTOR(TanhGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNTanhOp() { ~CuDNNTanhOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -76,7 +76,7 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> { ...@@ -76,7 +76,7 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0)); CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNTanhGradientOp() { ~CuDNNTanhGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
...@@ -20,56 +20,44 @@ template <class Context> ...@@ -20,56 +20,44 @@ template <class Context>
class AddOp final : public Operator<Context> { class AddOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(AddOp); USE_SIMPLE_CTOR_DTOR(AddOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class AddGradientOp final : public Operator<Context> { class AddGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(AddGradientOp); USE_SIMPLE_CTOR_DTOR(AddGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class RAddOp final : public Operator<Context> { class RAddOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RAddOp); USE_SIMPLE_CTOR_DTOR(RAddOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class RAddGradientOp final : public Operator<Context> { class RAddGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RAddGradientOp); USE_SIMPLE_CTOR_DTOR(RAddGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -23,7 +23,7 @@ class AffineOp : public Operator<Context> { ...@@ -23,7 +23,7 @@ class AffineOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {} num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -31,7 +31,6 @@ class AffineOp : public Operator<Context> { ...@@ -31,7 +31,6 @@ class AffineOp : public Operator<Context> {
protected: protected:
TIndex axis, start_axis, num_axes; TIndex axis, start_axis, num_axes;
TIndex outer_dim, scale_dim, inner_dim; TIndex outer_dim, scale_dim, inner_dim;
Tensor* bias_multiplier;
}; };
template <class Context> template <class Context>
...@@ -41,7 +40,7 @@ class AffineGradientOp final : public Operator<Context> { ...@@ -41,7 +40,7 @@ class AffineGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {} num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void BiasRunWithType(); template <typename T> void BiasRunWithType();
...@@ -51,10 +50,109 @@ class AffineGradientOp final : public Operator<Context> { ...@@ -51,10 +50,109 @@ class AffineGradientOp final : public Operator<Context> {
protected: protected:
TIndex axis, start_axis, num_axes; TIndex axis, start_axis, num_axes;
TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim; TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim;
Tensor* bias_multiplier, *sum_multiplier;
Tensor sum_result; Tensor sum_result;
}; };
#ifdef WITH_CUDNN
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNAffineOpBase : public Operator<Context> {
public:
CuDNNAffineOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc));
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_desc));
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&add_desc));
CUDNN_CHECK(cudnnCreateReduceTensorDescriptor(&reduce_desc));
}
USE_OPERATOR_FUNCTIONS;
virtual ~CuDNNAffineOpBase() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(param_desc));
CUDNN_CHECK(cudnnDestroyOpTensorDescriptor(mul_desc));
CUDNN_CHECK(cudnnDestroyReduceTensorDescriptor(reduce_desc));
}
template <typename T>
void ResetDesc() {
// determine the range of affine
start_axis = axis;
if (start_axis < 0) start_axis += (int)Input(0).ndim();
if (num_axes == -1) num_axes = (int)Input(0).ndim() - start_axis;
else if (num_axes == 0) num_axes = 1;
end_axis = start_axis + num_axes;
CHECK_LT(start_axis, (int)Input(0).ndim());
CHECK_LE(start_axis + num_axes, (int)Input(0).ndim());
// determine the input desc
vector<TIndex> input_dims = Input(0).dims();
// cudnn requires ndimensions range from [4, 5]
if (input_dims.size() < 4) input_dims.resize(4, 1);
else if (input_dims.size() > 5)
LOG(FATAL) << "CuDNN Affine the dimensions up to 5.";
cudnnSetTensorDesc<T>(&input_desc, input_dims);
// determine the scale desc
vector<TIndex> param_dims(input_dims.size(), 1);
for (int i = start_axis; i < end_axis; i++) param_dims[i] = input_dims[i];
cudnnSetTensorDesc<T>(&param_desc, param_dims);
}
TIndex axis, start_axis, end_axis, num_axes;
cudnnTensorDescriptor_t input_desc, param_desc;
cudnnOpTensorDescriptor_t mul_desc, add_desc;
cudnnReduceTensorDescriptor_t reduce_desc;
};
#define USE_CUDNN_AFFINE_FUCNTIONS \
USE_OPERATOR_FUNCTIONS; \
using CuDNNAffineOpBase<Context>::start_axis; \
using CuDNNAffineOpBase<Context>::num_axes; \
using CuDNNAffineOpBase<Context>::input_desc; \
using CuDNNAffineOpBase<Context>::param_desc; \
using CuDNNAffineOpBase<Context>::mul_desc; \
using CuDNNAffineOpBase<Context>::add_desc; \
using CuDNNAffineOpBase<Context>::reduce_desc
template <class Context>
class CuDNNAffineOp : public CuDNNAffineOpBase<Context> {
public:
CuDNNAffineOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNAffineOpBase<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
USE_CUDNN_AFFINE_FUCNTIONS;
};
template <class Context>
class CuDNNAffineGradientOp : public CuDNNAffineOpBase<Context> {
public:
CuDNNAffineGradientOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNAffineOpBase<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void ComputeScaleGradient(T* dYxX, T* dA);
template <typename T> void ComputeScaleGradient_v2(T* dYxX, T* dA);
template <typename T> void ComputeBiasGradient(const T* dY, T* dB);
template <typename T> void ComputeBiasGradient_v2(const T* dY, T* dB);
template <typename T> void RunWithType();
protected:
USE_CUDNN_AFFINE_FUCNTIONS;
TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim;
Tensor sum_result;
};
#endif // WITH_CUDNN
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_ #endif // DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
\ No newline at end of file
...@@ -24,27 +24,23 @@ class ClipOp final : public Operator<Context> { ...@@ -24,27 +24,23 @@ class ClipOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)), low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)),
high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {} high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float low, high; float low, high;
Tensor* mask;
}; };
template <class Context> template <class Context>
class ClipGradientOp final : public Operator<Context> { class ClipGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ClipGradientOp); USE_SIMPLE_CTOR_DTOR(ClipGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected:
Tensor* mask;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,56 +20,44 @@ template <class Context> ...@@ -20,56 +20,44 @@ template <class Context>
class DivOp final : public Operator<Context> { class DivOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(DivOp); USE_SIMPLE_CTOR_DTOR(DivOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class DivGradientOp final : public Operator<Context> { class DivGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(DivGradientOp); USE_SIMPLE_CTOR_DTOR(DivGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class RDivOp final : public Operator<Context> { class RDivOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RDivOp); USE_SIMPLE_CTOR_DTOR(RDivOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class RDivGradientOp final : public Operator<Context> { class RDivGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RDivGradientOp); USE_SIMPLE_CTOR_DTOR(RDivGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
} // namepsace dragon } // namepsace dragon
......
...@@ -21,9 +21,9 @@ class DotOp final : public Operator<Context> { ...@@ -21,9 +21,9 @@ class DotOp final : public Operator<Context> {
public: public:
DotOp(const OperatorDef& op_def, Workspace* ws) DotOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void DotRunWithType(); template <typename T> void DotRunWithType();
...@@ -31,8 +31,7 @@ class DotOp final : public Operator<Context> { ...@@ -31,8 +31,7 @@ class DotOp final : public Operator<Context> {
template <typename T> void GemvRunWithType(); template <typename T> void GemvRunWithType();
protected: protected:
bool transA, transB; TIndex TransA, TransB, M, K1, K2, N1, N2;
TIndex M, K1, K2, N1, N2;
}; };
template <class Context> template <class Context>
...@@ -40,9 +39,9 @@ class DotGradientOp final : public Operator<Context> { ...@@ -40,9 +39,9 @@ class DotGradientOp final : public Operator<Context> {
public: public:
DotGradientOp(const OperatorDef& op_def, Workspace* ws) DotGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void DotRunWithType(); template <typename T> void DotRunWithType();
...@@ -50,8 +49,7 @@ class DotGradientOp final : public Operator<Context> { ...@@ -50,8 +49,7 @@ class DotGradientOp final : public Operator<Context> {
template <typename T> void GemvRunWithType(); template <typename T> void GemvRunWithType();
protected: protected:
bool transA, transB; TIndex TransA, TransB, M, K1, K2, N1, N2;
TIndex M, K1, K2, N1, N2;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -29,7 +29,7 @@ class EltwiseOp final : public Operator<Context> { ...@@ -29,7 +29,7 @@ class EltwiseOp final : public Operator<Context> {
<< "but provided " << coeffs.size() << " coeffs."; << "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1)); } else coeffs.resize(InputSize(), float(1));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
...@@ -53,7 +53,7 @@ class EltwiseGradientOp final : public Operator<Context> { ...@@ -53,7 +53,7 @@ class EltwiseGradientOp final : public Operator<Context> {
<< "but provided " << coeffs.size() << " coeffs."; << "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1)); } else coeffs.resize(InputSize(), float(1));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
......
...@@ -20,7 +20,7 @@ template <class Context> ...@@ -20,7 +20,7 @@ template <class Context>
class ExpOp final : public Operator<Context> { class ExpOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpOp); USE_SIMPLE_CTOR_DTOR(ExpOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,7 +30,7 @@ template <class Context> ...@@ -30,7 +30,7 @@ template <class Context>
class ExpGradientOp final : public Operator<Context> { class ExpGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpGradientOp); USE_SIMPLE_CTOR_DTOR(ExpGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -22,7 +22,7 @@ class GramMatrixOp final : public Operator<Context> { ...@@ -22,7 +22,7 @@ class GramMatrixOp final : public Operator<Context> {
GramMatrixOp(const OperatorDef& op_def, Workspace* ws) GramMatrixOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -38,7 +38,7 @@ class GramMatrixGradientOp final : public Operator<Context> { ...@@ -38,7 +38,7 @@ class GramMatrixGradientOp final : public Operator<Context> {
GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws) GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -23,17 +23,15 @@ class InnerProductOp: public Operator<Context> { ...@@ -23,17 +23,15 @@ class InnerProductOp: public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)), num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {} TransW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice(); void RunOnDevice();
template <typename T> void TransRunWithType(); template <typename T> void TransRunWithType();
template <typename T> void NoTransRunWithType(); template <typename T> void NoTransRunWithType();
protected: protected:
TIndex axis, num_output, M, K; TIndex axis, num_output, TransW, M, K;
bool transW;
Tensor* bias_multiplier;
}; };
template <class Context> template <class Context>
...@@ -43,16 +41,14 @@ class InnerProductGradientOp final : public Operator<Context> { ...@@ -43,16 +41,14 @@ class InnerProductGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)), num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {} TransW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, num_output, M, K; TIndex axis, num_output, TransW, M, K;
bool transW;
Tensor* bias_multiplier;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,7 +20,7 @@ template <class Context> ...@@ -20,7 +20,7 @@ template <class Context>
class LogOp final : public Operator<Context> { class LogOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(LogOp); USE_SIMPLE_CTOR_DTOR(LogOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,7 +30,7 @@ template <class Context> ...@@ -30,7 +30,7 @@ template <class Context>
class LogGradientOp final : public Operator<Context> { class LogGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(LogGradientOp); USE_SIMPLE_CTOR_DTOR(LogGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -21,17 +21,16 @@ class MatmulOp final : public Operator<Context> { ...@@ -21,17 +21,16 @@ class MatmulOp final : public Operator<Context> {
public: public:
MatmulOp(const OperatorDef& op_def, Workspace* ws) MatmulOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
bool transA, transB; TIndex TransA, TransB, M, K1, K2, N;
TIndex n, x1_offset, x2_offset, y_offset; TIndex n, x1_offset, x2_offset, y_offset;
TIndex M, K1, K2, N;
}; };
template <class Context> template <class Context>
...@@ -39,17 +38,16 @@ class MatmulGradientOp final : public Operator<Context> { ...@@ -39,17 +38,16 @@ class MatmulGradientOp final : public Operator<Context> {
public: public:
MatmulGradientOp(const OperatorDef& op_def, Workspace* ws) MatmulGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
bool transA, transB; TIndex TransA, TransB, M, K1, K2, N;
TIndex n, x1_offset, x2_offset, y_offset; TIndex n, x1_offset, x2_offset, y_offset;
TIndex M, K1, K2, N;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -20,56 +20,44 @@ template <class Context> ...@@ -20,56 +20,44 @@ template <class Context>
class MulOp final : public Operator<Context> { class MulOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MulOp); USE_SIMPLE_CTOR_DTOR(MulOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class MulGradientOp final : public Operator<Context> { class MulGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MulGradientOp); USE_SIMPLE_CTOR_DTOR(MulGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class RMulOp final : public Operator<Context> { class RMulOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RMulOp); USE_SIMPLE_CTOR_DTOR(RMulOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class RMulGradientOp final : public Operator<Context> { class RMulGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RMulGradientOp); USE_SIMPLE_CTOR_DTOR(RMulGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -26,7 +26,7 @@ class PowOp: public Operator<Context> { ...@@ -26,7 +26,7 @@ class PowOp: public Operator<Context> {
power(OperatorBase::GetSingleArg<float>("power", 1.0)) { power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -45,7 +45,7 @@ class PowGradientOp final : public Operator<Context> { ...@@ -45,7 +45,7 @@ class PowGradientOp final : public Operator<Context> {
power(OperatorBase::GetSingleArg<float>("power", 1.0)) { power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale; power_scale = power * scale;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -20,7 +20,7 @@ template <class Context> ...@@ -20,7 +20,7 @@ template <class Context>
class SquareOp final : public Operator<Context> { class SquareOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SquareOp); USE_SIMPLE_CTOR_DTOR(SquareOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,7 +30,7 @@ template <class Context> ...@@ -30,7 +30,7 @@ template <class Context>
class SquareGradientOp final : public Operator<Context> { class SquareGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SquareGradientOp); USE_SIMPLE_CTOR_DTOR(SquareGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -20,56 +20,44 @@ template <class Context> ...@@ -20,56 +20,44 @@ template <class Context>
class SubOp final : public Operator<Context> { class SubOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SubOp); USE_SIMPLE_CTOR_DTOR(SubOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class SubGradientOp final : public Operator<Context> { class SubGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SubGradientOp); USE_SIMPLE_CTOR_DTOR(SubGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class RSubOp final : public Operator<Context> { class RSubOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RSubOp); USE_SIMPLE_CTOR_DTOR(RSubOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
template <class Context> template <class Context>
class RSubGradientOp final : public Operator<Context> { class RSubGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(RSubGradientOp); USE_SIMPLE_CTOR_DTOR(RSubGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,7 +22,7 @@ class CompareOp final : public Operator<Context> { ...@@ -22,7 +22,7 @@ class CompareOp final : public Operator<Context> {
CompareOp(const OperatorDef& op_def, Workspace* ws) CompareOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {} operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EqualRunWithType(); template <typename T> void EqualRunWithType();
......
...@@ -20,7 +20,7 @@ template <class Context> ...@@ -20,7 +20,7 @@ template <class Context>
class CopyOp final : public Operator<Context> { class CopyOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(CopyOp); USE_SIMPLE_CTOR_DTOR(CopyOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -31,7 +31,7 @@ class ScanOp final: public Operator<Context> { ...@@ -31,7 +31,7 @@ class ScanOp final: public Operator<Context> {
debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) { debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) {
InitTemplate(); InitTemplate();
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
void InitTemplate(); void InitTemplate();
...@@ -68,7 +68,7 @@ class ScanGradientOp final: public Operator<Context> { ...@@ -68,7 +68,7 @@ class ScanGradientOp final: public Operator<Context> {
for (int i = 0; i < forward_inputs.size(); i++) for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = Output(i)->name(); terms[forward_inputs[i] + "_grad"] = Output(i)->name();
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
void MakeOps(const GraphDef& forward_def, GraphDef& new_def); void MakeOps(const GraphDef& forward_def, GraphDef& new_def);
......
...@@ -22,7 +22,7 @@ class L1LossOp : public Operator<Context> { ...@@ -22,7 +22,7 @@ class L1LossOp : public Operator<Context> {
L1LossOp(const OperatorDef& op_def, Workspace* ws) L1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -38,7 +38,7 @@ class L1LossGradientOp final : public Operator<Context> { ...@@ -38,7 +38,7 @@ class L1LossGradientOp final : public Operator<Context> {
L1LossGradientOp(const OperatorDef& op_def, Workspace* ws) L1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -22,7 +22,7 @@ class L2LossOp : public Operator<Context> { ...@@ -22,7 +22,7 @@ class L2LossOp : public Operator<Context> {
L2LossOp(const OperatorDef& op_def, Workspace* ws) L2LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -38,7 +38,7 @@ class L2LossGradientOp final : public Operator<Context> { ...@@ -38,7 +38,7 @@ class L2LossGradientOp final : public Operator<Context> {
L2LossGradientOp(const OperatorDef& op_def, Workspace* ws) L2LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -22,7 +22,7 @@ class SigmoidCrossEntropyOp final : public Operator<Context> { ...@@ -22,7 +22,7 @@ class SigmoidCrossEntropyOp final : public Operator<Context> {
SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -38,7 +38,7 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> { ...@@ -38,7 +38,7 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -23,7 +23,7 @@ class SmoothL1LossOp final : public Operator<Context> { ...@@ -23,7 +23,7 @@ class SmoothL1LossOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)), beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -41,7 +41,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> { ...@@ -41,7 +41,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)), beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -24,7 +24,7 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> { ...@@ -24,7 +24,7 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void SoftmaxRun(); void SoftmaxRun();
void RunOnDevice() override; void RunOnDevice() override;
...@@ -45,7 +45,7 @@ class SoftmaxCrossEntropyGradientOp final : public Operator<Context> { ...@@ -45,7 +45,7 @@ class SoftmaxCrossEntropyGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -30,7 +30,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -30,7 +30,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
} }
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void SoftmaxRun(); void SoftmaxRun();
void SoftmaxRunFP16(); void SoftmaxRunFP16();
...@@ -60,7 +60,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> { ...@@ -60,7 +60,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
} }
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunWithType();
......
...@@ -29,7 +29,7 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex ...@@ -29,7 +29,7 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
pos_alpha = alpha * 2.0; pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0; neg_alpha = (1 - alpha) * 2.0;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -53,7 +53,7 @@ class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyG ...@@ -53,7 +53,7 @@ class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyG
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)), gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {} neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -30,7 +30,7 @@ class AccuracyOp final: public Operator<Context> { ...@@ -30,7 +30,7 @@ class AccuracyOp final: public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
} }
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunWithType();
......
...@@ -23,7 +23,7 @@ class AsTypeOp final : public Operator<Context> { ...@@ -23,7 +23,7 @@ class AsTypeOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "float32")), dtype(OperatorBase::GetSingleArg<string>("dtype", "float32")),
inplace(OperatorBase::GetSingleArg<bool>("inplace", false)) {} inplace(OperatorBase::GetSingleArg<bool>("inplace", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -25,7 +25,7 @@ class GradientGenerateOp final: public Operator<Context> { ...@@ -25,7 +25,7 @@ class GradientGenerateOp final: public Operator<Context> {
CHECK_EQ(InputSize(), OutputSize()); CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize()); CHECK_EQ(defaults.size(), OutputSize());
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -42,7 +42,7 @@ class GradientGatherOp final : public Operator<Context> { ...@@ -42,7 +42,7 @@ class GradientGatherOp final : public Operator<Context> {
for (int i = 0; i < InputSize(); i++) for (int i = 0; i < InputSize(); i++)
if (Input(i).name() != "ignore") indices.push_back(i); if (Input(i).name() != "ignore") indices.push_back(i);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -55,7 +55,7 @@ template <class Context> ...@@ -55,7 +55,7 @@ template <class Context>
class StopGradientOp final : public Operator<Context> { class StopGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(StopGradientOp); USE_SIMPLE_CTOR_DTOR(StopGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -40,7 +40,7 @@ class ImageDataOp final : public Operator<Context> { ...@@ -40,7 +40,7 @@ class ImageDataOp final : public Operator<Context> {
std.mutable_data<float, CPUContext>()[i] = std_values[i]; std.mutable_data<float, CPUContext>()[i] = std_values[i];
} }
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType(); template <typename Tx, typename Ty> void RunWithType();
......
...@@ -25,7 +25,7 @@ class InitializeOp: public Operator<Context> { ...@@ -25,7 +25,7 @@ class InitializeOp: public Operator<Context> {
shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) { shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) {
GET_ARGUMENTS_WITH_DESC(int, dims); GET_ARGUMENTS_WITH_DESC(int, dims);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
this->filler.set_type("constant"); this->filler.set_type("constant");
this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0)); this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
...@@ -56,7 +56,7 @@ public: ...@@ -56,7 +56,7 @@ public:
this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0)); this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0));
this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0)); this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
...@@ -68,7 +68,7 @@ public: ...@@ -68,7 +68,7 @@ public:
this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0)); this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0)); this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
...@@ -84,7 +84,7 @@ public: ...@@ -84,7 +84,7 @@ public:
this->filler.set_low(mu - 2 * sigma); this->filler.set_low(mu - 2 * sigma);
this->filler.set_high(mu + 2 * sigma); this->filler.set_high(mu + 2 * sigma);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
} }
this->filler.set_scale(scale); this->filler.set_scale(scale);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
...@@ -126,7 +126,7 @@ public: ...@@ -126,7 +126,7 @@ public:
} }
this->filler.set_scale(scale); this->filler.set_scale(scale);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims); DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims);
......
...@@ -24,7 +24,7 @@ template <class Context> ...@@ -24,7 +24,7 @@ template <class Context>
class RunOp : public Operator<Context> { class RunOp : public Operator<Context> {
public: public:
RunOp(const OperatorDef& op_def, Workspace* ws); RunOp(const OperatorDef& op_def, Workspace* ws);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -38,7 +38,7 @@ class TemplateOp : public RunOp<Context> { ...@@ -38,7 +38,7 @@ class TemplateOp : public RunOp<Context> {
public: public:
TemplateOp(const OperatorDef& op_def, Workspace* ws) TemplateOp(const OperatorDef& op_def, Workspace* ws)
: RunOp<Context>(op_def, ws) {} : RunOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
...@@ -46,7 +46,7 @@ class TemplateGradientOp : public TemplateOp<Context> { ...@@ -46,7 +46,7 @@ class TemplateGradientOp : public TemplateOp<Context> {
public: public:
TemplateGradientOp(const OperatorDef& op_def, Workspace* ws) TemplateGradientOp(const OperatorDef& op_def, Workspace* ws)
: TemplateOp<Context>(op_def, ws) {} : TemplateOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -48,7 +48,7 @@ class ModelMPIBase : public Operator<Context> { ...@@ -48,7 +48,7 @@ class ModelMPIBase : public Operator<Context> {
return MPI_DATATYPE_NULL; return MPI_DATATYPE_NULL;
} }
protected: public:
MPI_Comm comm; MPI_Comm comm;
MPI_Group group; MPI_Group group;
int comm_size, comm_rank, comm_root; int comm_size, comm_rank, comm_root;
...@@ -57,7 +57,12 @@ class ModelMPIBase : public Operator<Context> { ...@@ -57,7 +57,12 @@ class ModelMPIBase : public Operator<Context> {
}; };
#define USE_MPIMODEL_FUNCTIONS(context) \ #define USE_MPIMODEL_FUNCTIONS(context) \
using ModelMPIBase<context>::mpi_dtype using ModelMPIBase<context>::comm; \
using ModelMPIBase<context>::mpi_dtype; \
using ModelMPIBase<context>::comm_size; \
using ModelMPIBase<context>::comm_rank; \
using ModelMPIBase<context>::comm_root; \
using ModelMPIBase<context>::dtype
} // namespace dragon } // namespace dragon
......
...@@ -23,7 +23,7 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> { ...@@ -23,7 +23,7 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
public: public:
MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws) MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,7 +35,7 @@ class MPIBroadcastGradientOp final : public ModelMPIBase<Context> { ...@@ -35,7 +35,7 @@ class MPIBroadcastGradientOp final : public ModelMPIBase<Context> {
public: public:
MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws) MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
...@@ -47,6 +47,3 @@ public: ...@@ -47,6 +47,3 @@ public:
#endif // WITH_MPI #endif // WITH_MPI
#endif //DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_ #endif //DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
\ No newline at end of file
...@@ -23,7 +23,7 @@ class MPIGatherOp final : public ModelMPIBase<Context> { ...@@ -23,7 +23,7 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
public: public:
MPIGatherOp(const OperatorDef& op_def, Workspace *ws) MPIGatherOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,7 +35,7 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> { ...@@ -35,7 +35,7 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> {
public: public:
MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws) MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -26,7 +26,7 @@ class ArangeOp final : public Operator<Context> { ...@@ -26,7 +26,7 @@ class ArangeOp final : public Operator<Context> {
GET_ARGUMENT_WITH_DESC(int, stop, 0); GET_ARGUMENT_WITH_DESC(int, stop, 0);
GET_ARGUMENT_WITH_DESC(int, step, 1); GET_ARGUMENT_WITH_DESC(int, step, 1);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -25,7 +25,7 @@ class ArgReduceOp final : public Operator<Context> { ...@@ -25,7 +25,7 @@ class ArgReduceOp final : public Operator<Context> {
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")), operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)), keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {} top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -22,7 +22,7 @@ class ConcatOp : public Operator<Context> { ...@@ -22,7 +22,7 @@ class ConcatOp : public Operator<Context> {
ConcatOp(const OperatorDef& op_def, Workspace* ws) ConcatOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -39,7 +39,7 @@ class ConcatGradientOp : public Operator<Context> { ...@@ -39,7 +39,7 @@ class ConcatGradientOp : public Operator<Context> {
ConcatGradientOp(const OperatorDef& op_def, Workspace* ws) ConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -28,7 +28,7 @@ class CropOp: public Operator<Context> { ...@@ -28,7 +28,7 @@ class CropOp: public Operator<Context> {
GET_ARGUMENTS_WITH_DESC(int, starts); GET_ARGUMENTS_WITH_DESC(int, starts);
GET_ARGUMENTS_WITH_DESC(int, ends); GET_ARGUMENTS_WITH_DESC(int, ends);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
...@@ -37,32 +37,33 @@ class CropOp: public Operator<Context> { ...@@ -37,32 +37,33 @@ class CropOp: public Operator<Context> {
protected: protected:
TIndex start_axis; TIndex start_axis;
string shape_like; string shape_like;
vector<int> st, ed, offsets, shape, keep_dims; vector<int> offsets, shape;
vector<TIndex> st, ed, keep_dims;
DECLARE_ARGUMENTS_WITH_DESC(int, starts); DECLARE_ARGUMENTS_WITH_DESC(int, starts);
DECLARE_ARGUMENTS_WITH_DESC(int, ends); DECLARE_ARGUMENTS_WITH_DESC(int, ends);
vector< pair<int, int> > process_axes; vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim; TIndex axis, inner_dim, dim;
Tensor* dest, *source; Tensor* dest, *source, navigator;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, starts); DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, starts);
DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, ends); DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, ends);
template <class Context> template <class Context>
class CropGradientOp final : public Operator<Context > { class CropGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(CropGradientOp); USE_SIMPLE_CTOR_DTOR(CropGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<int> st, ed, offsets, keep_dims; vector<TIndex> st, ed, offsets, keep_dims;
vector< pair<int, int> > process_axes; vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim; TIndex axis, inner_dim, dim;
Tensor* dest, *source; Tensor* dest, *source, navigator;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -22,7 +22,7 @@ class ExpandDimsOp final : public Operator<Context> { ...@@ -22,7 +22,7 @@ class ExpandDimsOp final : public Operator<Context> {
ExpandDimsOp(const OperatorDef& op_def, Workspace* ws) ExpandDimsOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {} axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -34,7 +34,7 @@ template <class Context> ...@@ -34,7 +34,7 @@ template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> { class ExpandDimsGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpandDimsGradientOp); USE_SIMPLE_CTOR_DTOR(ExpandDimsGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -24,7 +24,7 @@ class FlattenOp final : public Operator<Context> { ...@@ -24,7 +24,7 @@ class FlattenOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
keep_axes(OperatorBase::GetSingleArg<int>("keep_axes", INT_MAX)) {} keep_axes(OperatorBase::GetSingleArg<int>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
void SqueezeRun(); void SqueezeRun();
...@@ -38,7 +38,7 @@ template <class Context> ...@@ -38,7 +38,7 @@ template <class Context>
class FlattenGradientOp final : public Operator<Context> { class FlattenGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(FlattenGradientOp); USE_SIMPLE_CTOR_DTOR(FlattenGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -22,7 +22,7 @@ class GatherOp final : public Operator<Context> { ...@@ -22,7 +22,7 @@ class GatherOp final : public Operator<Context> {
GatherOp(const OperatorDef& op_def, Workspace* ws) GatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {} axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -39,7 +39,7 @@ class GatherGradientOp final : public Operator<Context> { ...@@ -39,7 +39,7 @@ class GatherGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {} acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -24,7 +24,7 @@ class OneHotOp final : public Operator < Context > { ...@@ -24,7 +24,7 @@ class OneHotOp final : public Operator < Context > {
depth(OperatorBase::GetSingleArg<int>("depth", -1)), depth(OperatorBase::GetSingleArg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)), on_value(OperatorBase::GetSingleArg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {} off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -35,7 +35,7 @@ class PadOp final : public Operator<Context> { ...@@ -35,7 +35,7 @@ class PadOp final : public Operator<Context> {
} }
std::sort(process_axes.begin(), process_axes.end()); std::sort(process_axes.begin(), process_axes.end());
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ConstRunWithType(); template <typename T> void ConstRunWithType();
...@@ -48,7 +48,7 @@ class PadOp final : public Operator<Context> { ...@@ -48,7 +48,7 @@ class PadOp final : public Operator<Context> {
float value; float value;
vector< pair<int, int> > process_axes; vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim; TIndex axis, inner_dim, dim;
Tensor* dest, *source; Tensor* dest, *source, navigator;
}; };
template <class Context> template <class Context>
...@@ -70,7 +70,7 @@ class PadGradientOp final : public Operator<Context> { ...@@ -70,7 +70,7 @@ class PadGradientOp final : public Operator<Context> {
std::sort(process_axes.begin(), process_axes.end()); std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end()); std::reverse(process_axes.begin(), process_axes.end());
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ConstRunWithType(); template <typename T> void ConstRunWithType();
...@@ -82,7 +82,7 @@ class PadGradientOp final : public Operator<Context> { ...@@ -82,7 +82,7 @@ class PadGradientOp final : public Operator<Context> {
string mode; string mode;
vector< pair<int, int> > process_axes; vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim; TIndex axis, inner_dim, dim;
Tensor* dest, *source; Tensor* dest, *source, navigator;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -23,7 +23,7 @@ class RandomPickOp : public Operator<Context> { ...@@ -23,7 +23,7 @@ class RandomPickOp : public Operator<Context> {
Operator<Context>(op_def, ws), Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
max_samples(OperatorBase::GetSingleArg<int>("max_samples", 1)) {} max_samples(OperatorBase::GetSingleArg<int>("max_samples", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws) RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {} axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -24,17 +24,15 @@ class ReduceOp final : public Operator<Context> { ...@@ -24,17 +24,15 @@ class ReduceOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")), operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {} keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
template <typename T> void MeanRunWithType(); template <typename T> void MeanRunWithType();
protected: protected:
bool keep_dims; TIndex axis, keep_dims, axis_dim, count, inner_dim;
string operation; string operation;
TIndex axis, axis_dim, count, inner_dim;
Tensor* multiplier;
}; };
template <class Context> template <class Context>
...@@ -44,15 +42,15 @@ class ReduceGradientOp final : public Operator<Context> { ...@@ -44,15 +42,15 @@ class ReduceGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {} operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
template <typename T> void MeanRunWithType(); template <typename T> void MeanRunWithType();
protected: protected:
string operation;
TIndex axis, axis_dim, count, inner_dim; TIndex axis, axis_dim, count, inner_dim;
string operation;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -24,7 +24,7 @@ class RepeatOp : public Operator<Context> { ...@@ -24,7 +24,7 @@ class RepeatOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1); GET_ARGUMENT_WITH_DESC(int, repeats, 1);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunWithType();
...@@ -42,7 +42,7 @@ class RepeatGradientOp : public Operator<Context> { ...@@ -42,7 +42,7 @@ class RepeatGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1); GET_ARGUMENT_WITH_DESC(int, repeats, 1);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunWithType();
......
...@@ -24,7 +24,7 @@ class ReshapeOp final : public Operator<Context> { ...@@ -24,7 +24,7 @@ class ReshapeOp final : public Operator<Context> {
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) { shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape); GET_ARGUMENTS_WITH_DESC(int, shape);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
...@@ -38,7 +38,7 @@ template <class Context> ...@@ -38,7 +38,7 @@ template <class Context>
class ReshapeGradientOp final : public Operator<Context> { class ReshapeGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ReshapeGradientOp); USE_SIMPLE_CTOR_DTOR(ReshapeGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -20,7 +20,7 @@ template <class Context> ...@@ -20,7 +20,7 @@ template <class Context>
class ShapeOp final : public Operator<Context> { class ShapeOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ShapeOp); USE_SIMPLE_CTOR_DTOR(ShapeOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -23,7 +23,7 @@ class SliceOp : public Operator<Context> { ...@@ -23,7 +23,7 @@ class SliceOp : public Operator<Context> {
Operator<Context>(op_def, ws), Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {} nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -42,7 +42,7 @@ class SliceGradientOp final : public Operator<Context> { ...@@ -42,7 +42,7 @@ class SliceGradientOp final : public Operator<Context> {
Operator<Context>(op_def, ws), Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {} nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -23,7 +23,7 @@ class StackOp : public Operator<Context> { ...@@ -23,7 +23,7 @@ class StackOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {} nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -41,7 +41,7 @@ class StackGradientOp : public Operator<Context> { ...@@ -41,7 +41,7 @@ class StackGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {} nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -23,7 +23,7 @@ class TileOp : public Operator<Context> { ...@@ -23,7 +23,7 @@ class TileOp : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples); GET_ARGUMENTS_WITH_DESC(int, multiples);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
...@@ -31,7 +31,7 @@ class TileOp : public Operator<Context> { ...@@ -31,7 +31,7 @@ class TileOp : public Operator<Context> {
protected: protected:
DECLARE_ARGUMENTS_WITH_DESC(int, multiples); DECLARE_ARGUMENTS_WITH_DESC(int, multiples);
TIndex axis, multiple, outer_dim, ex_inner_dim; TIndex axis, multiple, outer_dim, ex_inner_dim;
Tensor* dest, *source; Tensor* dest, *source, navigator;
}; };
template <class Context> template <class Context>
...@@ -41,7 +41,7 @@ class TileGradientOp : public Operator<Context> { ...@@ -41,7 +41,7 @@ class TileGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples); GET_ARGUMENTS_WITH_DESC(int, multiples);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
...@@ -49,7 +49,7 @@ class TileGradientOp : public Operator<Context> { ...@@ -49,7 +49,7 @@ class TileGradientOp : public Operator<Context> {
protected: protected:
DECLARE_ARGUMENTS_WITH_DESC(int, multiples); DECLARE_ARGUMENTS_WITH_DESC(int, multiples);
TIndex axis, multiple, outer_dim, ex_inner_dim; TIndex axis, multiple, outer_dim, ex_inner_dim;
Tensor* dest, *source; Tensor* dest, *source, navigator;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, TileOp, multiples); DEFINE_ARGUMENTS_WITH_DESC(int, TileOp, multiples);
......
...@@ -25,7 +25,7 @@ class TransposeOp final: public Operator<Context> { ...@@ -25,7 +25,7 @@ class TransposeOp final: public Operator<Context> {
if (perms.size() > 0) reverse_dims = false; if (perms.size() > 0) reverse_dims = false;
else reverse_dims = true; else reverse_dims = true;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -42,7 +42,7 @@ class TransposeGradientOp final : public Operator<Context> { ...@@ -42,7 +42,7 @@ class TransposeGradientOp final : public Operator<Context> {
public: public:
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws) TransposeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -32,7 +32,7 @@ class BatchNormOp : public Operator<Context> { ...@@ -32,7 +32,7 @@ class BatchNormOp : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -42,12 +42,9 @@ class BatchNormOp : public Operator<Context> { ...@@ -42,12 +42,9 @@ class BatchNormOp : public Operator<Context> {
protected: protected:
float momentum, eps; float momentum, eps;
Tensor mean, num_by_chans; Tensor nc, mean, *var;
Tensor* multiplier, *num_multiplier, *spatial_multiplier; TIndex axis, use_stats, N, C, S, NC, NS;
Tensor* stddev, *var;
TIndex axis, N, C, S, NC, NS;
string data_format, mode; string data_format, mode;
int use_stats;
bool use_global_stats, is_recomputing; bool use_global_stats, is_recomputing;
}; };
...@@ -62,7 +59,7 @@ class BatchNormGradientOp final : public Operator<Context> { ...@@ -62,7 +59,7 @@ class BatchNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -71,12 +68,9 @@ class BatchNormGradientOp final : public Operator<Context> { ...@@ -71,12 +68,9 @@ class BatchNormGradientOp final : public Operator<Context> {
template <typename T> void InferenceRunWithType(); template <typename T> void InferenceRunWithType();
protected: protected:
Tensor num_by_chans; TIndex axis, use_stats, N, C, S, NC, NS;
Tensor* multiplier, *num_multiplier, *spatial_multiplier; Tensor nc, *var;
Tensor* stddev, *var;
TIndex axis, N, C, S, NC, NS;
string data_format; string data_format;
int use_stats;
bool use_global_stats; bool use_global_stats;
}; };
...@@ -89,7 +83,7 @@ class FusedBatchNormOp : public Operator<Context> { ...@@ -89,7 +83,7 @@ class FusedBatchNormOp : public Operator<Context> {
momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)), momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)), eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -98,13 +92,10 @@ class FusedBatchNormOp : public Operator<Context> { ...@@ -98,13 +92,10 @@ class FusedBatchNormOp : public Operator<Context> {
template <typename T> void InferenceRunWithType(); template <typename T> void InferenceRunWithType();
protected: protected:
TIndex axis, use_stats, N, C, S, NC, NS;
float momentum, eps; float momentum, eps;
Tensor num_by_chans; Tensor nc, *mean, *var, *x_norm;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex axis, N, C, S, NC, NS;
string data_format; string data_format;
int use_stats;
bool use_global_stats, is_recomputing; bool use_global_stats, is_recomputing;
}; };
...@@ -116,7 +107,7 @@ class FusedBatchNormGradientOp : public Operator<Context> { ...@@ -116,7 +107,7 @@ class FusedBatchNormGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)), eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -125,13 +116,10 @@ class FusedBatchNormGradientOp : public Operator<Context> { ...@@ -125,13 +116,10 @@ class FusedBatchNormGradientOp : public Operator<Context> {
template <typename T> void InferenceRunWithType(); template <typename T> void InferenceRunWithType();
protected: protected:
TIndex axis, use_stats, N, C, S, NC, NS;
float eps; float eps;
Tensor num_by_chans; Tensor nc, *mean, *var, *x_norm;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex axis, N, C, S, NC, NS;
string data_format; string data_format;
int use_stats;
bool use_global_stats; bool use_global_stats;
}; };
...@@ -156,7 +144,7 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> { ...@@ -156,7 +144,7 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
<< "CUDNN_BN_MIN_EPSILON instead."; << "CUDNN_BN_MIN_EPSILON instead.";
eps64 = std::max(eps64, CUDNN_BN_MIN_EPSILON); eps64 = std::max(eps64, CUDNN_BN_MIN_EPSILON);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNBatchNormOp() { ~CuDNNBatchNormOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -170,12 +158,12 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> { ...@@ -170,12 +158,12 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex N, C;
double eps64; double eps64;
Tensor* mean, *var;
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc; cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
cudnnBatchNormMode_t bn_mode; cudnnBatchNormMode_t bn_mode;
TIndex N, C;
string data_format; string data_format;
Tensor* mean, *var;
}; };
template <class Context> template <class Context>
...@@ -193,7 +181,7 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context> ...@@ -193,7 +181,7 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context>
<< "CUDNN_BN_MIN_EPSILON instead."; << "CUDNN_BN_MIN_EPSILON instead.";
eps64 = std::max(eps64, CUDNN_BN_MIN_EPSILON); eps64 = std::max(eps64, CUDNN_BN_MIN_EPSILON);
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNBatchNormGradientOp() { ~CuDNNBatchNormGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -208,14 +196,12 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context> ...@@ -208,14 +196,12 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context>
template <typename T> void InferenceRunWithType(); template <typename T> void InferenceRunWithType();
protected: protected:
TIndex N, C, S, NC, NS;
double eps64; double eps64;
Tensor nc, *mean, *var;
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc; cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
cudnnBatchNormMode_t bn_mode; cudnnBatchNormMode_t bn_mode;
TIndex N, C, S, NC, NS;
string data_format; string data_format;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev;
}; };
#endif #endif
......
...@@ -34,7 +34,7 @@ class BatchRenormOp : public Operator<Context> { ...@@ -34,7 +34,7 @@ class BatchRenormOp : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -43,14 +43,12 @@ class BatchRenormOp : public Operator<Context> { ...@@ -43,14 +43,12 @@ class BatchRenormOp : public Operator<Context> {
template <typename T> void InferenceRunWithType(); template <typename T> void InferenceRunWithType();
protected: protected:
TIndex axis, use_stats, N, C, S, NC, NS;
float momentum, eps, r_max, d_max, t_delta; float momentum, eps, r_max, d_max, t_delta;
float t_r_max, t_d_max, t_val; float t_r_max, t_d_max, t_val;
Tensor mean, d, t_h_mean, t_h_var, num_by_chans; Tensor nc, mean, d, t_h_mean, t_h_var;
Tensor* multiplier, *num_multiplier, *spatial_multiplier; Tensor* r, *var, *x_norm;
Tensor* stddev, *r, *var, *x_norm;
TIndex axis, N, C, S, NC, NS;
string data_format, mode; string data_format, mode;
int use_stats;
bool use_global_stats, is_recomputing; bool use_global_stats, is_recomputing;
}; };
...@@ -65,7 +63,7 @@ class BatchRenormGradientOp final : public Operator<Context> { ...@@ -65,7 +63,7 @@ class BatchRenormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -76,12 +74,9 @@ class BatchRenormGradientOp final : public Operator<Context> { ...@@ -76,12 +74,9 @@ class BatchRenormGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
Tensor mean, num_by_chans; TIndex axis, use_stats, N, C, S, NC, NS;
Tensor* multiplier, *num_multiplier, *spatial_multiplier; Tensor nc, mean, *r, *var, *x_norm;
Tensor* stddev, *r, *var, *x_norm;
TIndex axis, N, C, S, NC, NS;
string data_format; string data_format;
int use_stats;
bool use_global_stats; bool use_global_stats;
}; };
......
...@@ -28,7 +28,7 @@ class GroupNormOp : public Operator<Context> { ...@@ -28,7 +28,7 @@ class GroupNormOp : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -37,10 +37,8 @@ class GroupNormOp : public Operator<Context> { ...@@ -37,10 +37,8 @@ class GroupNormOp : public Operator<Context> {
protected: protected:
float eps; float eps;
Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS; TIndex group, axis, N, C, S, NG, NC, NS, CGS;
Tensor nc, mean, *var;
string data_format; string data_format;
}; };
...@@ -55,7 +53,7 @@ class GroupNormGradientOp final : public Operator<Context> { ...@@ -55,7 +53,7 @@ class GroupNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -63,10 +61,8 @@ class GroupNormGradientOp final : public Operator<Context> { ...@@ -63,10 +61,8 @@ class GroupNormGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS; TIndex group, axis, N, C, S, NG, NC, NS, CGS;
Tensor nc, *var;
string data_format; string data_format;
}; };
...@@ -78,7 +74,7 @@ class FusedGroupNormOp : public Operator<Context> { ...@@ -78,7 +74,7 @@ class FusedGroupNormOp : public Operator<Context> {
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {} eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -86,11 +82,9 @@ class FusedGroupNormOp : public Operator<Context> { ...@@ -86,11 +82,9 @@ class FusedGroupNormOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS; TIndex group, axis, N, C, S, NG, NC, NS, CGS;
float eps;
Tensor nc, *mean, *var, *x_norm;
string data_format; string data_format;
}; };
...@@ -101,7 +95,7 @@ class FusedGroupNormGradientOp : public Operator<Context> { ...@@ -101,7 +95,7 @@ class FusedGroupNormGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {} axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -109,10 +103,8 @@ class FusedGroupNormGradientOp : public Operator<Context> { ...@@ -109,10 +103,8 @@ class FusedGroupNormGradientOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS; TIndex group, axis, N, C, S, NG, NC, NS, CGS;
Tensor nc, *mean, *var, *x_norm;
string data_format; string data_format;
}; };
......
...@@ -24,10 +24,9 @@ class InstanceNormOp : public Operator<Context> { ...@@ -24,10 +24,9 @@ class InstanceNormOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) { eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1.";
<< "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -35,10 +34,9 @@ class InstanceNormOp : public Operator<Context> { ...@@ -35,10 +34,9 @@ class InstanceNormOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float eps;
Tensor mean;
Tensor* spatial_multiplier, *stddev, *var;
TIndex axis, N, C, S, NC, CS; TIndex axis, N, C, S, NC, CS;
float eps;
Tensor mean, *var;
string data_format; string data_format;
}; };
...@@ -49,10 +47,9 @@ class InstanceNormGradientOp final : public Operator<Context> { ...@@ -49,10 +47,9 @@ class InstanceNormGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) { axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1.";
<< "\nThe axis can only be set to 1.";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Setup(); void Setup();
...@@ -60,8 +57,8 @@ class InstanceNormGradientOp final : public Operator<Context> { ...@@ -60,8 +57,8 @@ class InstanceNormGradientOp final : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
Tensor* spatial_multiplier, *stddev, *var;
TIndex axis, N, C, S, NC, CS; TIndex axis, N, C, S, NC, CS;
Tensor *var;
string data_format; string data_format;
}; };
......
...@@ -25,17 +25,17 @@ class L2NormOp final : public Operator<Context> { ...@@ -25,17 +25,17 @@ class L2NormOp final : public Operator<Context> {
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-5f)), eps(OperatorBase::GetSingleArg<float>("eps", 1e-5f)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {} mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float eps;
TIndex axis, num_axes, end_axis; TIndex axis, num_axes, end_axis;
float eps;
string mode; string mode;
bool across_inner; bool across_inner;
Tensor* norm, *buffer, *multiplier; Tensor* norm, buffer;
TIndex outer_dim, dim, inner_dim, spatial_dim; TIndex outer_dim, dim, inner_dim, spatial_dim;
}; };
...@@ -47,7 +47,7 @@ class L2NormGradientOp final : public Operator<Context> { ...@@ -47,7 +47,7 @@ class L2NormGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {} mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -56,7 +56,7 @@ class L2NormGradientOp final : public Operator<Context> { ...@@ -56,7 +56,7 @@ class L2NormGradientOp final : public Operator<Context> {
TIndex axis, num_axes, end_axis; TIndex axis, num_axes, end_axis;
string mode; string mode;
bool across_inner; bool across_inner;
Tensor* norm, *multiplier, *buffer, *buffer_inner; Tensor* norm, buffer, buffer_inner;
TIndex outer_dim, dim, inner_dim; TIndex outer_dim, dim, inner_dim;
}; };
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#include "core/operator.h"
namespace dragon {
#ifdef WITH_CUDNN
#if CUDNN_VERSION_MIN(5, 0, 0)
#include "utils/cudnn_device.h"
class cudnnTensorDescriptors {
public:
cudnnTensorDescriptors(const int num_descs) {
descs_.resize(num_descs);
for (int i = 0; i < num_descs; ++i)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&descs_[i]));
}
~cudnnTensorDescriptors() {
for (auto desc : descs_)
cudnnDestroyTensorDescriptor(desc);
}
template <typename T>
void Set(const vector<TIndex>& dims, const vector<TIndex>& strides) {
CHECK_EQ(dims.size(), strides.size());
for (auto desc : descs_) cudnnSetTensorDesc<T>(&desc, dims, strides);
}
const cudnnTensorDescriptor_t* descs() const { return descs_.data(); }
protected:
vector<cudnnTensorDescriptor_t> descs_;
};
template <class Context>
class CuDNNRecurrentOpBase : public Operator<Context> {
public:
CuDNNRecurrentOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
hidden_size(OperatorBase::GetSingleArg<int>("hidden_size", 0)),
num_layers(OperatorBase::GetSingleArg<int>("num_layers", 1)),
bidirectional(OperatorBase::GetSingleArg<bool>("bidirectional", false)),
dropout_ratio(OperatorBase::GetSingleArg<float>("dropout_ratio", 1.0)),
random_seed(op_def.device_option().random_seed()) {
// determine the rnn direction
rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// determine the rnn mode
const string mode = OperatorBase::GetSingleArg<string>("rnn_mode", "");
if (mode == "rnn_tanh") rnn_mode = CUDNN_RNN_TANH;
else if (mode == "rnn_relu") rnn_mode = CUDNN_RNN_RELU;
else if (mode == "lstm") rnn_mode = CUDNN_LSTM;
else if (mode == "gru") rnn_mode = CUDNN_GRU;
else LOG(FATAL) << "Unsupported rnn mode: " << mode;
// determine the rnn input mode
const string input_mode = OperatorBase::GetSingleArg<string>("rnn_input_mode", "linear");
if (input_mode == "skip") rnn_input_mode = CUDNN_SKIP_INPUT;
else if (input_mode == "linear") rnn_input_mode = CUDNN_LINEAR_INPUT;
else LOG(FATAL) << "Unsupported rnn input mode: " << input_mode;
// override the running phase
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
CUDNN_CHECK(cudnnCreateFilterDescriptor(&w_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc));
}
USE_OPERATOR_FUNCTIONS;
virtual ~CuDNNRecurrentOpBase() {
CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc));
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc));
CUDNN_CHECK(cudnnDestroyFilterDescriptor(w_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc));
}
template <typename T> void ResetDesc(Tensor* X, Tensor* Hx, Tensor* Cx,
Tensor* Y, Tensor* Hy, Tensor* Cy);
public:
TIndex hidden_size, num_layers;
bool bidirectional, states_initialized;
float dropout_ratio;
unsigned long long random_seed;
cudnnRNNDescriptor_t rnn_desc;
cudnnDropoutDescriptor_t dropout_desc;
cudnnDirectionMode_t rnn_direction;
cudnnRNNMode_t rnn_mode;
cudnnRNNInputMode_t rnn_input_mode;
cudnnFilterDescriptor_t w_desc;
cudnnTensorDescriptor_t hx_desc, cx_desc;
cudnnTensorDescriptor_t hy_desc, cy_desc;
vector<TIndex> input_dims;
size_t workspace_size, reserve_size, states_size;
std::unique_ptr<cudnnTensorDescriptors> xs_desc;
std::unique_ptr<cudnnTensorDescriptors> ys_desc;
};
#define USE_CUDNN_RECURRENT_FUNCTIONS \
USE_OPERATOR_FUNCTIONS; \
using CuDNNRecurrentOpBase<Context>::dropout_desc; \
using CuDNNRecurrentOpBase<Context>::rnn_desc; \
using CuDNNRecurrentOpBase<Context>::w_desc; \
using CuDNNRecurrentOpBase<Context>::hx_desc; \
using CuDNNRecurrentOpBase<Context>::cx_desc; \
using CuDNNRecurrentOpBase<Context>::hy_desc; \
using CuDNNRecurrentOpBase<Context>::cy_desc; \
using CuDNNRecurrentOpBase<Context>::xs_desc; \
using CuDNNRecurrentOpBase<Context>::ys_desc; \
using CuDNNRecurrentOpBase<Context>::input_dims; \
using CuDNNRecurrentOpBase<Context>::workspace_size; \
using CuDNNRecurrentOpBase<Context>::reserve_size
template <class Context>
class CuDNNRecurrentOp : public CuDNNRecurrentOpBase<Context> {
public:
CuDNNRecurrentOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(op_def, ws) {}
USE_CUDNN_RECURRENT_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class CuDNNRecurrentGradientOp : public CuDNNRecurrentOpBase<Context> {
public:
CuDNNRecurrentGradientOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(op_def, ws) {}
USE_CUDNN_RECURRENT_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
};
#endif
#endif // WITH_CUDNN
} // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
\ No newline at end of file
...@@ -9,45 +9,36 @@ ...@@ -9,45 +9,36 @@
// //
// ------------------------------------------------------------- // -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_ #ifndef DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_ #define DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class LSTMUnitOp : public Operator<Context> { class LSTMCellOp : public Operator<Context> {
public: public:
LSTMUnitOp(const OperatorDef& op_def, Workspace* ws) LSTMCellOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws) {}
has_cont(OperatorBase::GetSingleArg<string>("cont_t", "")) {} USE_OPERATOR_FUNCTIONS;
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected:
TIndex num, channels;
string has_cont;
Tensor* cont_t;
}; };
template <class Context> template <class Context>
class LSTMUnitGradientOp : public Operator<Context> { class LSTMCellGradientOp : public Operator<Context> {
public: public:
LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws) LSTMCellGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected:
TIndex num, channels;
Tensor* zeros;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_ #endif // DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
\ No newline at end of file \ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class RecurrentOp : public Operator<Context> {
public:
RecurrentOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
LOG(FATAL) << "RNN Operators require CuDNN support.";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {}
};
template <class Context>
class RecurrentGradientOp : public Operator<Context> {
public:
RecurrentGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
LOG(FATAL) << "RNN Operators require CuDNN support.";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {}
};
} // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_RECURRENT_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
#define DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class RNNParamSetOp : public Operator<Context> {
public:
RNNParamSetOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
param_type(OperatorBase::GetSingleArg<string>("param_type", "matrix")),
rnn_mode(OperatorBase::GetSingleArg<string>("rnn_mode", "rnn_tanh")),
num_layers(OperatorBase::GetSingleArg<int>("num_layers", 1)),
num_directions(OperatorBase::GetSingleArg<int>("num_directions", 1)),
input_size(OperatorBase::GetSingleArg<int>("input_size", 0)),
hidden_size(OperatorBase::GetSingleArg<int>("hidden_size", 0)),
layer_id(OperatorBase::GetSingleArg<int>("layer_id", 0)),
param_id(OperatorBase::GetSingleArg<int>("param_id", 0)) {
if (rnn_mode == "rnn_tanh") { num_params = 2; spliter = 1; }
else if (rnn_mode == "rnn_relu") { num_params = 2; spliter = 1; }
else if (rnn_mode == "lstm") { num_params = 8; spliter = 4; }
else if (rnn_mode == "gru") { num_params = 6; spliter = 3; }
else LOG(FATAL) << "Unsupported rnn mode: " << rnn_mode;
input_ex_size = hidden_size * num_directions;
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
string param_type, rnn_mode;
TIndex num_layers, num_directions, num_params, spliter;
TIndex input_size, input_ex_size, hidden_size;
TIndex layer_id, param_id;
};
} // namespace dragon
#endif
\ No newline at end of file
...@@ -21,7 +21,7 @@ class AdamUpdateOp final : public UpdateOpBase<Context> { ...@@ -21,7 +21,7 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
public: public:
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws) AdamUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), t(0) {} : UpdateOpBase<Context>(op_def, ws), t(0) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
......
...@@ -27,7 +27,7 @@ class CollectiveUpdateOp : public Operator<Context> { ...@@ -27,7 +27,7 @@ class CollectiveUpdateOp : public Operator<Context> {
InitMPI(); InitMPI();
if (mode.find("NCCL") != string::npos) InitNCCL(); if (mode.find("NCCL") != string::npos) InitNCCL();
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void InitMPI(); void InitMPI();
void InitNCCL(); void InitNCCL();
...@@ -41,7 +41,6 @@ class CollectiveUpdateOp : public Operator<Context> { ...@@ -41,7 +41,6 @@ class CollectiveUpdateOp : public Operator<Context> {
protected: protected:
int comm_size, comm_rank, comm_root; int comm_size, comm_rank, comm_root;
int world_size, world_rank; int world_size, world_rank;
Tensor* buffer;
string mode; string mode;
MPI_Comm comm; MPI_Comm comm;
......
...@@ -22,7 +22,7 @@ class MovingAverageOp final : public Operator<Context> { ...@@ -22,7 +22,7 @@ class MovingAverageOp final : public Operator<Context> {
MovingAverageOp(const OperatorDef& op_def, Workspace* ws) MovingAverageOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {} decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -21,7 +21,7 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> { ...@@ -21,7 +21,7 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
public: public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws) NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws) {} : UpdateOpBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
......
...@@ -21,7 +21,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> { ...@@ -21,7 +21,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public: public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws) RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws) {} : UpdateOpBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
......
...@@ -22,7 +22,7 @@ class SGDUpdateOp final : public UpdateOpBase<Context> { ...@@ -22,7 +22,7 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws) SGDUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws),
old_lr(-1.f), correction(1.f) {} old_lr(-1.f), correction(1.f) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
......
...@@ -27,7 +27,7 @@ class UpdateOpBase : public Operator<Context> { ...@@ -27,7 +27,7 @@ class UpdateOpBase : public Operator<Context> {
zero_grad(OperatorBase::GetSingleArg<bool>("zero_grad", true)) { zero_grad(OperatorBase::GetSingleArg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot"; CHECK(!slot.empty()) << "\nRequired a non-empty slot";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
float Param(const string& name) const; float Param(const string& name) const;
string Slot(); string Slot();
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
// //
// ------------------------------------------------------------ // ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_ #ifndef DRAGON_OPERATORS_VISION_BIAS_ADD_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_ #define DRAGON_OPERATORS_VISION_BIAS_ADD_OP_H_
#include "core/operator.h" #include "core/operator.h"
...@@ -22,7 +22,7 @@ class BiasAddOp : public Operator<Context> { ...@@ -22,7 +22,7 @@ class BiasAddOp : public Operator<Context> {
BiasAddOp(const OperatorDef& op_def, Workspace* ws) BiasAddOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -30,7 +30,6 @@ class BiasAddOp : public Operator<Context> { ...@@ -30,7 +30,6 @@ class BiasAddOp : public Operator<Context> {
protected: protected:
TIndex outer_dim, dim, inner_dim; TIndex outer_dim, dim, inner_dim;
string data_format; string data_format;
Tensor* bias_multiplier;
}; };
template <class Context> template <class Context>
...@@ -39,7 +38,7 @@ class BiasAddGradientOp final : public Operator<Context> { ...@@ -39,7 +38,7 @@ class BiasAddGradientOp final : public Operator<Context> {
BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws) BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -47,9 +46,8 @@ class BiasAddGradientOp final : public Operator<Context> { ...@@ -47,9 +46,8 @@ class BiasAddGradientOp final : public Operator<Context> {
protected: protected:
int outer_dim, dim, inner_dim; int outer_dim, dim, inner_dim;
string data_format; string data_format;
Tensor* bias_multiplier;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_ #endif // DRAGON_OPERATORS_VISION_BIAS_ADD_OP_H_
\ No newline at end of file \ No newline at end of file
...@@ -30,7 +30,7 @@ class BilinearResizeOp : public Operator<Context> { ...@@ -30,7 +30,7 @@ class BilinearResizeOp : public Operator<Context> {
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -48,7 +48,7 @@ class BilinearResizeGradientOp : public Operator<Context> { ...@@ -48,7 +48,7 @@ class BilinearResizeGradientOp : public Operator<Context> {
BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws) BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -24,7 +24,7 @@ class Conv2dOp : public ConvOpBase<Context> { ...@@ -24,7 +24,7 @@ class Conv2dOp : public ConvOpBase<Context> {
this->num_spatial_axes = 2; this->num_spatial_axes = 2;
Setup(); Setup();
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context); USE_CONVOLUTION_FUNCTIONS(Context);
bool ReverseDimensions() override { return false; } bool ReverseDimensions() override { return false; }
...@@ -39,7 +39,7 @@ class Conv2dGradientOp : public Conv2dOp<Context> { ...@@ -39,7 +39,7 @@ class Conv2dGradientOp : public Conv2dOp<Context> {
public: public:
Conv2dGradientOp(const OperatorDef& def, Workspace* ws) Conv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) {} : Conv2dOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context); USE_CONVOLUTION_FUNCTIONS(Context);
bool HasBias() override { return Output(2)->name() != "ignore"; } bool HasBias() override { return Output(2)->name() != "ignore"; }
...@@ -61,7 +61,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -61,7 +61,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnn_group = 1; cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE(); enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else #else
cudnn_group = this->group; cudnn_group = group;
enable_tensor_core = false; enable_tensor_core = false;
#endif #endif
handle = new cudnnHandle_t[cudnn_group]; handle = new cudnnHandle_t[cudnn_group];
...@@ -77,11 +77,11 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -77,11 +77,11 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW; if (data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC; else if (data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context); USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dOp() { ~CuDNNConv2dOp() {
...@@ -109,7 +109,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -109,7 +109,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc; cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t fwd_data_size;
TIndex bias_offset, cudnn_group; TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims; vector<TIndex> input_dims;
bool enable_tensor_core; bool enable_tensor_core;
...@@ -124,7 +124,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -124,7 +124,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnn_group = 1; cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE(); enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else #else
cudnn_group = this->group; cudnn_group = group;
enable_tensor_core = false; enable_tensor_core = false;
#endif #endif
handle = new cudnnHandle_t[cudnn_group * 3]; handle = new cudnnHandle_t[cudnn_group * 3];
...@@ -139,11 +139,11 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -139,11 +139,11 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW; if (data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC; else if (data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context); USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dGradientOp() { ~CuDNNConv2dGradientOp() {
...@@ -172,7 +172,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -172,7 +172,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc; cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size; size_t bwd_filter_size, bwd_data_size;
TIndex bias_offset, cudnn_group; TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims; vector<TIndex> input_dims;
bool enable_tensor_core; bool enable_tensor_core;
......
...@@ -34,19 +34,18 @@ class ConvOpBase : public Operator<Context> { ...@@ -34,19 +34,18 @@ class ConvOpBase : public Operator<Context> {
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
num_spatial_axes = -1; // unknown num_spatial_axes = -1; // unknown
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
protected: public:
vector<TIndex> kernel_size, stride, pad, dilation; vector<TIndex> kernel_size, stride, pad, dilation;
string data_format, padding; string data_format, padding;
vector<TIndex> input_shape, output_shape, bottom_shape, top_shape, col_shape; vector<TIndex> input_shape, output_shape, bottom_shape, top_shape;
vector<TIndex> weight_shape, bias_shape; vector<TIndex> weight_shape, bias_shape;
Tensor* col_buffer, *bias_multiplier;
TIndex num_output, group; TIndex num_output, group;
TIndex spatial_axis, num_spatial_axes; TIndex spatial_axis, num_spatial_axes;
TIndex channels, out_spatial_dim; TIndex channels, out_spatial_dim;
TIndex conv_in_channels, conv_out_channels; TIndex conv_in_channels, conv_out_channels;
TIndex conv_out_spatial_dim, kernel_dim; TIndex conv_out_spatial_dim, kernel_dim, col_dim;
TIndex col_offset, output_offset, weight_offset, x_offset, y_offset; TIndex col_offset, output_offset, weight_offset, x_offset, y_offset;
DECLARE_ARGUMENTS_WITH_DESC(int, output_dims); DECLARE_ARGUMENTS_WITH_DESC(int, output_dims);
bool is_1x1; bool is_1x1;
...@@ -58,10 +57,15 @@ class ConvOpBase : public Operator<Context> { ...@@ -58,10 +57,15 @@ class ConvOpBase : public Operator<Context> {
virtual bool ReverseDimensions() = 0; virtual bool ReverseDimensions() = 0;
virtual bool HasBias() = 0; virtual bool HasBias() = 0;
template <typename T> void Wx(const T* x, const T* weights, T* y, bool skip_im2col = false); template <typename T> void Wx(const T* x,
const T* weights, T* y, bool skip_im2col = false);
template <typename T> void Pb(const T* bias, T* y); template <typename T> void Pb(const T* bias, T* y);
template <typename T> void Dx(const T* dy, const T* weights, T* dx); template <typename T> void Dx(const T* dy, const T* weights, T* dx);
template <typename T> void Dw(const T* dy, const T* x, T *dw);
template <typename T> void Dw(const T* dy, const T* x, T* dw);
template <typename T> void Db(const T* dy, T* db); template <typename T> void Db(const T* dy, T* db);
private: private:
...@@ -108,7 +112,20 @@ DEFINE_ARGUMENTS_WITH_DESC(int, ConvOpBase, output_dims); ...@@ -108,7 +112,20 @@ DEFINE_ARGUMENTS_WITH_DESC(int, ConvOpBase, output_dims);
using ConvOpBase<context>::Pb; \ using ConvOpBase<context>::Pb; \
using ConvOpBase<context>::Dx; \ using ConvOpBase<context>::Dx; \
using ConvOpBase<context>::Dw; \ using ConvOpBase<context>::Dw; \
using ConvOpBase<context>::Db using ConvOpBase<context>::Db; \
using ConvOpBase<context>::kernel_size; \
using ConvOpBase<context>::stride; \
using ConvOpBase<context>::pad; \
using ConvOpBase<context>::dilation; \
using ConvOpBase<context>::group; \
using ConvOpBase<context>::channels; \
using ConvOpBase<context>::num_output; \
using ConvOpBase<context>::data_format; \
using ConvOpBase<context>::x_offset; \
using ConvOpBase<context>::y_offset; \
using ConvOpBase<context>::weight_offset; \
using ConvOpBase<context>::weight_shape; \
using ConvOpBase<context>::bias_shape
} // namespace dragon } // namespace dragon
......
...@@ -24,7 +24,7 @@ class Conv2dTransposeOp: public ConvOpBase<Context> { ...@@ -24,7 +24,7 @@ class Conv2dTransposeOp: public ConvOpBase<Context> {
this->num_spatial_axes = 2; this->num_spatial_axes = 2;
Setup(); Setup();
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context); USE_CONVOLUTION_FUNCTIONS(Context);
bool ReverseDimensions() override { return true; } bool ReverseDimensions() override { return true; }
...@@ -43,7 +43,7 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> { ...@@ -43,7 +43,7 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> {
public: public:
Conv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws) Conv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) {} : Conv2dTransposeOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context); USE_CONVOLUTION_FUNCTIONS(Context);
bool HasBias() override { return Output(2)->name() != "ignore"; } bool HasBias() override { return Output(2)->name() != "ignore"; }
...@@ -65,12 +65,12 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -65,12 +65,12 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnn_group = 1; cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE(); enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else #else
cudnn_group = this->group; cudnn_group = group;
enable_tensor_core = false; enable_tensor_core = false;
#endif #endif
handle = new cudnnHandle_t[cudnn_group]; handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group]; stream = new cudaStream_t[cudnn_group];
for (int g = 0; g < this->group; g++) { for (int g = 0; g < cudnn_group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g])); CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
...@@ -80,11 +80,11 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -80,11 +80,11 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW; if (data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC; else if (data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context); USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dTransposeOp() { ~CuDNNConv2dTransposeOp() {
...@@ -112,7 +112,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -112,7 +112,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc; cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t fwd_data_size;
TIndex bias_offset, cudnn_group; TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims; vector<TIndex> input_dims;
bool enable_tensor_core; bool enable_tensor_core;
...@@ -127,7 +127,7 @@ public: ...@@ -127,7 +127,7 @@ public:
cudnn_group = 1; cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE(); enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else #else
cudnn_group = this->group; cudnn_group = group;
enable_tensor_core = false; enable_tensor_core = false;
#endif #endif
handle = new cudnnHandle_t[cudnn_group * 3]; handle = new cudnnHandle_t[cudnn_group * 3];
...@@ -142,11 +142,11 @@ public: ...@@ -142,11 +142,11 @@ public:
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW; if (data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC; else if (data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context); USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dTransposeGradientOp() { ~CuDNNConv2dTransposeGradientOp() {
...@@ -175,7 +175,7 @@ public: ...@@ -175,7 +175,7 @@ public:
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc; cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size; size_t bwd_filter_size, bwd_data_size;
TIndex bias_offset, cudnn_group; TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims; vector<TIndex> input_dims;
bool enable_tensor_core; bool enable_tensor_core;
......
...@@ -21,7 +21,7 @@ class DenseConcatOp final : public ConcatOp<Context> { ...@@ -21,7 +21,7 @@ class DenseConcatOp final : public ConcatOp<Context> {
public: public:
DenseConcatOp(const OperatorDef& op_def, Workspace* ws) DenseConcatOp(const OperatorDef& op_def, Workspace* ws)
: ConcatOp<Context>(op_def, ws) {} : ConcatOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
}; };
template <class Context> template <class Context>
...@@ -30,7 +30,7 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> { ...@@ -30,7 +30,7 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> {
DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws) DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: ConcatGradientOp<Context>(op_def, ws), : ConcatGradientOp<Context>(op_def, ws),
growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {} growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void ElimateCorruption() override; void ElimateCorruption() override;
template <typename T> void RestoreX1(); template <typename T> void RestoreX1();
......
...@@ -29,7 +29,7 @@ class LRNOp : public Operator<Context> { ...@@ -29,7 +29,7 @@ class LRNOp : public Operator<Context> {
k(OperatorBase::GetSingleArg<float>("k", float(2.0))), k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")), mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -60,7 +60,7 @@ class LRNGradientOp : public Operator<Context> { ...@@ -60,7 +60,7 @@ class LRNGradientOp : public Operator<Context> {
k(OperatorBase::GetSingleArg<float>("k", float(2.0))), k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")), mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -97,7 +97,7 @@ class CuDNNLRNOp : public LRNOp<Context> { ...@@ -97,7 +97,7 @@ class CuDNNLRNOp : public LRNOp<Context> {
this->beta, this->beta,
this->k)); this->k));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNLRNOp() { ~CuDNNLRNOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -126,7 +126,7 @@ class CuDNNLRNGradientOp : public LRNGradientOp<Context > { ...@@ -126,7 +126,7 @@ class CuDNNLRNGradientOp : public LRNGradientOp<Context > {
this->beta, this->beta,
this->k)); this->k));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNLRNGradientOp() { ~CuDNNLRNGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
...@@ -30,7 +30,7 @@ class NNResizeOp : public Operator<Context> { ...@@ -30,7 +30,7 @@ class NNResizeOp : public Operator<Context> {
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -48,7 +48,7 @@ class NNResizeGradientOp : public Operator<Context> { ...@@ -48,7 +48,7 @@ class NNResizeGradientOp : public Operator<Context> {
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws) NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -41,7 +41,7 @@ class Pooling2dOp: public Operator <Context> { ...@@ -41,7 +41,7 @@ class Pooling2dOp: public Operator <Context> {
} }
} }
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Reshape(); void Reshape();
void RunOnDevice() override; void RunOnDevice() override;
...@@ -81,7 +81,7 @@ class Pooling2dGradientOp: public Operator<Context> { ...@@ -81,7 +81,7 @@ class Pooling2dGradientOp: public Operator<Context> {
} }
} }
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void Reshape(); void Reshape();
void RunOnDevice() override; void RunOnDevice() override;
...@@ -116,7 +116,7 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> { ...@@ -116,7 +116,7 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode; } else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNPooling2dOp() { ~CuDNNPooling2dOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
...@@ -151,7 +151,7 @@ class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> { ...@@ -151,7 +151,7 @@ class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode; } else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
~CuDNNPooling2dGradientOp() { ~CuDNNPooling2dGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
...@@ -28,7 +28,7 @@ class ROIAlignOp : public Operator<Context> { ...@@ -28,7 +28,7 @@ class ROIAlignOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -50,7 +50,7 @@ class ROIAlignGradientOp : public Operator<Context> { ...@@ -50,7 +50,7 @@ class ROIAlignGradientOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -27,7 +27,7 @@ class ROIPoolingOp : public Operator<Context> { ...@@ -27,7 +27,7 @@ class ROIPoolingOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -35,7 +35,6 @@ class ROIPoolingOp : public Operator<Context> { ...@@ -35,7 +35,6 @@ class ROIPoolingOp : public Operator<Context> {
protected: protected:
int pool_h, pool_w; int pool_h, pool_w;
float spatial_scale; float spatial_scale;
Tensor* mask;
}; };
template <class Context> template <class Context>
...@@ -46,7 +45,7 @@ class ROIPoolingGradientOp final : public Operator<Context> { ...@@ -46,7 +45,7 @@ class ROIPoolingGradientOp final : public Operator<Context> {
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)), pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {} spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -54,7 +53,6 @@ class ROIPoolingGradientOp final : public Operator<Context> { ...@@ -54,7 +53,6 @@ class ROIPoolingGradientOp final : public Operator<Context> {
protected: protected:
int pool_h, pool_w; int pool_h, pool_w;
float spatial_scale; float spatial_scale;
Tensor* mask;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -33,16 +33,28 @@ namespace math { ...@@ -33,16 +33,28 @@ namespace math {
/******************** Level-0 ********************/ /******************** Level-0 ********************/
template <typename T, class Context> template <typename T, class Context>
void Set(const int n, const T alpha, T* x); void Set(
const int n,
const T alpha,
T* x);
template <typename T, class Context> template <typename T, class Context>
void RandomUniform(const int n, const float low, const float high, T *x); void RandomUniform(
const int n,
const float low,
const float high,
T* x);
template <typename T, class Context> template <typename T, class Context>
void RandomNormal(const int n, const float mu, const float sigma, T* x); void RandomNormal(
const int n,
const float mu,
const float sigma,
T* x);
template <typename T, class Context> template <typename T, class Context>
void RandomTruncatedNormal(const int n, void RandomTruncatedNormal(
const int n,
const float mu, const float mu,
const float sigma, const float sigma,
const float low, const float low,
...@@ -50,81 +62,153 @@ void RandomTruncatedNormal(const int n, ...@@ -50,81 +62,153 @@ void RandomTruncatedNormal(const int n,
T* x); T* x);
template <typename T, class Context> template <typename T, class Context>
void RandomBernoulli(const int n, const float p, uint32_t* x); void RandomBernoulli(
const int n,
const float p,
uint32_t* x);
/******************** Level-1 ********************/ /******************** Level-1 ********************/
template <typename T, class Context> template <typename T, class Context>
void Add(const int n, const T* a, const T* b, T* y); void Add(
const int n,
const T* a,
const T* b,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Sub(const int n, const T* a, const T* b, T* y); void Sub(
const int n,
const T* a,
const T* b,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Mul(const int n, const T* a, const T* b, T* y); void Mul(
const int n,
const T* a,
const T* b,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Div(const int n, const T* a, const T* b, T* y); void Div(
const int n,
const T* a,
const T* b,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Clip(const int n, const float low, const float high, T* x); void Clip(
const int n,
const float low,
const float high,
T* x);
template <typename T, class Context> template <typename T, class Context>
void Exp(const int n, const T* x, T* y); void Exp(
const int n,
const T* x,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Log(const int n, const T* x, T* y); void Log(
const int n,
const T* x,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Square(const int n, const T* x, T* y); void Square(
const int n,
const T* x,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Sqrt(const int n, const T* x, T* y); void Sqrt(
const int n,
const T* x,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Pow(const int n, const float alpha, const T* x, T* y); void Pow(
const int n,
const float alpha,
const T* x,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Inv(const int n, const float numerator, const T* x, T* y); void Inv(
const int n,
const float numerator,
const T* x,
T* y);
/******************** Level-2 ********************/ /******************** Level-2 ********************/
template <typename T, class Context> template <typename T, class Context>
void Scal(const int n, const float alpha, T* y); void Scal(
const int n,
const float alpha,
T* y);
template <typename T, class Context> template <typename T, class Context>
void Scale(const int n, const float alpha, const T* x, T* y); void Scale(
const int n,
const float alpha,
const T* x,
T* y);
template <typename T, class Context> template <typename T, class Context>
T StridedDot(const int n, T StridedDot(
const int n,
const T* a, const T* a,
const int incx, const int incx,
const T* b, const T* b,
const int incy); const int incy);
template <typename T, class Context> template <typename T, class Context>
float Dot(const int n, const T* a, const T* b); float Dot(
const int n,
const T* a,
const T* b);
template<typename T, class Context> template<typename T, class Context>
float ASum(const int n, const T *x); float ASum(
const int n,
const T* x);
template<typename T, class Context> template<typename T, class Context>
void AddScalar(const int n, const float alpha, T* y); void AddScalar(
const int n,
const float alpha,
T* y);
template<typename T, class Context> template<typename T, class Context>
void MulScalar(const int n, const float alpha, T* y); void MulScalar(
const int n,
const float alpha,
T* y);
template<typename T, class Context> template<typename T, class Context>
void Axpy(const int n, float alpha, const T* x, T *y); void Axpy(
const int n,
float alpha,
const T* x,
T* y);
template<typename T, class Context> template<typename T, class Context>
void Axpby(const int n, float alpha, const T* x, float beta, T *y); void Axpby(
const int n,
float alpha,
const T* x,
float beta,
T* y);
/******************** Level-3 ********************/ /******************** Level-3 ********************/
template <typename T, class Context> template <typename T, class Context>
void Gemm(const CBLAS_TRANSPOSE transA, void Gemm(
const CBLAS_TRANSPOSE transB, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int M, const int M,
const int N, const int N,
const int K, const int K,
...@@ -136,7 +220,8 @@ void Gemm(const CBLAS_TRANSPOSE transA, ...@@ -136,7 +220,8 @@ void Gemm(const CBLAS_TRANSPOSE transA,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT); TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
template<typename T, class Context> template<typename T, class Context>
void Gemv(const CBLAS_TRANSPOSE transA, void Gemv(
const CBLAS_TRANSPOSE TransA,
const int M, const int M,
const int N, const int N,
const float alpha, const float alpha,
......
...@@ -35,11 +35,13 @@ inline std::vector<std::string> SplitString(const std::string& str, ...@@ -35,11 +35,13 @@ inline std::vector<std::string> SplitString(const std::string& str,
return ret; return ret;
} }
template<> inline std::string dragon_cast<std::string, int>(int val) { #define DEFINE_NUMBER2STRING(T) \
std::stringstream ss; template<> inline std::string dragon_cast<std::string, T>(T val) { \
ss << val; std::stringstream ss; ss << val; return ss.str(); \
return ss.str(); }
}
DEFINE_NUMBER2STRING(int);
DEFINE_NUMBER2STRING(unsigned long long);
template<> inline int dragon_cast<int, std::string>(std::string val) { template<> inline int dragon_cast<int, std::string>(std::string val) {
return atoi(val.c_str()); return atoi(val.c_str());
......
...@@ -24,7 +24,7 @@ Workspace* CreateWorkspace(const std::string& name){ ...@@ -24,7 +24,7 @@ Workspace* CreateWorkspace(const std::string& name){
unique_ptr<Workspace> new_workspace(new Workspace(name)); unique_ptr<Workspace> new_workspace(new Workspace(name));
g_workspaces[name] = std::move(new_workspace); g_workspaces[name] = std::move(new_workspace);
sub_workspaces[name] = vector<string>(); sub_workspaces[name] = vector<string>();
return new_workspace.get(); return g_workspaces[name].get();
} }
Workspace* ResetWorkspace(const std::string& name) { Workspace* ResetWorkspace(const std::string& name) {
......
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!