Commit 081258b9 by Ting PAN

Preliminary RNN & LSTM & GRU Support

1 parent fe161546
Showing with 990 additions and 575 deletions
......@@ -11,7 +11,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libnuma-dev \
libprotobuf-dev \
protobuf-compiler \
libopenblas-dev \
libopenblas-dev \
python3-pip \
python3-dev \
python3-pyqt4 \
......@@ -21,13 +21,14 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \
protobuf \
lmdb \
protobuf \
lmdb \
opencv-python \
six \
Pillow
matplotlib \
pyyaml
pyyaml \
cython
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/3rdparty.zip && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \
......
FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
RUN apt-get update && apt-get install -y --no-install-recommends \
RUN rm /etc/apt/sources.list.d/cuda.list && rm /etc/apt/sources.list.d/nvidia-ml.list && \
apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
......@@ -11,8 +12,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libnuma-dev \
libprotobuf-dev \
protobuf-compiler \
libopenblas-dev \
libnccl2 \
libopenblas-dev \
libnccl2 \
libnccl-dev \
python3-pip \
python3-dev \
......@@ -29,7 +30,8 @@ RUN pip3 install --no-cache-dir --upgrade setuptools wheel && \
six \
Pillow \
matplotlib \
pyyaml
pyyaml \
cython
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/3rdparty.zip && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \
......
......@@ -171,7 +171,7 @@ endif()
# ---[ Flags
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /O2 /Oi /GL /Ot /Gy")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /O2")
if (WITH_OMP)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp")
endif()
......
......@@ -39,7 +39,6 @@ using std::vector;
using std::pair;
using std::set;
using std::map;
using std::mutex;
using std::unique_ptr;
using std::shared_ptr;
......@@ -49,7 +48,7 @@ using Map = std::unordered_map<Key, Value>;
template <typename Value>
using Set = std::unordered_set<Value> ;
#define DRAGON_VERSION 2204
#define DRAGON_VERSION 2205
#define CONCATENATE_IMPL(s1, s2) s1##s2
#define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2)
......
......@@ -84,6 +84,9 @@ static inline std::mt19937* rand_generator() {
return CPUContext::cpu_object_.rand_generator.get();
}
#define CPU_FP16_NOT_SUPPORTED \
LOG(FATAL) << "FP16 is unsupported for CPUContext.";
} // namepsace dragon
#endif // DRAGON_CORE_CONTEXT_H_
\ No newline at end of file
......@@ -134,7 +134,8 @@ class CUDAContext {
DeviceGuard gurad(gpu_id_);
CUBLAS_CHECK(cublasCreate_v2(&handle));
#if CUDA_VERSION >= 9000
if (TENSOR_CORE_AVAILABLE()) cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
if (TENSOR_CORE_AVAILABLE())
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
#endif
return handle;
}
......@@ -165,6 +166,8 @@ class CUDAContext {
}
#endif
static std::mutex& mutex() { static std::mutex m; return m; }
static CUDAObject cuda_object_;
private:
......
......@@ -154,18 +154,19 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \
using OperatorBase::DebugString; \
using OperatorBase::DTypeHelper \
using OperatorBase::DTypeHelper; \
using OperatorBase::SwitchToPhase
#define USE_OPERATOR_FUNCTIONS(context) \
#define USE_OPERATOR_FUNCTIONS \
USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<context>::ctx; \
using Operator<context>::AllowRun
using Operator<Context>::ctx; \
using Operator<Context>::AllowRun
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
#define TENSOR_FILL(tensor, shape) \
#define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \
<< "\nTensor(" << tensor.name() << ") is empty. \n" \
......@@ -189,11 +190,19 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
ptr_tensor = ws()->CreateTensor("/share/multiplier"); \
if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape(vector<TIndex>(1, size)); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.0f), \
math::Set<T, Context>(size, dragon_cast<T, float>(1.f), \
ptr_tensor->template mutable_data<T, Context>()); \
} \
}
#define DECLARE_MULTIPLIER(name, size) \
const T* name; \
{ \
Tensor* _auto_multiplier_; \
INIT_MULTIPLIER(_auto_multiplier_, size); \
name = _auto_multiplier_->template data<T, Context>(); \
}
#define DECLARE_ARGUMENT_WITH_DESC(type, argument) \
type argument##_value; \
string argument##_desc; \
......
......@@ -23,14 +23,15 @@ typedef size_t TSize;
class Tensor {
public:
Tensor() {}
Tensor(const vector<TIndex>& dims) { Reshape(dims); }
Tensor(const string& name) : name_(name) {}
void Reshape(const vector<TIndex>& dims) {
dims_ = dims;
TIndex new_size = 1;
for (auto d : dims_) {
CHECK_GT(d, 0);
new_size *= d;
CHECK_GE(d, 0);
if (d > 0) new_size *= d;
}
if (own_mem_) {
if (size_ != new_size &&
......
......@@ -18,26 +18,27 @@
namespace dragon {
#define WORKSPACE_COMMON_BUFFER_SIZE 2
#define WORKSPACE_MAX_CORRUPTED_SIZE 2
class Workspace {
public:
typedef Map<string, Workspace*> WorkspaceMap;
typedef Map<string, unique_ptr<Tensor> > TensorMap;
typedef Map<string, stack<string> > BufferMap;
typedef Map<string, unique_ptr<mutex> > LockMap;
typedef Map<string, unique_ptr<OperatorBase> > OperatorMap;
typedef Map<string, unique_ptr<GraphBase> > GraphMap;
typedef Map<string, TensorFiller> FillerMap;
typedef Map<string, string> RenameMap;
Workspace(const string& name) : name_(name) { Init(); }
Workspace(const string& name) : name_(name) { InitWorkspace(); }
~Workspace();
void Init() {
inline const string& name() { return name_; }
/******************** Workspace ********************/
inline void InitWorkspace() {
CreateTensor("ignore");
CreateBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE);
Tensor* head = CreateTensor("/opt/mirror_stage/head");
head->Reshape(vector<TIndex>(1, WORKSPACE_MAX_CORRUPTED_SIZE));
Tensor* recompute_flag = CreateTensor("/opt/mirror_stage/recompute_flag");
......@@ -50,10 +51,6 @@ class Workspace {
}
}
inline const string& name() { return name_; }
/******************** Workspace ********************/
inline Workspace* MoveWorkspace(Workspace* ws) {
CHECK(ws) << "The given Workspace is invalid.";
if (workspace_map_.count(ws->name()))
......@@ -62,11 +59,9 @@ class Workspace {
}
inline void ClearWorkspace() {
// clear tensors & buffers
// clear tensors & buffers & re-initialization
for (auto& kv : tensor_map_) kv.second->Reset();
ResetBuffers("Common");
// Re-Initialization
Init();
InitWorkspace();
}
/******************** Tensor ********************/
......@@ -112,20 +107,6 @@ class Workspace {
return tensor;
}
inline void LockTensor(const string& name) {
string query = GetTensorName(name);
if (!lock_map_.count(query))
lock_map_[query] = unique_ptr<mutex>(new mutex);
lock_map_[query]->lock();
}
inline void UnlockTensor(const string& name) {
string query = GetTensorName(name);
if (!lock_map_.count(query))
lock_map_[query] = unique_ptr<mutex>(new mutex);
lock_map_[query]->unlock();
}
inline void ResetTensor(const string& name) {
Tensor* tensor = TryGetTensor(name, false);
CHECK(tensor) << "\nTensor(" << name << ") does not "
......@@ -179,49 +160,32 @@ class Workspace {
return nullptr;
}
/******************** Buffer ********************/
inline void CreateBuffer(string category, int num) {
if (!buffer_map_.count(category))
buffer_map_[category] = stack<string>();
for (int i = 1; i <= num; i++) {
string name = "/share/buffer/" + category + "_" + dragon_cast<string, int>(i);
buffer_map_[category].push(name);
CreateTensor(name);
}
/******************** Cache ********************/
template <class Context>
inline vector<void*> caches(const vector<size_t>& segments) {
TIndex total_size = 0;
for (auto& segment : segments) total_size += (TIndex)segment;
Tensor* cacheT = CreateTensor("/share/cache");
cacheT->Reshape(vector<TIndex>(1, total_size));
vector<void*> caches(segments.size());
caches[0] = cacheT->template mutable_data<uint8_t, Context>();
for (int i = 1; i < segments.size(); i++)
caches[i] = (uint8_t*)caches[i - 1] + segments[i - 1];
return caches;
}
inline Tensor* GetBuffer(string category = "Common") {
if (!buffer_map_[category].empty()) {
string name = buffer_map_[category].top();
buffer_map_[category].pop();
return tensor_map_[name].get();
}
LOG(FATAL) << "Buffers of [" << category << "] "
<< "are not enough, add more if necessary.";
return nullptr;
}
inline void ReleaseBuffer(Tensor* tensor,
string category = "Common",
bool enforce = false) {
static Map<string, int> limits = {
{ "Common", WORKSPACE_COMMON_BUFFER_SIZE }};
if (buffer_map_[category].size() >= limits[category] || enforce) {
ResetTensor(tensor->name());
if (buffer_map_[category].empty())
buffer_map_[category].push(tensor->name());
} else {
buffer_map_[category].push(tensor->name());
}
}
inline void ResetBuffers(string category) {
while (!buffer_map_[category].empty()) {
string name = buffer_map_[category].top();
buffer_map_[category].pop();
tensor_map_[name]->Reset();
}
template <typename T, class Context>
inline vector<T*> caches(const vector<TIndex>& segments) {
TIndex total_count = 0;
for (auto& segment : segments) total_count += segment;
Tensor* cacheT = CreateTensor("/share/cache");
cacheT->Reshape(vector<TIndex>(1, total_count));
vector<T*> caches(segments.size());
caches[0] = cacheT->template mutable_data<T, Context>();
for (int i = 1; i < segments.size(); i++)
caches[i] = caches[i - 1] + segments[i - 1];
return caches;
}
/******************** Operator ********************/
......@@ -297,8 +261,6 @@ class Workspace {
string name_;
WorkspaceMap workspace_map_;
TensorMap tensor_map_;
BufferMap buffer_map_;
LockMap lock_map_;
OperatorMap op_map_;
GraphMap graph_map_;
FillerMap filler_map_;
......
......@@ -24,8 +24,9 @@ class DropoutOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -42,8 +43,9 @@ class DropoutGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -57,6 +59,79 @@ class DropoutGradientOp final : public Operator<Context> {
DEFINE_ARGUMENT_WITH_DESC(float, DropoutOp, prob);
DEFINE_ARGUMENT_WITH_DESC(float, DropoutGradientOp, prob);
#ifdef WITH_CUDNN
#if CUDNN_VERSION_MIN(7, 0, 0)
template <class Context>
class CuDNNDropoutOp final : public Operator<Context> {
public:
CuDNNDropoutOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)),
random_seed(op_def.device_option().random_seed()) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
}
USE_OPERATOR_FUNCTIONS;
~CuDNNDropoutOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
DECLARE_ARGUMENT_WITH_DESC(float, prob);
bool use_scale, states_initialized;
cudnnTensorDescriptor_t input_desc;
cudnnDropoutDescriptor_t dropout_desc;
size_t states_size, reserve_space_size;
unsigned long long random_seed;
};
template <class Context>
class CuDNNDropoutGradientOp final : public Operator<Context> {
public:
CuDNNDropoutGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)),
random_seed(op_def.device_option().random_seed()) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
}
USE_OPERATOR_FUNCTIONS;
~CuDNNDropoutGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
DECLARE_ARGUMENT_WITH_DESC(float, prob);
bool use_scale, states_initialized;
cudnnTensorDescriptor_t input_desc;
cudnnDropoutDescriptor_t dropout_desc;
size_t states_size, reserve_space_size;
unsigned long long random_seed;
};
DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutOp, prob);
DEFINE_ARGUMENT_WITH_DESC(float, CuDNNDropoutGradientOp, prob);
#endif
#endif // WITH_CUDNN
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
\ No newline at end of file
......@@ -22,7 +22,7 @@ class EluOp : public Operator<Context> {
EluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -37,7 +37,7 @@ class EluGradientOp : public Operator<Context> {
EluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -61,7 +61,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNEluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -88,7 +88,7 @@ class CuDNNEluGradientOp final : public EluGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_ELU, CUDNN_PROPAGATE_NAN, this->alpha));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNEluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
......@@ -23,15 +23,14 @@ class PReluOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool channel_shared;
TIndex channel_shared, channels, dim;
string data_format;
TIndex channels, dim;
};
template <class Context>
......@@ -41,16 +40,14 @@ class PReluGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool channel_shared;
TIndex channel_shared, channels, dim;
string data_format;
TIndex channels, dim;
Tensor* bcast_dw, *multiplier;
};
} // namespace dragon
......
......@@ -22,7 +22,7 @@ class ReluOp : public Operator<Context> {
ReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -37,7 +37,7 @@ class ReluGradientOp : public Operator<Context> {
ReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -59,7 +59,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNReluOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -86,7 +86,7 @@ class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNReluGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
......@@ -20,7 +20,7 @@ template <class Context>
class SEluOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SEluOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,7 +30,7 @@ template <class Context>
class SEluGradientOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SEluGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -20,7 +20,7 @@ template <class Context>
class SigmoidOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SigmoidOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,7 +30,7 @@ template <class Context>
class SigmoidGradientOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SigmoidGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -49,7 +49,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNSigmoidOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -76,7 +76,7 @@ class CuDNNSigmoidGradientOp final : public SigmoidGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNSigmoidGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
......@@ -22,15 +22,13 @@ class SoftmaxOp final : public Operator<Context> {
SoftmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
int axis;
TIndex outer_dim, inner_dim;
Tensor* sum_multiplier, *scale;
TIndex axis, outer_dim, inner_dim;
};
template <class Context>
......@@ -39,15 +37,13 @@ class SoftmaxGradientOp final : public Operator<Context> {
SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
int axis;
TIndex outer_dim, inner_dim;
Tensor* sum_multiplier, *scale;
TIndex axis, outer_dim, inner_dim;
};
#ifdef WITH_CUDNN
......@@ -63,7 +59,7 @@ class CuDNNSoftmaxOp final : public Operator<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNSoftmaxOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -88,7 +84,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNSoftmaxGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
......@@ -20,7 +20,7 @@ template <class Context>
class TanhOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(TanhOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,7 +30,7 @@ template <class Context>
class TanhGradientOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(TanhGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -49,7 +49,7 @@ public:
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNTanhOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -76,7 +76,7 @@ class CuDNNTanhGradientOp final : public TanhGradientOp<Context> {
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNTanhGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
......@@ -20,56 +20,44 @@ template <class Context>
class AddOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(AddOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class AddGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(AddGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class RAddOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RAddOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class RAddGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RAddGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
} // namespace dragon
......
......@@ -23,7 +23,7 @@ class AffineOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -31,7 +31,6 @@ class AffineOp : public Operator<Context> {
protected:
TIndex axis, start_axis, num_axes;
TIndex outer_dim, scale_dim, inner_dim;
Tensor* bias_multiplier;
};
template <class Context>
......@@ -41,7 +40,7 @@ class AffineGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void BiasRunWithType();
......@@ -51,10 +50,109 @@ class AffineGradientOp final : public Operator<Context> {
protected:
TIndex axis, start_axis, num_axes;
TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim;
Tensor* bias_multiplier, *sum_multiplier;
Tensor sum_result;
};
#ifdef WITH_CUDNN
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNAffineOpBase : public Operator<Context> {
public:
CuDNNAffineOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc));
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&mul_desc));
CUDNN_CHECK(cudnnCreateOpTensorDescriptor(&add_desc));
CUDNN_CHECK(cudnnCreateReduceTensorDescriptor(&reduce_desc));
}
USE_OPERATOR_FUNCTIONS;
virtual ~CuDNNAffineOpBase() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(param_desc));
CUDNN_CHECK(cudnnDestroyOpTensorDescriptor(mul_desc));
CUDNN_CHECK(cudnnDestroyReduceTensorDescriptor(reduce_desc));
}
template <typename T>
void ResetDesc() {
// determine the range of affine
start_axis = axis;
if (start_axis < 0) start_axis += (int)Input(0).ndim();
if (num_axes == -1) num_axes = (int)Input(0).ndim() - start_axis;
else if (num_axes == 0) num_axes = 1;
end_axis = start_axis + num_axes;
CHECK_LT(start_axis, (int)Input(0).ndim());
CHECK_LE(start_axis + num_axes, (int)Input(0).ndim());
// determine the input desc
vector<TIndex> input_dims = Input(0).dims();
// cudnn requires ndimensions range from [4, 5]
if (input_dims.size() < 4) input_dims.resize(4, 1);
else if (input_dims.size() > 5)
LOG(FATAL) << "CuDNN Affine the dimensions up to 5.";
cudnnSetTensorDesc<T>(&input_desc, input_dims);
// determine the scale desc
vector<TIndex> param_dims(input_dims.size(), 1);
for (int i = start_axis; i < end_axis; i++) param_dims[i] = input_dims[i];
cudnnSetTensorDesc<T>(&param_desc, param_dims);
}
TIndex axis, start_axis, end_axis, num_axes;
cudnnTensorDescriptor_t input_desc, param_desc;
cudnnOpTensorDescriptor_t mul_desc, add_desc;
cudnnReduceTensorDescriptor_t reduce_desc;
};
#define USE_CUDNN_AFFINE_FUCNTIONS \
USE_OPERATOR_FUNCTIONS; \
using CuDNNAffineOpBase<Context>::start_axis; \
using CuDNNAffineOpBase<Context>::num_axes; \
using CuDNNAffineOpBase<Context>::input_desc; \
using CuDNNAffineOpBase<Context>::param_desc; \
using CuDNNAffineOpBase<Context>::mul_desc; \
using CuDNNAffineOpBase<Context>::add_desc; \
using CuDNNAffineOpBase<Context>::reduce_desc
template <class Context>
class CuDNNAffineOp : public CuDNNAffineOpBase<Context> {
public:
CuDNNAffineOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNAffineOpBase<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
USE_CUDNN_AFFINE_FUCNTIONS;
};
template <class Context>
class CuDNNAffineGradientOp : public CuDNNAffineOpBase<Context> {
public:
CuDNNAffineGradientOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNAffineOpBase<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void ComputeScaleGradient(T* dYxX, T* dA);
template <typename T> void ComputeScaleGradient_v2(T* dYxX, T* dA);
template <typename T> void ComputeBiasGradient(const T* dY, T* dB);
template <typename T> void ComputeBiasGradient_v2(const T* dY, T* dB);
template <typename T> void RunWithType();
protected:
USE_CUDNN_AFFINE_FUCNTIONS;
TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim;
Tensor sum_result;
};
#endif // WITH_CUDNN
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_AFFINE_OP_H_
\ No newline at end of file
......@@ -24,27 +24,23 @@ class ClipOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)),
high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float low, high;
Tensor* mask;
};
template <class Context>
class ClipGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ClipGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
Tensor* mask;
};
} // namespace dragon
......
......@@ -20,56 +20,44 @@ template <class Context>
class DivOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(DivOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class DivGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(DivGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class RDivOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RDivOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class RDivGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RDivGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
} // namepsace dragon
......
......@@ -19,11 +19,11 @@ namespace dragon {
template <class Context>
class DotOp final : public Operator<Context> {
public:
DotOp(const OperatorDef& op_def, Workspace* ws)
DotOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void DotRunWithType();
......@@ -31,18 +31,17 @@ class DotOp final : public Operator<Context> {
template <typename T> void GemvRunWithType();
protected:
bool transA, transB;
TIndex M, K1, K2, N1, N2;
TIndex TransA, TransB, M, K1, K2, N1, N2;
};
template <class Context>
class DotGradientOp final : public Operator<Context> {
public:
DotGradientOp(const OperatorDef& op_def, Workspace* ws)
DotGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void DotRunWithType();
......@@ -50,8 +49,7 @@ class DotGradientOp final : public Operator<Context> {
template <typename T> void GemvRunWithType();
protected:
bool transA, transB;
TIndex M, K1, K2, N1, N2;
TIndex TransA, TransB, M, K1, K2, N1, N2;
};
} // namespace dragon
......
......@@ -29,7 +29,7 @@ class EltwiseOp final : public Operator<Context> {
<< "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void SumRunWithType();
......@@ -53,7 +53,7 @@ class EltwiseGradientOp final : public Operator<Context> {
<< "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void SumRunWithType();
......
......@@ -20,7 +20,7 @@ template <class Context>
class ExpOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,7 +30,7 @@ template <class Context>
class ExpGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -22,7 +22,7 @@ class GramMatrixOp final : public Operator<Context> {
GramMatrixOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -38,7 +38,7 @@ class GramMatrixGradientOp final : public Operator<Context> {
GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -23,17 +23,15 @@ class InnerProductOp: public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS(Context);
TransW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice();
template <typename T> void TransRunWithType();
template <typename T> void NoTransRunWithType();
protected:
TIndex axis, num_output, M, K;
bool transW;
Tensor* bias_multiplier;
TIndex axis, num_output, TransW, M, K;
};
template <class Context>
......@@ -43,16 +41,14 @@ class InnerProductGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS(Context);
TransW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, num_output, M, K;
bool transW;
Tensor* bias_multiplier;
TIndex axis, num_output, TransW, M, K;
};
} // namespace dragon
......
......@@ -20,7 +20,7 @@ template <class Context>
class LogOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LogOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,7 +30,7 @@ template <class Context>
class LogGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LogGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -19,37 +19,35 @@ namespace dragon {
template <class Context>
class MatmulOp final : public Operator<Context> {
public:
MatmulOp(const OperatorDef& op_def, Workspace* ws)
MatmulOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool transA, transB;
TIndex TransA, TransB, M, K1, K2, N;
TIndex n, x1_offset, x2_offset, y_offset;
TIndex M, K1, K2, N;
};
template <class Context>
class MatmulGradientOp final : public Operator<Context> {
public:
MatmulGradientOp(const OperatorDef& op_def, Workspace* ws)
MatmulGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
TransA(OperatorBase::GetSingleArg<bool>("TransA", false)),
TransB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool transA, transB;
TIndex TransA, TransB, M, K1, K2, N;
TIndex n, x1_offset, x2_offset, y_offset;
TIndex M, K1, K2, N;
};
} // namespace dragon
......
......@@ -20,56 +20,44 @@ template <class Context>
class MulOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MulOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class MulGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MulGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class RMulOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RMulOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class RMulGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RMulGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
} // namespace dragon
......
......@@ -26,7 +26,7 @@ class PowOp: public Operator<Context> {
power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -45,7 +45,7 @@ class PowGradientOp final : public Operator<Context> {
power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -20,7 +20,7 @@ template <class Context>
class SquareOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SquareOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,7 +30,7 @@ template <class Context>
class SquareGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SquareGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -20,56 +20,44 @@ template <class Context>
class SubOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SubOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class SubGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SubGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class RSubOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RSubOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class RSubGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(RSubGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
} // namespace dragon
......
......@@ -22,7 +22,7 @@ class CompareOp final : public Operator<Context> {
CompareOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void EqualRunWithType();
......
......@@ -20,7 +20,7 @@ template <class Context>
class CopyOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CopyOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -31,7 +31,7 @@ class ScanOp final: public Operator<Context> {
debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) {
InitTemplate();
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void InitTemplate();
......@@ -68,7 +68,7 @@ class ScanGradientOp final: public Operator<Context> {
for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = Output(i)->name();
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void MakeOps(const GraphDef& forward_def, GraphDef& new_def);
......
......@@ -22,7 +22,7 @@ class L1LossOp : public Operator<Context> {
L1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -38,7 +38,7 @@ class L1LossGradientOp final : public Operator<Context> {
L1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -22,7 +22,7 @@ class L2LossOp : public Operator<Context> {
L2LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -38,7 +38,7 @@ class L2LossGradientOp final : public Operator<Context> {
L2LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -22,7 +22,7 @@ class SigmoidCrossEntropyOp final : public Operator<Context> {
SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -38,7 +38,7 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -23,7 +23,7 @@ class SmoothL1LossOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -41,7 +41,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -24,7 +24,7 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void SoftmaxRun();
void RunOnDevice() override;
......@@ -45,7 +45,7 @@ class SoftmaxCrossEntropyGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -30,7 +30,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void SoftmaxRun();
void SoftmaxRunFP16();
......@@ -60,7 +60,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
......
......@@ -29,13 +29,13 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float alpha, gamma;
float alpha, gamma;
int neg_id;
float pos_alpha, neg_alpha;
TIndex axis, outer_dim, inner_dim;
......@@ -53,7 +53,7 @@ class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyG
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -30,7 +30,7 @@ class AccuracyOp final: public Operator<Context> {
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
......
......@@ -23,13 +23,13 @@ class AsTypeOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "float32")),
inplace(OperatorBase::GetSingleArg<bool>("inplace", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
protected:
string dtype;
bool inplace;
string dtype;
bool inplace;
};
} // namespace dragon
......
......@@ -25,7 +25,7 @@ class GradientGenerateOp final: public Operator<Context> {
CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize());
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -42,7 +42,7 @@ class GradientGatherOp final : public Operator<Context> {
for (int i = 0; i < InputSize(); i++)
if (Input(i).name() != "ignore") indices.push_back(i);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -55,7 +55,7 @@ template <class Context>
class StopGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(StopGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
......
......@@ -40,7 +40,7 @@ class ImageDataOp final : public Operator<Context> {
std.mutable_data<float, CPUContext>()[i] = std_values[i];
}
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
......
......@@ -25,7 +25,7 @@ class InitializeOp: public Operator<Context> {
shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) {
GET_ARGUMENTS_WITH_DESC(int, dims);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -44,7 +44,7 @@ public:
this->filler.set_type("constant");
this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
......@@ -56,7 +56,7 @@ public:
this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0));
this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
......@@ -68,7 +68,7 @@ public:
this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
......@@ -84,7 +84,7 @@ public:
this->filler.set_low(mu - 2 * sigma);
this->filler.set_high(mu + 2 * sigma);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
......@@ -105,7 +105,7 @@ public:
}
this->filler.set_scale(scale);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
......@@ -126,7 +126,7 @@ public:
}
this->filler.set_scale(scale);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
};
DEFINE_ARGUMENTS_WITH_DESC(int, InitializeOp, dims);
......
......@@ -24,7 +24,7 @@ template <class Context>
class RunOp : public Operator<Context> {
public:
RunOp(const OperatorDef& op_def, Workspace* ws);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -38,7 +38,7 @@ class TemplateOp : public RunOp<Context> {
public:
TemplateOp(const OperatorDef& op_def, Workspace* ws)
: RunOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
......@@ -46,7 +46,7 @@ class TemplateGradientOp : public TemplateOp<Context> {
public:
TemplateGradientOp(const OperatorDef& op_def, Workspace* ws)
: TemplateOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
......
......@@ -48,7 +48,7 @@ class ModelMPIBase : public Operator<Context> {
return MPI_DATATYPE_NULL;
}
protected:
public:
MPI_Comm comm;
MPI_Group group;
int comm_size, comm_rank, comm_root;
......@@ -57,7 +57,12 @@ class ModelMPIBase : public Operator<Context> {
};
#define USE_MPIMODEL_FUNCTIONS(context) \
using ModelMPIBase<context>::mpi_dtype
using ModelMPIBase<context>::comm; \
using ModelMPIBase<context>::mpi_dtype; \
using ModelMPIBase<context>::comm_size; \
using ModelMPIBase<context>::comm_rank; \
using ModelMPIBase<context>::comm_root; \
using ModelMPIBase<context>::dtype
} // namespace dragon
......
......@@ -23,7 +23,7 @@ class MPIBroadcastOp final : public ModelMPIBase<Context> {
public:
MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override;
......@@ -35,7 +35,7 @@ class MPIBroadcastGradientOp final : public ModelMPIBase<Context> {
public:
MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override;
......@@ -46,7 +46,4 @@ public:
#endif // WITH_MPI
#endif //DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
#endif //DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
\ No newline at end of file
......@@ -23,7 +23,7 @@ class MPIGatherOp final : public ModelMPIBase<Context> {
public:
MPIGatherOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override;
......@@ -35,7 +35,7 @@ class MPIGatherGradientOp final : public ModelMPIBase<Context> {
public:
MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_MPIMODEL_FUNCTIONS(Context);
void RunOnDevice() override;
......
......@@ -26,7 +26,7 @@ class ArangeOp final : public Operator<Context> {
GET_ARGUMENT_WITH_DESC(int, stop, 0);
GET_ARGUMENT_WITH_DESC(int, step, 1);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -25,7 +25,7 @@ class ArgReduceOp final : public Operator<Context> {
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -22,7 +22,7 @@ class ConcatOp : public Operator<Context> {
ConcatOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -39,7 +39,7 @@ class ConcatGradientOp : public Operator<Context> {
ConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -28,7 +28,7 @@ class CropOp: public Operator<Context> {
GET_ARGUMENTS_WITH_DESC(int, starts);
GET_ARGUMENTS_WITH_DESC(int, ends);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
void RunOnDevice() override;
......@@ -37,32 +37,33 @@ class CropOp: public Operator<Context> {
protected:
TIndex start_axis;
string shape_like;
vector<int> st, ed, offsets, shape, keep_dims;
vector<int> offsets, shape;
vector<TIndex> st, ed, keep_dims;
DECLARE_ARGUMENTS_WITH_DESC(int, starts);
DECLARE_ARGUMENTS_WITH_DESC(int, ends);
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
Tensor* dest, *source, navigator;
};
DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, starts);
DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, ends);
template <class Context>
class CropGradientOp final : public Operator<Context > {
class CropGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CropGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
vector<int> st, ed, offsets, keep_dims;
vector<TIndex> st, ed, offsets, keep_dims;
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
Tensor* dest, *source, navigator;
};
} // namespace dragon
......
......@@ -22,7 +22,7 @@ class ExpandDimsOp final : public Operator<Context> {
ExpandDimsOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -34,7 +34,7 @@ template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpandDimsGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
......
......@@ -24,7 +24,7 @@ class FlattenOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
keep_axes(OperatorBase::GetSingleArg<int>("keep_axes", INT_MAX)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
void SqueezeRun();
......@@ -38,7 +38,7 @@ template <class Context>
class FlattenGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(FlattenGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
......
......@@ -22,7 +22,7 @@ class GatherOp final : public Operator<Context> {
GatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -39,7 +39,7 @@ class GatherGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -24,7 +24,7 @@ class OneHotOp final : public Operator < Context > {
depth(OperatorBase::GetSingleArg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -35,7 +35,7 @@ class PadOp final : public Operator<Context> {
}
std::sort(process_axes.begin(), process_axes.end());
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void ConstRunWithType();
......@@ -48,7 +48,7 @@ class PadOp final : public Operator<Context> {
float value;
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
Tensor* dest, *source, navigator;
};
template <class Context>
......@@ -70,7 +70,7 @@ class PadGradientOp final : public Operator<Context> {
std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end());
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void ConstRunWithType();
......@@ -82,7 +82,7 @@ class PadGradientOp final : public Operator<Context> {
string mode;
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
Tensor* dest, *source, navigator;
};
} // namespace dragon
......
......@@ -23,7 +23,7 @@ class RandomPickOp : public Operator<Context> {
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
max_samples(OperatorBase::GetSingleArg<int>("max_samples", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -41,7 +41,7 @@ public:
RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -24,35 +24,33 @@ class ReduceOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void SumRunWithType();
template <typename T> void MeanRunWithType();
protected:
bool keep_dims;
TIndex axis, keep_dims, axis_dim, count, inner_dim;
string operation;
TIndex axis, axis_dim, count, inner_dim;
Tensor* multiplier;
};
template <class Context>
class ReduceGradientOp final : public Operator<Context> {
public:
ReduceGradientOp(const OperatorDef& op_def, Workspace* ws)
ReduceGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void SumRunWithType();
template <typename T> void MeanRunWithType();
protected:
string operation;
TIndex axis, axis_dim, count, inner_dim;
string operation;
};
} // namespace dragon
......
......@@ -24,7 +24,7 @@ class RepeatOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template<typename T> void RunWithType();
......@@ -42,7 +42,7 @@ class RepeatGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
GET_ARGUMENT_WITH_DESC(int, repeats, 1);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template<typename T> void RunWithType();
......
......@@ -24,7 +24,7 @@ class ReshapeOp final : public Operator<Context> {
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -38,7 +38,7 @@ template <class Context>
class ReshapeGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReshapeGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
......
......@@ -20,7 +20,7 @@ template <class Context>
class ShapeOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ShapeOp);
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
};
......
......@@ -23,7 +23,7 @@ class SliceOp : public Operator<Context> {
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -42,7 +42,7 @@ class SliceGradientOp final : public Operator<Context> {
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -23,7 +23,7 @@ class StackOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -41,7 +41,7 @@ class StackGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -23,7 +23,7 @@ class TileOp : public Operator<Context> {
: Operator<Context>(op_def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template<typename T> void TileRunWithType();
......@@ -31,7 +31,7 @@ class TileOp : public Operator<Context> {
protected:
DECLARE_ARGUMENTS_WITH_DESC(int, multiples);
TIndex axis, multiple, outer_dim, ex_inner_dim;
Tensor* dest, *source;
Tensor* dest, *source, navigator;
};
template <class Context>
......@@ -41,7 +41,7 @@ class TileGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template<typename T> void TileRunWithType();
......@@ -49,7 +49,7 @@ class TileGradientOp : public Operator<Context> {
protected:
DECLARE_ARGUMENTS_WITH_DESC(int, multiples);
TIndex axis, multiple, outer_dim, ex_inner_dim;
Tensor* dest, *source;
Tensor* dest, *source, navigator;
};
DEFINE_ARGUMENTS_WITH_DESC(int, TileOp, multiples);
......
......@@ -25,7 +25,7 @@ class TransposeOp final: public Operator<Context> {
if (perms.size() > 0) reverse_dims = false;
else reverse_dims = true;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -42,7 +42,7 @@ class TransposeGradientOp final : public Operator<Context> {
public:
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -32,7 +32,7 @@ class BatchNormOp : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -42,12 +42,9 @@ class BatchNormOp : public Operator<Context> {
protected:
float momentum, eps;
Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* stddev, *var;
TIndex axis, N, C, S, NC, NS;
Tensor nc, mean, *var;
TIndex axis, use_stats, N, C, S, NC, NS;
string data_format, mode;
int use_stats;
bool use_global_stats, is_recomputing;
};
......@@ -62,7 +59,7 @@ class BatchNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -71,12 +68,9 @@ class BatchNormGradientOp final : public Operator<Context> {
template <typename T> void InferenceRunWithType();
protected:
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* stddev, *var;
TIndex axis, N, C, S, NC, NS;
TIndex axis, use_stats, N, C, S, NC, NS;
Tensor nc, *var;
string data_format;
int use_stats;
bool use_global_stats;
};
......@@ -89,7 +83,7 @@ class FusedBatchNormOp : public Operator<Context> {
momentum(OperatorBase::GetSingleArg<float>("momentum", 0.9f)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -98,13 +92,10 @@ class FusedBatchNormOp : public Operator<Context> {
template <typename T> void InferenceRunWithType();
protected:
TIndex axis, use_stats, N, C, S, NC, NS;
float momentum, eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex axis, N, C, S, NC, NS;
Tensor nc, *mean, *var, *x_norm;
string data_format;
int use_stats;
bool use_global_stats, is_recomputing;
};
......@@ -116,7 +107,7 @@ class FusedBatchNormGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -125,13 +116,10 @@ class FusedBatchNormGradientOp : public Operator<Context> {
template <typename T> void InferenceRunWithType();
protected:
TIndex axis, use_stats, N, C, S, NC, NS;
float eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex axis, N, C, S, NC, NS;
Tensor nc, *mean, *var, *x_norm;
string data_format;
int use_stats;
bool use_global_stats;
};
......@@ -156,7 +144,7 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
<< "CUDNN_BN_MIN_EPSILON instead.";
eps64 = std::max(eps64, CUDNN_BN_MIN_EPSILON);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNBatchNormOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -170,12 +158,12 @@ class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
template <typename T> void RunWithType();
protected:
TIndex N, C;
double eps64;
Tensor* mean, *var;
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
cudnnBatchNormMode_t bn_mode;
TIndex N, C;
string data_format;
Tensor* mean, *var;
};
template <class Context>
......@@ -193,7 +181,7 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context>
<< "CUDNN_BN_MIN_EPSILON instead.";
eps64 = std::max(eps64, CUDNN_BN_MIN_EPSILON);
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNBatchNormGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -208,20 +196,18 @@ class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context>
template <typename T> void InferenceRunWithType();
protected:
TIndex N, C, S, NC, NS;
double eps64;
Tensor nc, *mean, *var;
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
cudnnBatchNormMode_t bn_mode;
TIndex N, C, S, NC, NS;
string data_format;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev;
};
#endif
#endif // WITH_CUDNN
} // namespace dragon
} // namespace dragon
#endif // DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
\ No newline at end of file
......@@ -34,7 +34,7 @@ class BatchRenormOp : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -43,14 +43,12 @@ class BatchRenormOp : public Operator<Context> {
template <typename T> void InferenceRunWithType();
protected:
TIndex axis, use_stats, N, C, S, NC, NS;
float momentum, eps, r_max, d_max, t_delta;
float t_r_max, t_d_max, t_val;
Tensor mean, d, t_h_mean, t_h_var, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* stddev, *r, *var, *x_norm;
TIndex axis, N, C, S, NC, NS;
Tensor nc, mean, d, t_h_mean, t_h_var;
Tensor* r, *var, *x_norm;
string data_format, mode;
int use_stats;
bool use_global_stats, is_recomputing;
};
......@@ -65,7 +63,7 @@ class BatchRenormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -76,12 +74,9 @@ class BatchRenormGradientOp final : public Operator<Context> {
template <typename T> void RunWithType();
protected:
Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* stddev, *r, *var, *x_norm;
TIndex axis, N, C, S, NC, NS;
TIndex axis, use_stats, N, C, S, NC, NS;
Tensor nc, mean, *r, *var, *x_norm;
string data_format;
int use_stats;
bool use_global_stats;
};
......
......@@ -28,7 +28,7 @@ class GroupNormOp : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -37,10 +37,8 @@ class GroupNormOp : public Operator<Context> {
protected:
float eps;
Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
Tensor nc, mean, *var;
string data_format;
};
......@@ -55,7 +53,7 @@ class GroupNormGradientOp final : public Operator<Context> {
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -63,10 +61,8 @@ class GroupNormGradientOp final : public Operator<Context> {
template <typename T> void RunWithType();
protected:
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
Tensor nc, *var;
string data_format;
};
......@@ -78,7 +74,7 @@ class FusedGroupNormOp : public Operator<Context> {
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -86,11 +82,9 @@ class FusedGroupNormOp : public Operator<Context> {
template <typename T> void RunWithType();
protected:
float eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
float eps;
Tensor nc, *mean, *var, *x_norm;
string data_format;
};
......@@ -101,7 +95,7 @@ class FusedGroupNormGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -109,10 +103,8 @@ class FusedGroupNormGradientOp : public Operator<Context> {
template <typename T> void RunWithType();
protected:
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
Tensor nc, *mean, *var, *x_norm;
string data_format;
};
......
......@@ -23,11 +23,10 @@ class InstanceNormOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-3f)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
if (axis != -1)
CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -35,10 +34,9 @@ class InstanceNormOp : public Operator<Context> {
template <typename T> void RunWithType();
protected:
float eps;
Tensor mean;
Tensor* spatial_multiplier, *stddev, *var;
TIndex axis, N, C, S, NC, CS;
float eps;
Tensor mean, *var;
string data_format;
};
......@@ -49,10 +47,9 @@ class InstanceNormGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
CHECK_EQ(axis, 1) << "\nThe axis can only be set to 1.";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Setup();
......@@ -60,9 +57,9 @@ class InstanceNormGradientOp final : public Operator<Context> {
template <typename T> void RunWithType();
protected:
Tensor* spatial_multiplier, *stddev, *var;
TIndex axis, N, C, S, NC, CS;
string data_format;
TIndex axis, N, C, S, NC, CS;
Tensor *var;
string data_format;
};
} // namespace dragon
......
......@@ -25,17 +25,17 @@ class L2NormOp final : public Operator<Context> {
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", 1e-5f)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float eps;
TIndex axis, num_axes, end_axis;
float eps;
string mode;
bool across_inner;
Tensor* norm, *buffer, *multiplier;
Tensor* norm, buffer;
TIndex outer_dim, dim, inner_dim, spatial_dim;
};
......@@ -47,7 +47,7 @@ class L2NormGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -56,7 +56,7 @@ class L2NormGradientOp final : public Operator<Context> {
TIndex axis, num_axes, end_axis;
string mode;
bool across_inner;
Tensor* norm, *multiplier, *buffer, *buffer_inner;
Tensor* norm, buffer, buffer_inner;
TIndex outer_dim, dim, inner_dim;
};
......
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#include "core/operator.h"
namespace dragon {
#ifdef WITH_CUDNN
#if CUDNN_VERSION_MIN(5, 0, 0)
#include "utils/cudnn_device.h"
class cudnnTensorDescriptors {
public:
cudnnTensorDescriptors(const int num_descs) {
descs_.resize(num_descs);
for (int i = 0; i < num_descs; ++i)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&descs_[i]));
}
~cudnnTensorDescriptors() {
for (auto desc : descs_)
cudnnDestroyTensorDescriptor(desc);
}
template <typename T>
void Set(const vector<TIndex>& dims, const vector<TIndex>& strides) {
CHECK_EQ(dims.size(), strides.size());
for (auto desc : descs_) cudnnSetTensorDesc<T>(&desc, dims, strides);
}
const cudnnTensorDescriptor_t* descs() const { return descs_.data(); }
protected:
vector<cudnnTensorDescriptor_t> descs_;
};
template <class Context>
class CuDNNRecurrentOpBase : public Operator<Context> {
public:
CuDNNRecurrentOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), states_initialized(false),
hidden_size(OperatorBase::GetSingleArg<int>("hidden_size", 0)),
num_layers(OperatorBase::GetSingleArg<int>("num_layers", 1)),
bidirectional(OperatorBase::GetSingleArg<bool>("bidirectional", false)),
dropout_ratio(OperatorBase::GetSingleArg<float>("dropout_ratio", 1.0)),
random_seed(op_def.device_option().random_seed()) {
// determine the rnn direction
rnn_direction = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// determine the rnn mode
const string mode = OperatorBase::GetSingleArg<string>("rnn_mode", "");
if (mode == "rnn_tanh") rnn_mode = CUDNN_RNN_TANH;
else if (mode == "rnn_relu") rnn_mode = CUDNN_RNN_RELU;
else if (mode == "lstm") rnn_mode = CUDNN_LSTM;
else if (mode == "gru") rnn_mode = CUDNN_GRU;
else LOG(FATAL) << "Unsupported rnn mode: " << mode;
// determine the rnn input mode
const string input_mode = OperatorBase::GetSingleArg<string>("rnn_input_mode", "linear");
if (input_mode == "skip") rnn_input_mode = CUDNN_SKIP_INPUT;
else if (input_mode == "linear") rnn_input_mode = CUDNN_LINEAR_INPUT;
else LOG(FATAL) << "Unsupported rnn input mode: " << input_mode;
// override the running phase
SwitchToPhase(OperatorBase::GetSingleArg<string>("phase", ""));
CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc));
CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc));
CUDNN_CHECK(cudnnCreateFilterDescriptor(&w_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc));
}
USE_OPERATOR_FUNCTIONS;
virtual ~CuDNNRecurrentOpBase() {
CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc));
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc));
CUDNN_CHECK(cudnnDestroyFilterDescriptor(w_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc));
}
template <typename T> void ResetDesc(Tensor* X, Tensor* Hx, Tensor* Cx,
Tensor* Y, Tensor* Hy, Tensor* Cy);
public:
TIndex hidden_size, num_layers;
bool bidirectional, states_initialized;
float dropout_ratio;
unsigned long long random_seed;
cudnnRNNDescriptor_t rnn_desc;
cudnnDropoutDescriptor_t dropout_desc;
cudnnDirectionMode_t rnn_direction;
cudnnRNNMode_t rnn_mode;
cudnnRNNInputMode_t rnn_input_mode;
cudnnFilterDescriptor_t w_desc;
cudnnTensorDescriptor_t hx_desc, cx_desc;
cudnnTensorDescriptor_t hy_desc, cy_desc;
vector<TIndex> input_dims;
size_t workspace_size, reserve_size, states_size;
std::unique_ptr<cudnnTensorDescriptors> xs_desc;
std::unique_ptr<cudnnTensorDescriptors> ys_desc;
};
#define USE_CUDNN_RECURRENT_FUNCTIONS \
USE_OPERATOR_FUNCTIONS; \
using CuDNNRecurrentOpBase<Context>::dropout_desc; \
using CuDNNRecurrentOpBase<Context>::rnn_desc; \
using CuDNNRecurrentOpBase<Context>::w_desc; \
using CuDNNRecurrentOpBase<Context>::hx_desc; \
using CuDNNRecurrentOpBase<Context>::cx_desc; \
using CuDNNRecurrentOpBase<Context>::hy_desc; \
using CuDNNRecurrentOpBase<Context>::cy_desc; \
using CuDNNRecurrentOpBase<Context>::xs_desc; \
using CuDNNRecurrentOpBase<Context>::ys_desc; \
using CuDNNRecurrentOpBase<Context>::input_dims; \
using CuDNNRecurrentOpBase<Context>::workspace_size; \
using CuDNNRecurrentOpBase<Context>::reserve_size
template <class Context>
class CuDNNRecurrentOp : public CuDNNRecurrentOpBase<Context> {
public:
CuDNNRecurrentOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(op_def, ws) {}
USE_CUDNN_RECURRENT_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class CuDNNRecurrentGradientOp : public CuDNNRecurrentOpBase<Context> {
public:
CuDNNRecurrentGradientOp(const OperatorDef& op_def, Workspace* ws)
: CuDNNRecurrentOpBase<Context>(op_def, ws) {}
USE_CUDNN_RECURRENT_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
};
#endif
#endif // WITH_CUDNN
} // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
\ No newline at end of file
......@@ -9,45 +9,36 @@
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class LSTMUnitOp : public Operator<Context> {
class LSTMCellOp : public Operator<Context> {
public:
LSTMUnitOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
has_cont(OperatorBase::GetSingleArg<string>("cont_t", "")) {}
USE_OPERATOR_FUNCTIONS(Context);
LSTMCellOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType();
protected:
TIndex num, channels;
string has_cont;
Tensor* cont_t;
};
template <class Context>
class LSTMUnitGradientOp : public Operator<Context> {
class LSTMCellGradientOp : public Operator<Context> {
public:
LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws)
LSTMCellGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex num, channels;
Tensor* zeros;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_RECURRENT_LSTM_CELL_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_CUDNN_RECURRENT_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class RecurrentOp : public Operator<Context> {
public:
RecurrentOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
LOG(FATAL) << "RNN Operators require CuDNN support.";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {}
};
template <class Context>
class RecurrentGradientOp : public Operator<Context> {
public:
RecurrentGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
LOG(FATAL) << "RNN Operators require CuDNN support.";
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override {}
};
} // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_RECURRENT_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
// along with the software. If not, See,
//
// <https://opensource.org/licenses/BSD-2-Clause>
//
// -------------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
#define DRAGON_OPERATORS_RECURRENT_RNN_PARAM_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class RNNParamSetOp : public Operator<Context> {
public:
RNNParamSetOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
param_type(OperatorBase::GetSingleArg<string>("param_type", "matrix")),
rnn_mode(OperatorBase::GetSingleArg<string>("rnn_mode", "rnn_tanh")),
num_layers(OperatorBase::GetSingleArg<int>("num_layers", 1)),
num_directions(OperatorBase::GetSingleArg<int>("num_directions", 1)),
input_size(OperatorBase::GetSingleArg<int>("input_size", 0)),
hidden_size(OperatorBase::GetSingleArg<int>("hidden_size", 0)),
layer_id(OperatorBase::GetSingleArg<int>("layer_id", 0)),
param_id(OperatorBase::GetSingleArg<int>("param_id", 0)) {
if (rnn_mode == "rnn_tanh") { num_params = 2; spliter = 1; }
else if (rnn_mode == "rnn_relu") { num_params = 2; spliter = 1; }
else if (rnn_mode == "lstm") { num_params = 8; spliter = 4; }
else if (rnn_mode == "gru") { num_params = 6; spliter = 3; }
else LOG(FATAL) << "Unsupported rnn mode: " << rnn_mode;
input_ex_size = hidden_size * num_directions;
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
string param_type, rnn_mode;
TIndex num_layers, num_directions, num_params, spliter;
TIndex input_size, input_ex_size, hidden_size;
TIndex layer_id, param_id;
};
} // namespace dragon
#endif
\ No newline at end of file
......@@ -21,7 +21,7 @@ class AdamUpdateOp final : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), t(0) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
......
......@@ -27,7 +27,7 @@ class CollectiveUpdateOp : public Operator<Context> {
InitMPI();
if (mode.find("NCCL") != string::npos) InitNCCL();
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void InitMPI();
void InitNCCL();
......@@ -41,7 +41,6 @@ class CollectiveUpdateOp : public Operator<Context> {
protected:
int comm_size, comm_rank, comm_root;
int world_size, world_rank;
Tensor* buffer;
string mode;
MPI_Comm comm;
......
......@@ -22,7 +22,7 @@ class MovingAverageOp final : public Operator<Context> {
MovingAverageOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -21,7 +21,7 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
......
......@@ -21,7 +21,7 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
......
......@@ -22,7 +22,7 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
old_lr(-1.f), correction(1.f) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
......
......@@ -27,7 +27,7 @@ class UpdateOpBase : public Operator<Context> {
zero_grad(OperatorBase::GetSingleArg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
float Param(const string& name) const;
string Slot();
......
......@@ -9,8 +9,8 @@
//
// ------------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
#ifndef DRAGON_OPERATORS_VISION_BIAS_ADD_OP_H_
#define DRAGON_OPERATORS_VISION_BIAS_ADD_OP_H_
#include "core/operator.h"
......@@ -22,7 +22,7 @@ class BiasAddOp : public Operator<Context> {
BiasAddOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -30,7 +30,6 @@ class BiasAddOp : public Operator<Context> {
protected:
TIndex outer_dim, dim, inner_dim;
string data_format;
Tensor* bias_multiplier;
};
template <class Context>
......@@ -39,7 +38,7 @@ class BiasAddGradientOp final : public Operator<Context> {
BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -47,9 +46,8 @@ class BiasAddGradientOp final : public Operator<Context> {
protected:
int outer_dim, dim, inner_dim;
string data_format;
Tensor* bias_multiplier;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_VISION_BIAS_ADD_OP_H_
\ No newline at end of file
......@@ -30,7 +30,7 @@ class BilinearResizeOp : public Operator<Context> {
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -48,7 +48,7 @@ class BilinearResizeGradientOp : public Operator<Context> {
BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -24,7 +24,7 @@ class Conv2dOp : public ConvOpBase<Context> {
this->num_spatial_axes = 2;
Setup();
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context);
bool ReverseDimensions() override { return false; }
......@@ -39,7 +39,7 @@ class Conv2dGradientOp : public Conv2dOp<Context> {
public:
Conv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context);
bool HasBias() override { return Output(2)->name() != "ignore"; }
......@@ -61,7 +61,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else
cudnn_group = this->group;
cudnn_group = group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group];
......@@ -77,11 +77,11 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
if (data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dOp() {
......@@ -109,7 +109,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size;
size_t fwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
bool enable_tensor_core;
......@@ -124,7 +124,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else
cudnn_group = this->group;
cudnn_group = group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
......@@ -139,11 +139,11 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
if (data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dGradientOp() {
......@@ -172,7 +172,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size;
size_t bwd_filter_size, bwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
bool enable_tensor_core;
......
......@@ -34,19 +34,18 @@ class ConvOpBase : public Operator<Context> {
else LOG(FATAL) << "Unknown data format: " << data_format;
num_spatial_axes = -1; // unknown
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
protected:
public:
vector<TIndex> kernel_size, stride, pad, dilation;
string data_format, padding;
vector<TIndex> input_shape, output_shape, bottom_shape, top_shape, col_shape;
vector<TIndex> input_shape, output_shape, bottom_shape, top_shape;
vector<TIndex> weight_shape, bias_shape;
Tensor* col_buffer, *bias_multiplier;
TIndex num_output, group;
TIndex spatial_axis, num_spatial_axes;
TIndex channels, out_spatial_dim;
TIndex conv_in_channels, conv_out_channels;
TIndex conv_out_spatial_dim, kernel_dim;
TIndex conv_out_spatial_dim, kernel_dim, col_dim;
TIndex col_offset, output_offset, weight_offset, x_offset, y_offset;
DECLARE_ARGUMENTS_WITH_DESC(int, output_dims);
bool is_1x1;
......@@ -58,10 +57,15 @@ class ConvOpBase : public Operator<Context> {
virtual bool ReverseDimensions() = 0;
virtual bool HasBias() = 0;
template <typename T> void Wx(const T* x, const T* weights, T* y, bool skip_im2col = false);
template <typename T> void Wx(const T* x,
const T* weights, T* y, bool skip_im2col = false);
template <typename T> void Pb(const T* bias, T* y);
template <typename T> void Dx(const T* dy, const T* weights, T* dx);
template <typename T> void Dw(const T* dy, const T* x, T *dw);
template <typename T> void Dw(const T* dy, const T* x, T* dw);
template <typename T> void Db(const T* dy, T* db);
private:
......@@ -108,7 +112,20 @@ DEFINE_ARGUMENTS_WITH_DESC(int, ConvOpBase, output_dims);
using ConvOpBase<context>::Pb; \
using ConvOpBase<context>::Dx; \
using ConvOpBase<context>::Dw; \
using ConvOpBase<context>::Db
using ConvOpBase<context>::Db; \
using ConvOpBase<context>::kernel_size; \
using ConvOpBase<context>::stride; \
using ConvOpBase<context>::pad; \
using ConvOpBase<context>::dilation; \
using ConvOpBase<context>::group; \
using ConvOpBase<context>::channels; \
using ConvOpBase<context>::num_output; \
using ConvOpBase<context>::data_format; \
using ConvOpBase<context>::x_offset; \
using ConvOpBase<context>::y_offset; \
using ConvOpBase<context>::weight_offset; \
using ConvOpBase<context>::weight_shape; \
using ConvOpBase<context>::bias_shape
} // namespace dragon
......
......@@ -24,7 +24,7 @@ class Conv2dTransposeOp: public ConvOpBase<Context> {
this->num_spatial_axes = 2;
Setup();
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context);
bool ReverseDimensions() override { return true; }
......@@ -43,7 +43,7 @@ class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> {
public:
Conv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context);
bool HasBias() override { return Output(2)->name() != "ignore"; }
......@@ -65,12 +65,12 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else
cudnn_group = this->group;
cudnn_group = group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
for (int g = 0; g < this->group; g++) {
for (int g = 0; g < cudnn_group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
......@@ -80,11 +80,11 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
if (data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dTransposeOp() {
......@@ -112,7 +112,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size;
size_t fwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
bool enable_tensor_core;
......@@ -127,7 +127,7 @@ public:
cudnn_group = 1;
enable_tensor_core &= TENSOR_CORE_AVAILABLE();
#else
cudnn_group = this->group;
cudnn_group = group;
enable_tensor_core = false;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
......@@ -142,11 +142,11 @@ public:
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
if (data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
USE_CONVOLUTION_FUNCTIONS(Context);
~CuDNNConv2dTransposeGradientOp() {
......@@ -175,7 +175,7 @@ public:
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size;
size_t bwd_filter_size, bwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
bool enable_tensor_core;
......
......@@ -21,7 +21,7 @@ class DenseConcatOp final : public ConcatOp<Context> {
public:
DenseConcatOp(const OperatorDef& op_def, Workspace* ws)
: ConcatOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
};
template <class Context>
......@@ -30,7 +30,7 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> {
DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: ConcatGradientOp<Context>(op_def, ws),
growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void ElimateCorruption() override;
template <typename T> void RestoreX1();
......
......@@ -29,7 +29,7 @@ class LRNOp : public Operator<Context> {
k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -60,7 +60,7 @@ class LRNGradientOp : public Operator<Context> {
k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -97,7 +97,7 @@ class CuDNNLRNOp : public LRNOp<Context> {
this->beta,
this->k));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNLRNOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -126,7 +126,7 @@ class CuDNNLRNGradientOp : public LRNGradientOp<Context > {
this->beta,
this->k));
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNLRNGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
......@@ -30,7 +30,7 @@ class NNResizeOp : public Operator<Context> {
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -48,7 +48,7 @@ class NNResizeGradientOp : public Operator<Context> {
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -41,7 +41,7 @@ class Pooling2dOp: public Operator <Context> {
}
}
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Reshape();
void RunOnDevice() override;
......@@ -81,7 +81,7 @@ class Pooling2dGradientOp: public Operator<Context> {
}
}
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void Reshape();
void RunOnDevice() override;
......@@ -116,7 +116,7 @@ class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNPooling2dOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......@@ -151,7 +151,7 @@ class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
~CuDNNPooling2dGradientOp() {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
......
......@@ -28,7 +28,7 @@ class ROIAlignOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -50,7 +50,7 @@ class ROIAlignGradientOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
......@@ -27,7 +27,7 @@ class ROIPoolingOp : public Operator<Context> {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -35,7 +35,6 @@ class ROIPoolingOp : public Operator<Context> {
protected:
int pool_h, pool_w;
float spatial_scale;
Tensor* mask;
};
template <class Context>
......@@ -46,7 +45,7 @@ class ROIPoolingGradientOp final : public Operator<Context> {
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -54,7 +53,6 @@ class ROIPoolingGradientOp final : public Operator<Context> {
protected:
int pool_h, pool_w;
float spatial_scale;
Tensor* mask;
};
} // namespace dragon
......
......@@ -33,118 +33,203 @@ namespace math {
/******************** Level-0 ********************/
template <typename T, class Context>
void Set(const int n, const T alpha, T* x);
void Set(
const int n,
const T alpha,
T* x);
template <typename T, class Context>
void RandomUniform(const int n, const float low, const float high, T *x);
void RandomUniform(
const int n,
const float low,
const float high,
T* x);
template <typename T, class Context>
void RandomNormal(const int n, const float mu, const float sigma, T* x);
void RandomNormal(
const int n,
const float mu,
const float sigma,
T* x);
template <typename T, class Context>
void RandomTruncatedNormal(const int n,
const float mu,
const float sigma,
const float low,
const float high,
T* x);
void RandomTruncatedNormal(
const int n,
const float mu,
const float sigma,
const float low,
const float high,
T* x);
template <typename T, class Context>
void RandomBernoulli(const int n, const float p, uint32_t* x);
void RandomBernoulli(
const int n,
const float p,
uint32_t* x);
/******************** Level-1 ********************/
template <typename T, class Context>
void Add(const int n, const T* a, const T* b, T* y);
void Add(
const int n,
const T* a,
const T* b,
T* y);
template <typename T, class Context>
void Sub(const int n, const T* a, const T* b, T* y);
void Sub(
const int n,
const T* a,
const T* b,
T* y);
template <typename T, class Context>
void Mul(const int n, const T* a, const T* b, T* y);
void Mul(
const int n,
const T* a,
const T* b,
T* y);
template <typename T, class Context>
void Div(const int n, const T* a, const T* b, T* y);
void Div(
const int n,
const T* a,
const T* b,
T* y);
template <typename T, class Context>
void Clip(const int n, const float low, const float high, T* x);
void Clip(
const int n,
const float low,
const float high,
T* x);
template <typename T, class Context>
void Exp(const int n, const T* x, T* y);
void Exp(
const int n,
const T* x,
T* y);
template <typename T, class Context>
void Log(const int n, const T* x, T* y);
void Log(
const int n,
const T* x,
T* y);
template <typename T, class Context>
void Square(const int n, const T* x, T* y);
void Square(
const int n,
const T* x,
T* y);
template <typename T, class Context>
void Sqrt(const int n, const T* x, T* y);
void Sqrt(
const int n,
const T* x,
T* y);
template <typename T, class Context>
void Pow(const int n, const float alpha, const T* x, T* y);
void Pow(
const int n,
const float alpha,
const T* x,
T* y);
template <typename T, class Context>
void Inv(const int n, const float numerator, const T* x, T* y);
void Inv(
const int n,
const float numerator,
const T* x,
T* y);
/******************** Level-2 ********************/
template <typename T, class Context>
void Scal(const int n, const float alpha, T* y);
void Scal(
const int n,
const float alpha,
T* y);
template <typename T, class Context>
void Scale(const int n, const float alpha, const T* x, T* y);
void Scale(
const int n,
const float alpha,
const T* x,
T* y);
template <typename T, class Context>
T StridedDot(const int n,
const T* a,
const int incx,
const T* b,
const int incy);
T StridedDot(
const int n,
const T* a,
const int incx,
const T* b,
const int incy);
template <typename T, class Context>
float Dot(const int n, const T* a, const T* b);
float Dot(
const int n,
const T* a,
const T* b);
template<typename T, class Context>
float ASum(const int n, const T *x);
float ASum(
const int n,
const T* x);
template<typename T, class Context>
void AddScalar(const int n, const float alpha, T* y);
void AddScalar(
const int n,
const float alpha,
T* y);
template<typename T, class Context>
void MulScalar(const int n, const float alpha, T* y);
void MulScalar(
const int n,
const float alpha,
T* y);
template<typename T, class Context>
void Axpy(const int n, float alpha, const T* x, T *y);
void Axpy(
const int n,
float alpha,
const T* x,
T* y);
template<typename T, class Context>
void Axpby(const int n, float alpha, const T* x, float beta, T *y);
void Axpby(
const int n,
float alpha,
const T* x,
float beta,
T* y);
/******************** Level-3 ********************/
template <typename T, class Context>
void Gemm(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const float alpha,
const T* A,
const T* B,
const float beta,
T* C,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
void Gemm(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const float alpha,
const T* A,
const T* B,
const float beta,
T* C,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
template<typename T, class Context>
void Gemv(const CBLAS_TRANSPOSE transA,
const int M,
const int N,
const float alpha,
const T* A,
const T* x,
const float beta,
T* y,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
void Gemv(
const CBLAS_TRANSPOSE TransA,
const int M,
const int N,
const float alpha,
const T* A,
const T* x,
const float beta,
T* y,
TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
} // namespace math
......
......@@ -35,11 +35,13 @@ inline std::vector<std::string> SplitString(const std::string& str,
return ret;
}
template<> inline std::string dragon_cast<std::string, int>(int val) {
std::stringstream ss;
ss << val;
return ss.str();
}
#define DEFINE_NUMBER2STRING(T) \
template<> inline std::string dragon_cast<std::string, T>(T val) { \
std::stringstream ss; ss << val; return ss.str(); \
}
DEFINE_NUMBER2STRING(int);
DEFINE_NUMBER2STRING(unsigned long long);
template<> inline int dragon_cast<int, std::string>(std::string val) {
return atoi(val.c_str());
......
......@@ -24,7 +24,7 @@ Workspace* CreateWorkspace(const std::string& name){
unique_ptr<Workspace> new_workspace(new Workspace(name));
g_workspaces[name] = std::move(new_workspace);
sub_workspaces[name] = vector<string>();
return new_workspace.get();
return g_workspaces[name].get();
}
Workspace* ResetWorkspace(const std::string& name) {
......
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!