Commit ff2ab80b by Ting PAN

try it

1 parent 33789951
Showing with 4830 additions and 0 deletions
# ---------------- Welcom To Use Dragon ----------------
PROJECT(dragon)
CMAKE_MINIMUM_REQUIRED(VERSION 2.8.0)
# ---------------- Welcom To Use Dragon ----------------
# ---------------- User Config ----------------
# set optional libraries
option(WITH_CUDA "Set to ON use CUDA" ON)
option(WITH_CUDNN "Set to ON use CUDNN" OFF)
option(WITH_BLAS "Set to ON to use BLAS" OFF)
option(WITH_SSE "Set to ON to use SSE 4.1" ON)
option(WITH_MPI "Set to ON to use MPI" OFF)
option(WITH_MPI_CUDA "Set to ON to use MPI_CUDA_AWARE" OFF)
option(WITH_CUDA_FP16 "Set to ON to use FP16" ON)
# set your 3rdparty
set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/3rdparty)
# set if you want to build pydragon
#set(PYTHON_DIR /usr/include/python2.7)
set(ANACONDA_DIR /xxx/anaconda)
set(NUMPY_DIR /xxx/numpy)
# set CUDA compiling architecture
set(CUDA_ARCH -gencode arch=compute_20,code=sm_20
-gencode arch=compute_30,code=sm_30
-gencode arch=compute_35,code=sm_35
-gencode arch=compute_50,code=sm_50
-gencode arch=compute_60,code=sm_60)
# ---------------- User Config ----------------
# ---------------- Do Not Edit Following Items ----------------
# __----~~~~~~~~~~~------___
# . . ~~//====...... __--~ ~~
# -. \_|// |||\\ ~~~~~~::::... /~
# ___-==_ _-~o~ \/ ||| \\ _/~~-
# __---~~~.==~||\=_ -_--~/_-~|- |\\ \\ _/~
# _-~~ .=~ | \\-_ '-~7 /- / || \ /
# .~ .~ | \\ -_ / /- / || \ /
# / ____ / | \\ ~-_/ /|- _/ .|| \ /
# |~~ ~~|--~~~~--_ \ ~==-/ | \~--===~~ .\
# ' ~-| /| |-~\~~ __--~~
# |-~~-_/ | | ~\_ _-~ /\
# / \ \__ \/~ \__
# _--~ _/ | .-~~____--~-/ ~~==.
# ((->/~ '.|||' -_| ~~-/ , . _||
# -_ ~\ ~~---l__i__i__i--~~_/
# _-~-__ ~) \--______________--~~
# //.-~~~-~_--~- |-------~~~~~~~~
# //.-~~~--\
#
# ---------------- If You Are Not Goot At CMake ----------------
# ---[ Dependencies
if (WITH_CUDA)
FIND_PACKAGE(CUDA REQUIRED)
endif()
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
message(STATUS "C++11 support has been enabled by default.")
# ---[ Config types
set(CMAKE_BUILD_TYPE Release CACHE STRING "set build type to release")
set(CMAKE_CONFIGURATION_TYPES Release CACHE STRING "set build type to release" FORCE)
# ---[ Includes
set(INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include)
include_directories(${INCLUDE_DIR})
include_directories(${3RDPARTY_DIR}/include)
include_directories(${3RDPARTY_DIR}/include/mpi)
include_directories(${CUDA_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/src)
include_directories(${NUMPY_DIR}/core/include)
include_directories(${ANACONDA_DIR}/include/python2.7)
include_directories(${PYTHON_DIR})
include_directories(${ANACONDA_DIR}/include)
# ---[ libs
set(3RDPARTY_LIBS ${3RDPARTY_DIR}/lib)
link_directories(${3RDPARTY_LIBS})
link_directories(/usr/local/cuda/lib64)
# ---[ Install
set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" FORCE)
# ---[ defines
if (WITH_CUDA)
ADD_DEFINITIONS(-DWITH_CUDA)
message(STATUS "Use CUDA [Optional]")
endif()
if (WITH_CUDNN)
ADD_DEFINITIONS(-DWITH_CUDNN)
message(STATUS "Use CUDNN [Optional]")
endif()
if (WITH_BLAS)
ADD_DEFINITIONS(-DWITH_BLAS)
message(STATUS "Use BLAS [Optional]")
else()
message(STATUS "Unuse BLAS [Optional]"
"\n -- > GEMM/GEMV is disabled"
"\n -- > prefer not to run as CPU Mode")
endif()
if (WITH_SSE)
ADD_DEFINITIONS(-DWITH_SSE)
message(STATUS "Use SSE [Optional]")
if(UNIX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1")
endif()
endif()
if (WITH_MPI)
ADD_DEFINITIONS(-DWITH_MPI)
message(STATUS "Use MPI [Optional]")
endif()
if (WITH_MPI_CUDA)
ADD_DEFINITIONS(-DWITH_CUDA_AWARE)
message(STATUS "Use MPI-CUDA [Optional]")
endif()
if (WITH_CUDA_FP16)
ADD_DEFINITIONS(-DWITH_CUDA_FP16)
message(STATUS "Use CUDA FP16 [Optional]")
endif()
# ---[ Flags
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
endif()
if(UNIX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -O2 -m64 -fpermissive -std=c++11")
endif()
# ---[ Warnings
# ---[ Subdirectories
add_subdirectory(modules/python)
# ---[ Utils
file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib)
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_COMMON_H_
#define DRAGON_CORE_COMMON_H_
#include <climits>
#include <memory>
#include <string>
#include <queue>
#include <stack>
#include <vector>
#include <set>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <mutex>
#include "core/types.h"
#include "protos/dragon.pb.h"
#include "utils/logging.h"
namespace dragon {
using std::string;
using std::queue;
using std::stack;
using std::vector;
using std::pair;
using std::set;
using std::map;
using std::mutex;
using std::unique_ptr;
using std::shared_ptr;
template <typename Key, typename Value>
using Map = std::unordered_map<Key, Value>;
template <typename Value>
using Set = std::unordered_set<Value> ;
#define CONCATENATE_IMPL(s1, s2) s1##s2
#define CONCATENATE(s1, s2) CONCATENATE_IMPL(s1,s2)
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__)
#define NOT_IMPLEMENTED LOG(FATAL) << "this moudle is not implemented"
} // namespace dragon
#endif // DRAGON_CORE_COMMON_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_CONTEXT_H_
#define DRAGON_CORE_CONTEXT_H_
#include <random>
#include <ctime>
#include "common.h"
#include "utils/logging.h"
#ifdef WITH_CUDA
#include "utils/cuda_device.h"
#endif
namespace dragon {
class CPUObject{
public:
unique_ptr<std::mt19937> rand_generator;
};
class CPUContext{
public:
CPUContext(): random_seed_(3) { generator(); }
CPUContext(unsigned int random_seed): random_seed_(random_seed) { generator(); }
CPUContext(const DeviceOption& option): random_seed_(option.has_random_seed() ?
option.random_seed() : 3) { generator(); }
virtual ~CPUContext() {}
inline void SwitchToDevice() {}
inline void FinishDeviceCompution() { return; }
inline static void* New(size_t nbytes) {
void* data;
#ifdef WITH_CUDA_HOST_MEN
CUDA_CHECK(cudaMallocHost(&data, nbytes));
#else
data = malloc(nbytes);
#endif
CHECK(data) << "malloc mem: " << nbytes << " bytes failed.";
return data;
}
inline static void Memset(size_t nbytes, void* ptr) { memset(ptr, 0, nbytes); }
template<class DstContext, class SrcContext>
inline static void Memcpy(size_t nbytes, void* dst, const void* src) { memcpy(dst, src, nbytes); }
inline static void Delete(void* data) { free(data); }
template<typename T, class DstContext, class SrcContext>
inline static void Copy(int n, T* dst, const T* src){
if (dst == src) return;
// only the basic types(e.g. int/float) can memcpy correctly
if (std::is_fundamental<T>::value)
Memcpy<DstContext, SrcContext>(n * sizeof(T), (void*)dst, (const void*)src);
else for (int i = 0; i < n; i++) dst[i] = src[i];
}
inline std::mt19937* generator() {
auto& generator = cpu_object_.rand_generator;
if (!generator.get())
generator.reset(new std::mt19937(random_seed_));
return generator.get();
}
static CPUObject cpu_object_;
private:
unsigned int random_seed_;
};
static inline std::mt19937* rand_generator() {
return CPUContext::cpu_object_.rand_generator.get();
}
} // namepsace dragon
#endif // DRAGON_CORE_CONTEXT_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_CONTEXT_CUDA_H_
#define DRAGON_CORE_CONTEXT_CUDA_H_
#include "common.h"
#include "context.h"
#include "utils/cuda_device.h"
#include "utils/cudnn_device.h"
namespace dragon {
#ifdef WITH_CUDA
#define MAX_GPUS 8
/**************************************************************************
* cuXXX libraries wrapper "Context" as "Handle"
* it's well known that each "Context" binds to some "Devices" in OpenCL
* so, we must create different handles to associate different devices
* or the computations will be dispatched to the same GPU
* read more: http://docs.nvidia.com/cuda/cublas/, section 2.1.2
* also, "Handle" is thread safe
* it seems not necessary to create handles for different threads
*************************************************************************/
class CUDAObject{
public:
CUDAObject(): cur_gpu(0) {
for (int i = 0; i < MAX_GPUS; i++) {
cublas_handle[i] = nullptr;
curand_generator[i] = nullptr;
#ifdef WITH_CUDNN
cudnn_handle[i] = nullptr;
#endif
}
}
~CUDAObject() {
for (int i = 0; i < MAX_GPUS; i++) {
if (cublas_handle[i]) cublasDestroy_v2(cublas_handle[i]);
if (curand_generator[i]) curandDestroyGenerator(curand_generator[i]);
#ifdef WITH_CUDNN
if (cudnn_handle[i]) cudnnDestroy(cudnn_handle[i]);
#endif
}
}
int cur_gpu;
cublasHandle_t cublas_handle[MAX_GPUS];
curandGenerator_t curand_generator[MAX_GPUS];
#ifdef WITH_CUDNN
cudnnHandle_t cudnn_handle[MAX_GPUS];
#endif
};
class CUDAContext {
public:
CUDAContext(const DeviceOption& option)
: gpu_id_(option.gpu_id()),
random_seed_(option.has_random_seed() ? option.random_seed() : 3) {
CPUContext context(option);
CHECK_EQ(option.device_type(), CUDA);
cublas_handle();
curand_generator();
#ifdef WITH_CUDNN
cudnn_handle();
#endif
}
CUDAContext(const int gpu_id = 0)
: gpu_id_(gpu_id), random_seed_(3) {
CPUContext context;
cublas_handle();
curand_generator();
#ifdef WITH_CUDNN
cudnn_handle();
#endif
}
void SwitchToDevice() {
CUDA_CHECK(cudaSetDevice(gpu_id_));
cuda_object_.cur_gpu = gpu_id_;
}
void FinishDeviceCompution() {
cudaStreamSynchronize(cudaStreamDefault);
cudaError_t error = cudaGetLastError();
CHECK_EQ(error, cudaSuccess)
<< "cuda error: " << cudaGetErrorString(error);
}
inline static void* New(size_t nbytes) {
void* data;
cudaMalloc(&data, nbytes);
CHECK(data) << "malloc cuda mem: " << nbytes << " bytes failed.";
return data;
}
inline static void Memset(size_t nbytes, void* ptr) { cudaMemset(ptr, 0, nbytes); }
template<class DstContext, class SrcContext>
inline static void Memcpy(size_t nbytes, void* dst, const void* src) {
CUDA_CHECK(cudaMemcpy(dst, src, nbytes, cudaMemcpyDefault));
}
inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) {
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDefault, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaStreamDestroy(stream));
}
inline static void Delete(void* data) { cudaFree(data); }
template<typename T, class DstContext, class SrcContext>
static void Copy(int n, T* dst, const T* src){
if (dst == src) return;
Memcpy<SrcContext, DstContext>(n * sizeof(T), (void*)dst, (const void*)src);
}
cublasHandle_t& cublas_handle() {
auto& handle = cuda_object_.cublas_handle[gpu_id_];
if (handle) {
return handle;
} else {
DeviceGuard gurad(gpu_id_);
CUBLAS_CHECK(cublasCreate_v2(&handle));
return handle;
}
}
curandGenerator_t& curand_generator() {
auto& generator = cuda_object_.curand_generator[gpu_id_];
if (generator) {
return generator;
} else {
DeviceGuard gurad(gpu_id_);
CURAND_CHECK(curandCreateGenerator(&generator, CURAND_RNG_PSEUDO_DEFAULT));
CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(generator, random_seed_));
return generator;
}
}
#ifdef WITH_CUDNN
cudnnHandle_t cudnn_handle(){
auto& handle = cuda_object_.cudnn_handle[gpu_id_];
if (handle) {
return handle;
} else{
DeviceGuard gurad(gpu_id_);
CUDNN_CHECK(cudnnCreate(&handle));
return handle;
}
}
#endif
static CUDAObject cuda_object_;
private:
int gpu_id_, random_seed_;
};
static inline cublasHandle_t& cublas_handle() {
int cur_gpu = CUDAContext::cuda_object_.cur_gpu;
CHECK(CUDAContext::cuda_object_.cublas_handle[cur_gpu] != nullptr);
return CUDAContext::cuda_object_.cublas_handle[cur_gpu];
}
static inline curandGenerator_t& curand_generator() {
int cur_gpu = CUDAContext::cuda_object_.cur_gpu;
CHECK(CUDAContext::cuda_object_.curand_generator[cur_gpu] != nullptr);
return CUDAContext::cuda_object_.curand_generator[cur_gpu];
}
#ifdef WITH_CUDNN
static inline cudnnHandle_t& cudnn_handle() {
int cur_gpu = CUDAContext::cuda_object_.cur_gpu;
CHECK(CUDAContext::cuda_object_.cudnn_handle[cur_gpu] != nullptr);
return CUDAContext::cuda_object_.cudnn_handle[cur_gpu];
}
#endif
#else // WITH_CUDA
class CUDAContext{
public:
CUDAContext(const DeviceOption& option) { LOG(FATAL) << "CUDA is not compilied."; }
CUDAContext(const int gpu_id = 0) { LOG(FATAL) << "CUDA is not compilied."; }
template<class DstContext, class SrcContext>
static void Memcpy(size_t nbytes, void* dst, const void* src) {
LOG(FATAL) << "CUDA is not compilied.";
}
};
#endif // WITH_CUDA
} // namespace dragon
#endif // DRAGON_CORE_CONTEXT_CUDA_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_GRAPH_H_
#define DRAGON_CORE_GRAPH_H_
#include "core/common.h"
#include "core/operator.h"
namespace dragon {
class GraphBase {
public:
struct Node {
vector<string> parents;
vector<string> childs;
int op_idx = -1;
string op_type;
};
GraphBase(const GraphDef& graph_def, Workspace* ws);
virtual bool Create(const GraphDef& graph_def, Workspace* ws) = 0;
virtual bool Run(const string& include, const string& exclude) = 0;
protected:
string name_, phase_;
Map<string, Argument> args_;
Workspace* ws_;
};
class Graph final : public GraphBase {
public:
Graph(const GraphDef& graph_def, Workspace* ws);
bool Create(const GraphDef& graph_def, Workspace* ws) override;
bool Run(const string& include, const string& exclude) override;
GraphDef Prune(const GraphDef& graph_def);
GraphDef Share(const GraphDef& graph_def);
GraphDef MakeUpdate(const GraphDef& graph_def);
inline Workspace* ws() const { return ws_; }
private:
void ForwardShareDyeing(string u, string ancestor);
void ForwardPruneDyeing(string u, string leaf, vector<string> path);
void BackwardPruneDyeing(string v);
vector<OperatorBase*> ops_;
Map<string, Node> dag_;
Map<string, bool> visited_, colored_;
Map<string, string> renamed_;
Set<string> targets_;
};
GraphBase* NewGraph(const GraphDef& graph_def, Workspace* ws);
DECLARE_REGISTRY(GraphRegistry, GraphBase, const GraphDef&, Workspace*);
} // namespace dragon
#endif // DRAGON_CORE_GRAPH_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_GRAPH_GRADIENT_H_
#define DRAGON_CORE_GRAPH_GRADIENT_H_
#include "core/common.h"
namespace dragon {
typedef pair<bool, vector<pair<string, int> > > CheckTuple;
class GraphGradientMaker {
public:
GraphGradientMaker(const GraphDef& forward_def,
const vector<string>& targets)
: cur_op_idx_(0),
forward_def_(forward_def) {
for (auto& target : targets) targets_set_.insert(target);
}
GraphDef Make();
inline void SetTerms(const Map<string, string>& terms) { terms_ = terms; }
inline void SetOperatorPrefix(const string& prefix) { op_prefix_ = prefix; }
inline void SetOperatorSuffix(const string& suffix) { op_suffix_ = suffix; }
inline void AddExternalGrad(const string& name) { external_grads_.insert(name); }
private:
CheckTuple CheckMissingGrad(OperatorDef* forward_op);
string GetOperatorName();
GraphDef forward_def_, new_def_;
Map<string, string> terms_, inputs_to_grads_;
Set<string> targets_set_, blacklist_set_, external_grads_;
string op_prefix_, op_suffix_;
int cur_op_idx_;
};
} // namespace dragon
#endif
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_MIXEDMEM_H_
#define DRAGON_CORE_MIXEDMEM_H_
#include "typeid.h"
#include "context.h"
#include "context_cuda.h"
namespace dragon {
class MixedMemory{
public:
enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED };
MixedMemory()
: state_(UNINITIALIZED),
cpu_ptr_(nullptr), cuda_ptr_(nullptr),
nbytes_(0) {}
MixedMemory(const TypeMeta& meta, const size_t nbytes)
: state_(UNINITIALIZED), meta_(meta),
cpu_ptr_(nullptr), cuda_ptr_(nullptr),
nbytes_(nbytes) {}
~MixedMemory();
const void* cpu_data();
const void* cuda_data();
void* mutable_cpu_data();
void* mutable_cuda_data();
#ifdef WITH_CUDA
void async_cuda_data(const cudaStream_t& stream);
#endif
void SwitchToDevice();
inline size_t nbytes() const { return nbytes_; }
inline void* cpu_ptr() { state_ = STATE_AT_CPU; return cpu_ptr_; }
inline void* cuda_ptr() { state_ = STATE_AT_CUDA; return cuda_ptr_; }
inline State state() { return state_; }
private:
void ToCUDA();
void ToCPU();
void* cpu_ptr_, *cuda_ptr_;
State state_;
size_t nbytes_;
TypeMeta meta_;
};
} // namespace dragon
#endif
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_H_
#define DRAGON_CORE_OPERATOR_H_
#include "core/registry.h"
#include "core/context.h"
#include "core/tensor.h"
#include "core/operator_gradient.h"
#include "core/operator_schema.h"
#include "utils/cast.h"
#ifdef WITH_MPI
#include <mpi/mpi.h>
#endif
namespace dragon {
class Workspace;
class OperatorBase{
public:
OperatorBase(const OperatorDef& op_def, Workspace* ws);
inline Tensor& input(int idx) {
CHECK_LT(idx, (int)inputs_.size());
CHECK_GE(idx, -(int)inputs_.size());
if (idx >= 0) return *inputs_[idx];
else return *inputs_[idx + inputs_.size()];
}
inline Tensor* output(int idx) {
CHECK_LT(idx, (int)outputs_.size());
CHECK_GE(idx, -(int)outputs_.size());
if (idx >= 0) return outputs_[idx];
else return outputs_[idx + outputs_.size()];
}
inline size_t InputSize() { return inputs_.size(); }
inline size_t OutputSize() { return outputs_.size(); }
inline void SwitchToPhase(const string& phase) { this->phase_ = phase; }
virtual void Run() { NOT_IMPLEMENTED; }
inline const string& name() const { return op_def_.name(); }
inline const string& type() const { return op_def_.type(); }
inline const string& phase() const { return phase_; }
inline Workspace* ws() const { return ws_; }
template <typename T>
T GetSingleArg(const string& name, const T& default_value);
template <typename T>
vector<T> GetRepeatedArg(const string& name);
inline const Map<std::string, const Argument*>& args() { return args_; }
inline const Argument& arg(const string& name) { return *(args_[name]); }
inline const OperatorDef& op_def() const { return op_def_; }
inline const string debug_string() const { return op_def_.DebugString(); }
protected:
string phase_;
Map<std::string, const Argument*> args_;
vector<Tensor*> inputs_, outputs_;
OperatorDef op_def_;
Workspace* ws_;
};
template <class Context>
class Operator : public OperatorBase {
public:
Operator(const OperatorDef& op_def, Workspace* ws)
: OperatorBase(op_def, ws), ctx_(op_def.device_option()) {
allow_run_ = true;
allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && output(0)->name() == "ignore"));
}
virtual void Run() final {
if (!allow_run_) return;
ctx_.SwitchToDevice();
if (!op_def_.debug_mode()) ShareBeforeRun();
MemorySwitch();
RunOnDevice();
if (!op_def_.debug_mode()) ClearAfterRun();
ctx_.FinishDeviceCompution();
}
void MemorySwitch() {
for (int i = 0; i < InputSize(); i++)
if (input(i).name() != "ignore")
input(i).SwitchToDevice();
for (int i = 0; i < OutputSize(); i++)
if (output(i)->name() != "ignore")
output(i)->SwitchToDevice();
}
virtual void ShareBeforeRun() { /*** share tensors here if necessary ***/ }
virtual void RunOnDevice() = 0;
virtual void ClearAfterRun() { /*** clear tensors here if necessary ***/ }
inline Context& ctx() { return ctx_; }
inline string anchor() { return GetSingleArg("anchor", name()); }
inline bool allow_run() { return allow_run_; }
protected:
Context ctx_;
bool allow_run_;
private:
bool _MPICheck() {
#ifndef WITH_MPI
return true;
#else
vector<int> allow_ranks = Operator::GetRepeatedArg<int>("mpi_rank");
if (allow_ranks.empty()) return true;
int cur_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank);
for (auto mpi_rank : allow_ranks)
if (cur_rank == mpi_rank) return true;
return false;
#endif
}
};
OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
#define USE_SIMPLE_CTOR_DTOR(name) \
name(const OperatorDef& op_def, Workspace* ws) \
: Operator<Context>(op_def, ws) {} \
virtual ~name() {}
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
#define TENSOR_FILL(tensor, shape) \
if (tensor.count() == 0) { \
CHECK(ws()->GetFiller(tensor.name())) \
<< "Tensor(" << tensor.name() << ") is empty. \n" \
<< "may be specify a filler for it ?"; \
tensor.Reshape(shape); \
unique_ptr< Filler<T, Context> > filler( \
CreateFiller<T, Context>(*ws()->GetFiller(tensor.name()))); \
filler->Fill(&tensor); \
} else { \
TIndex count = 1; \
for(int i = 0; i < shape.size(); i++) count *= shape[i]; \
CHECK_EQ(count, tensor.count()) \
<< "\nmodel request " << "Tensor(" << tensor.name() << ")'s " \
<< "size is " << count << "\n" \
<< "but now is " << tensor.count() << "\n" \
<< "may be feed the incorrect Tensor before ?"; \
tensor.Reshape(shape); \
}
#define INIT_MULTIPLIER(ptr_tensor, size) { \
ptr_tensor = ws()->CreateTensor("_t_multiplier"); \
if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape(vector<TIndex>(1, size)); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.0f), \
ptr_tensor->template mutable_data<T, Context>()); \
} \
}
#define INSTANTIATE_OPERATOR(name, context) \
template class name##Op<context>;
#define INSTANTIATE_CUDNN_OPERATOR(name) \
template class CuDNN##name##Op<CUDAContext>;
#define REGISTER_CPU_OPERATOR(name, ...) \
REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
#define REGISTER_CUDA_OPERATOR(name, ...) \
REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__)
#define REGISTER_CUDNN_OPERATOR(name, ...) \
REGISTER_CLASS(CUDNNOperatorRegistry, name, __VA_ARGS__)
#define DEPLOY_CPU(name) \
REGISTER_CPU_OPERATOR(name, name##Op<CPUContext>); \
INSTANTIATE_OPERATOR(name, CPUContext);
#define DEPLOY_CUDA(name) \
REGISTER_CUDA_OPERATOR(name, name##Op<CUDAContext>); \
INSTANTIATE_OPERATOR(name, CUDAContext); \
#define DEPLOY_CPU_CUDA(name) \
REGISTER_CUDA_OPERATOR(name, name##Op<CPUContext>); \
INSTANTIATE_OPERATOR(name, CPUContext); \
#define DEPLOY_CUDNN(name) \
REGISTER_CUDNN_OPERATOR(name, CuDNN##name##Op<CUDAContext>); \
INSTANTIATE_CUDNN_OPERATOR(name);
} // namespace dragon
#endif // DRAGON_CORE_OPERATOR_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_GRADIENT_H_
#define DRAGON_CORE_OPERATOR_GRADIENT_H_
#include "core/common.h"
#include "core/registry.h"
#include "core/operator.h"
#include "utils/proto_utils.h"
namespace dragon {
struct Gradient {
vector<OperatorDef> ops;
vector<string> g_inputs;
vector<float> defaults;
Gradient(const vector<OperatorDef>& ops,
const vector<string>& g_inputs,
const vector<float>& defaults)
: ops(ops), g_inputs(g_inputs), defaults(defaults) {}
};
class GradientMakerBase {
public:
GradientMakerBase(const OperatorDef& def,
const vector<string>& g_outputs)
: def(def), g_outputs_(g_outputs), g_inputs_(def.input_size()) {}
virtual ~GradientMakerBase() {}
inline virtual bool CopyDeviceOption() const { return true; }
inline virtual bool CopyEngine() const { return true; }
inline virtual bool CopyArguments() const { return true; }
inline virtual Gradient Make() {
vector<OperatorDef> new_defs = MakeDefs();
Argument anchor;
anchor.set_name("anchor"); anchor.set_s(def.name());
for (int i = 0; i < new_defs.size(); i++)
new_defs[i].add_arg()->CopyFrom(anchor);
return Gradient(new_defs, g_inputs_, DefaultValues());
};
virtual inline vector<OperatorDef> MakeDefs() {
NOT_IMPLEMENTED;
return vector<OperatorDef>();
}
virtual inline vector<float> DefaultValues() {
return vector<float>(g_outputs_.size(), 1.0);
}
template <class... Args>
inline static vector<OperatorDef> SingleDef(const Args& ... args) {
return vector<OperatorDef> { MakeOperatorDef(args...) };
}
inline string I(const int i) { return def.input(i); }
inline string O(const int i) { return def.output(i); }
inline string GI(const int i) {
if (i >= g_inputs_.size()) return "ignore";
g_inputs_[i] = def.input(i) + "_grad";
return g_inputs_[i];
}
inline string GO(const int i) { return g_outputs_[i]; }
protected:
const OperatorDef& def;
vector<string> g_inputs_;
const vector<string>& g_outputs_;
};
// implemented in operator.cpp
Gradient MakeGradientForOp(const OperatorDef& op_def, const vector<string>& g_outputs);
# define GRADIENT_MAKER_CTOR(name) \
name(const OperatorDef& def, const vector<string>& g_output) \
: GradientMakerBase(def, g_output) {}
class NoGradient : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(NoGradient);
vector<OperatorDef> MakeDefs() override {
return vector<OperatorDef>();
}
};
DECLARE_REGISTRY(GradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
DECLARE_REGISTRY(NoGradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<string>&);
// define in the operator.cpp
#define REGISTER_GRADIENT(name, ...) \
REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
#define NO_GRADIENT(name) \
REGISTER_GRADIENT(name, NoGradient); \
REGISTER_CLASS(NoGradientRegistry, name, NoGradient)
} // namespace dragon
#endif // DRAGON_CORE_OPERATOR_GRADIENT_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_OPERATOR_SCHEMA_H_
#define DRAGON_CORE_OPERATOR_SCHEMA_H_
#include <functional>
#include <limits>
#include "common.h"
namespace dragon {
class OpSchema{
public:
OpSchema()
: op_type_("unknown"), file_("unknown"), line_(0) { Init(); }
OpSchema(const string& op_type, const string& file, const int line)
: op_type_(op_type), file_(file), line_(line) { Init(); }
bool Verify(const OperatorDef& def) const;
OpSchema& Inplace(set<pair<int, int> > inplace);
std::function<bool(int, int)> CheckInplace;
inline bool AllowInplace() const { return allow_inplace_; }
OpSchema& NumInputs(int n);
OpSchema& NumInputs(int min_num, int max_num);
OpSchema& NumOutputs(int n);
OpSchema& NumOutputs(int min_num, int max_num);
private:
void Init() {
min_input_ = min_output_= 0;
max_input_ = max_output_ = std::numeric_limits<int>::max();
CheckInplace = [](int, int) { return false; };
allow_inplace_ = false;
}
string op_type_, file_;
int line_, min_input_, max_input_;
int min_output_, max_output_;
bool allow_inplace_;
};
class OpSchemaRegistry {
public:
static OpSchema& NewSchema(const string& op_type, const string& file, const int line) {
auto& m = schema_map();
CHECK(!m.count(op_type))
<< "\nOpSchema(" << op_type << ") has registered before."
<< "\nat file: " << file
<< "\n line: " << line;
m.emplace(std::make_pair(op_type, OpSchema(op_type, file, line)));
return m[op_type];
}
static const OpSchema* Schema(const string& op_type) {
auto& m = schema_map();
if (m.count(op_type)) return &m[op_type];
else LOG(FATAL) << "OpSchema(" << op_type << ") has not registered yet.";
return nullptr;
}
private:
static Map<string, OpSchema>& schema_map() {
static Map<string, OpSchema> schema_map_;
return schema_map_;
}
};
#define OPERATOR_SCHEMA(name) \
static OpSchema& ANONYMOUS_VARIABLE(name) = \
OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__)
} // namespace dragon
#endif // DRAGON_CORE_OPERATOR_SCHEMA_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_REGISTRY_H_
#define DRAGON_CORE_REGISTRY_H_
#include <functional>
#include "core/common.h"
#include "utils/logging.h"
namespace dragon {
template <class SrcType, class ObjType, class... Args>
class Registry {
public:
typedef std::function<ObjType*(Args ...)> Creator;
void Register(const SrcType& key, Creator creator) {
CHECK(!registry_.count(key)) << "Key(" << key << ") has already registered.";
registry_[key] = creator;
}
ObjType* Create(const SrcType& key, Args ... args) {
CHECK(registry_.count(key)) << "Key(" << key << ") has not registered yet.";
return registry_[key](args...);
}
bool Has(const SrcType& key) { return (registry_.count(key)) != 0; }
vector<SrcType> keys() {
vector<SrcType> ret;
for (const auto& it : registry_) ret.push_back(it.first);
return ret;
}
private:
Map<SrcType, Creator> registry_;
};
template <class SrcType, class ObjType, class... Args>
class Registerer {
public:
Registerer(const SrcType& key, Registry<SrcType, ObjType, Args...>* registry,
typename Registry<SrcType, ObjType, Args...>::Creator creator, const string& help_msg = "") {
registry->Register(key, creator);
}
template <class DerivedType>
static ObjType* defaultCreator(Args ... args) {return new DerivedType(args...);}
};
// use in *.h files
#define DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjType,...) \
dragon::Registry<SrcType, ObjType,##__VA_ARGS__>* RegistryName(); \
typedef dragon::Registerer<SrcType,ObjType,##__VA_ARGS__> Registerer##RegistryName;
// use in *.cc files
#define DEFINE_TYPED_REGISTRY(RegistryName,SrcType, ObjType,...) \
Registry<SrcType,ObjType,##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType,ObjType,##__VA_ARGS__>* registry = \
new Registry<SrcType,ObjType,##__VA_ARGS__>(); \
return registry; \
}
#define DECLARE_REGISTRY(RegistryName, ObjType, ...) \
DECLARE_TYPED_REGISTRY(RegistryName, string, ObjType, ##__VA_ARGS__)
#define DEFINE_REGISTRY(RegistryName, ObjType, ...) \
DEFINE_TYPED_REGISTRY(RegistryName, string, ObjType, ##__VA_ARGS__)
#define REGISTER_TYPED_CLASS(RegistryName, key, ...) \
static Registerer##RegistryName ANONYMOUS_VARIABLE(g_##RegistryName) ( \
key, RegistryName(), Registerer##RegistryName::defaultCreator<__VA_ARGS__>)
#define REGISTER_CLASS(RegistryName, key, ...) \
REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
} // namepsace dragon
#endif //DRAGON_CORE_REGISTRY_H_
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_TENSOR_H_
#define DRAONG_CORE_TENSOR_H_
#include <vector>
#include "core/common.h"
#include "core/typeid.h"
#include "core/mixedmem.h"
namespace dragon {
typedef int64_t TIndex;
typedef size_t TSize;
class Tensor {
public:
Tensor() {}
Tensor(const string& name) : name_(name) {}
void Reshape(const vector<TIndex>& dims) {
dims_ = dims;
TIndex new_size = 1;
for (auto d : dims_) {
CHECK_GT(d, 0);
new_size *= d;
}
if (size_ != new_size &&
capacity_ < TIndex(new_size * meta_.itemsize())) {
memory_.reset();
capacity_ = 0;
}
size_ = new_size;
}
void ReshapeLike(const Tensor& other) {
Reshape(other.dims_);
}
inline const string& name() const { return name_; }
inline TIndex axis(const TIndex i) const {
CHECK_GE(i, -(TIndex)ndim());
CHECK_LT(i, (TIndex)ndim());
if (i < 0) return i + ndim();
else return i;
}
inline TSize ndim() const { return dims_.size(); }
inline TIndex dim(const TIndex i) const{ return dims_[axis(i)];}
inline const vector<TIndex>& dims() const { return dims_; }
inline TSize nbytes() const { return size_ * meta_.itemsize(); }
inline TIndex count(const TIndex start, const TIndex end) const {
TIndex ret = 1;
for (TIndex i = start; i < end; i++) ret *= dim(i);
return ret;
}
inline TIndex count() const { return size_; }
inline TIndex count(const TIndex start) const { return count(start, ndim()); }
inline TIndex offset(const TIndex n, const TIndex c = 0,
const TIndex h = 0, const TIndex w = 0) {
CHECK_LE(n, dim(0));
CHECK_LE(c, dim(1));
CHECK_LE(h, dim(2));
CHECK_LE(w, dim(3));
return ((n * dim(1) + c) * dim(2) + h) * dim(3) + w;
}
inline TIndex offset(const vector<TIndex>& vec) {
CHECK_LE(vec.size(), ndim());
TIndex offset = 0;
for (int i = 0; i < ndim(); i++){
offset = offset * dim(i);
if (vec.size() > i) offset += vec[i];
}
return offset;
}
inline string dim_string() const {
std::stringstream ss;
ss << "(";
for (int i = 0; i < ndim() - 1; i++) ss << dim(i) << ",";
ss << dim(ndim() - 1) << ")";
return ss.str();
}
MixedMemory::State memory_state() const { return memory_->state(); }
MixedMemory* memory() const { return memory_.get(); }
void SwitchToDevice() { if(memory_) memory_->SwitchToDevice(); }
const TypeMeta& meta() const { return meta_; }
void SetMeta(const TypeMeta& meta) { meta_ = meta; }
template <typename T> inline bool IsType() { return meta_.Match<T>(); }
template <class Context>
const void* raw_data() const {
CHECK(memory_.get()) << "memory access before allowcating.";
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>())
return memory_->cpu_data();
else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>())
return memory_->cuda_data();
else LOG(FATAL) << "unknown memory type access. only CPU or CUDA are supported.";
return nullptr;
}
template <typename T, class Context>
const T* data() const {
return static_cast<const T*>(raw_data<Context>());
}
template <class Context>
void active_data_ptr(void** data_ptr) {
if (!memory_) {
*data_ptr = nullptr;
} else {
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) {
*data_ptr = memory_->mutable_cpu_data();
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) {
*data_ptr = memory_->mutable_cuda_data();
}
}
}
template <class Context>
void* raw_mutable_data(const TypeMeta& meta){
void* data_ptr;
active_data_ptr<Context>(&data_ptr);
if (meta_ == meta && data_ptr) {
return data_ptr;
} else {
meta_ = meta; // copy-assign the meta
CHECK_GT(size_, 0); // must specify a valid size
memory_.reset(new MixedMemory(meta, size_* meta_.itemsize()));
// malloc
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>())
data_ptr = memory_->mutable_cpu_data();
else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>())
data_ptr = memory_->mutable_cuda_data();
// init for each structed element if necessary
if (meta.ctor()) meta_.ctor()(data_ptr, size_);
}
capacity_ = size_ * meta_.itemsize();
return data_ptr;
}
template <class Context>
void* raw_mutable_data() {
CHECK_NE(meta_.id(), 0)
<< "\nTensor(" << name_ << "): unknown type, "
<< "or does not have a type.";
return raw_mutable_data<Context>(meta_);
}
template <typename T, class Context>
T* mutable_data() {
void* data_ptr;
active_data_ptr<Context>(&data_ptr);
if (data_ptr && meta_ == TypeMeta::Make<T>()) return static_cast<T*>(data_ptr);
return static_cast<T*>(raw_mutable_data<Context>(TypeMeta::Make<T>()));
}
void Share(const Tensor& other) {
CHECK_EQ(size_, other.size_);
memory_ = other.memory_;
meta_ = other.meta_;
capacity_ = other.capacity_;
}
void Replace(const Tensor& other) {
memory_ = other.memory_;
meta_ = other.meta_;
capacity_ = other.capacity_;
size_ = other.size_;
dims_ = other.dims_;
}
void Reset() {
size_ = capacity_ = 0;
meta_ = TypeMeta();
dims_.clear();
memory_.reset();
}
private:
vector<TIndex> dims_;
TIndex size_ = 0, capacity_ = 0;
TypeMeta meta_;
string name_;
shared_ptr<MixedMemory> memory_;
};
} // namespace dragon
#endif // DRAONG_CORE_TENSOR_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_TYPEID_H_
#define DRAGON_CORE_TYPEID_H_
#include <cstdlib>
#include <iostream>
#include <map>
namespace dragon {
typedef intptr_t TypeId;
class TypeMeta {
public:
template <typename T>
struct TypeRegister {
static TypeId id() {
static bool type_id_bit[1];
return (TypeId)(type_id_bit);
}
};
typedef void(*PlacementNew)(void*, size_t);
typedef void(*TypedCopy)(const void*, void*, size_t);
typedef void(*TypedDestructor)(void*, size_t);
TypeMeta()
: id_(0), itemsize_(0),
ctor_(nullptr), copy_(nullptr), dtor_(nullptr) {}
TypeMeta(const TypeMeta& src)
: id_(src.id_), itemsize_(src.itemsize_),
ctor_(src.ctor_), copy_(src.copy_), dtor_(src.dtor_) {}
TypeMeta& operator = (const TypeMeta& src) {
if (this == &src) return *this;
id_ = src.id_;
itemsize_ = src.itemsize_;
ctor_ = src.ctor_;
copy_ = src.copy_;
dtor_ = src.dtor_;
return *this;
}
bool operator == (const TypeMeta& other) const {
return (id_ == other.id_);
}
bool operator != (const TypeMeta& other) const {
return (id_ != other.id_);
}
const TypeId& id() const { return id_; }
const size_t& itemsize() const { return itemsize_; }
PlacementNew ctor() const { return ctor_; }
TypedCopy copy() const { return copy_; }
TypedDestructor dtor() const { return dtor_; }
template <typename T>
static TypeId Id() {
// return T's id
// using a intptr_t as hash key
return TypeRegister<T>::id();
}
template <typename T>
static size_t Itemsize() { return sizeof(T); }
template <typename T>
bool Match() const { return (id_ == Id<T>()); }
template <typename T>
static void Ctor(void* ptr, size_t n){
T* typed_ptr = static_cast<T*>(ptr);
for (unsigned int i = 0; i < n; i++) new(typed_ptr + i) T;
}
template <typename T>
static void Copy(const void* src, void* dst, size_t n){
const T* typed_src = static_cast<const T*>(src);
T* typed_dst = static_cast<T*>(dst);
for (unsigned int i = 0; i < n; i++) typed_dst[i] = typed_src[i];
}
template <typename T>
static void Dtor(void* ptr, size_t n){
T* typed_ptr = static_cast<T*>(ptr);
for (unsigned int i = 0; i < n; i++) typed_ptr[i].~T();
}
#define FundMeta std::enable_if<std::is_fundamental<T>::value,TypeMeta>::type
template <typename T>
static typename FundMeta Make() {
return TypeMeta(Id<T>(), Itemsize<T>(), nullptr, nullptr, nullptr);
}
#define StructMeta std::enable_if<!std::is_fundamental<T>::value && std::is_copy_assignable<T>::value, TypeMeta>::type
template<typename T>
static typename StructMeta Make() {
return TypeMeta(Id<T>(), Itemsize<T>(), Ctor<T>, Copy<T>, Dtor<T>);
}
private:
TypeMeta(TypeId id, size_t itemsize,
PlacementNew ctor, TypedCopy copy, TypedDestructor dtor)
: id_(id), itemsize_(itemsize),
ctor_(ctor), copy_(copy), dtor_(dtor) {}
private:
TypeId id_;
size_t itemsize_;
PlacementNew ctor_;
TypedCopy copy_;
TypedDestructor dtor_;
};
} // namespace dragon
#endif // DRAGON_CORE_TYPEID_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_TYPES_H_
#define DRAGON_CORE_TYPES_H_
namespace dragon {
#ifdef _MSC_VER
typedef struct __declspec(align(2)) {
unsigned short x;
} float16;
typedef struct __declspec(align(4)) {
unsigned int x;
} float32;
#else
typedef struct {
unsigned short x;
} __attribute__((aligned(2))) float16;
typedef struct {
unsigned int x;
} __attribute__((aligned(4))) float32;
#endif
} // namespace dragon
#endif // DRAGON_CORE_TYPES_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CORE_WORKSPACE_H_
#define DRAGON_CORE_WORKSPACE_H_
#include "core/common.h"
#include "core/graph.h"
#include "utils/string.h"
namespace dragon {
#define WORKSPACE_MIN_BUFFER_SIZE 3
#define WORKSPACE_MAX_BUFFER_SIZE 3
class Workspace{
public:
typedef Map<string, unique_ptr<Tensor> > TensorMap;
typedef Map<string, unique_ptr<mutex> > LockMap;
typedef Map<string, unique_ptr<GraphBase> > GraphMap;
typedef Map<string, TensorFiller> FillerMap;
typedef Map<string, string> RenameMap;
Workspace(): root_folder_(".") { init(); }
Workspace(string root_folder) : root_folder_(root_folder) { init(); }
void init() {
CreateTensor("ignore");
for (int i = 0; i < WORKSPACE_MIN_BUFFER_SIZE; i++) CreateBuffer();
}
/******************** Tensor ********************/
inline string GetTensorName(const string& name) {
if (rename_map_.count(name)) return rename_map_[name];
else return name;
}
inline bool HasTensor(const string& name) {
string query = GetTensorName(name);
return tensor_map_.count(query) > 0;
}
inline Tensor* CreateTensor(const string& name){
string query = GetTensorName(name);
if (!HasTensor(query))
tensor_map_[query] = unique_ptr<Tensor>(new Tensor(query));
return tensor_map_[query].get();
}
inline Tensor* GetTensor(const string& name){
string query = GetTensorName(name);
CHECK(HasTensor(query))
<< "Tensor(" << name << ") does not exist.";
return tensor_map_[query].get();
}
inline void LockTensor(const string& name){
string query = GetTensorName(name);
if (!lock_map_.count(query))
lock_map_[query] = unique_ptr<mutex>(new mutex);
lock_map_[query]->lock();
}
inline void UnlockTensor(const string& name){
string query = GetTensorName(name);
if (!lock_map_.count(query))
lock_map_[query] = unique_ptr<mutex>(new mutex);
lock_map_[query]->unlock();
}
inline void ReleaseTensor(const string& name) {
CHECK(HasTensor(name)) << "\nTensor(" << name << ") does not "
<< "belong to workspace, could not release it.";
string query = GetTensorName(name);
tensor_map_[query]->Reset();
}
inline vector<string> GetTensors() {
vector<string> names;
for (auto& it : tensor_map_) names.push_back(it.first);
return names;
}
/******************** Filler ********************/
inline void CreateFiller(const TensorFiller filler) {
CHECK_GT(filler.tensor().size(), 0)
<< "Tensor without a valid name can not be filled.";
if (filler_map_.count(filler.tensor())) return;
filler_map_[filler.tensor()] = filler;
}
inline const TensorFiller* GetFiller(const string& name) {
if (filler_map_.count(name) > 0) return &filler_map_[name];
else return nullptr;
}
/******************** Buffer ********************/
inline Tensor* CreateBuffer() {
int buffer_idx = 1;
string name;
while (1) {
name = "_t_buffer_" + dragon_cast<string, int>(buffer_idx++);
if (!HasTensor(name)) break;
}
buffer_stack_.push(name);
return CreateTensor(name);
}
inline Tensor* GetBuffer() {
if (!buffer_stack_.empty()) {
string name = buffer_stack_.top();
buffer_stack_.pop();
return GetTensor(name);
}
LOG(FATAL) << "buffers are not enough, add more if necessary.";
return nullptr;
}
inline void ReleaseBuffer(Tensor* tensor, bool force_release=false) {
// release directly
if (buffer_stack_.size() >= WORKSPACE_MAX_BUFFER_SIZE || force_release) {
ReleaseTensor(tensor->name());
} else { // recover as a available buffer
buffer_stack_.push(tensor->name());
}
}
/******************** Graph ********************/
GraphBase* CreateGraph(const GraphDef& graph_def);
inline bool RunGraph(const string& graph_name,
const string& include, const string& exclude) {
if (!graph_map_.count(graph_name)) {
LOG(ERROR) << "Graph(" << graph_name << ") does not exist.";
return false;
}
return graph_map_[graph_name]->Run(include, exclude);
}
inline vector<string> GetGraphs(){
vector<string> names;
for (auto& it : graph_map_) names.push_back(it.first);
return names;
}
/******************** Utility ********************/
inline const string& GetRootFolder() const { return root_folder_; }
inline void CreateRename(const string& old_tensor,
const string& new_tensor) {
rename_map_[old_tensor] = new_tensor;
}
private:
TensorMap tensor_map_;
LockMap lock_map_;
GraphMap graph_map_;
FillerMap filler_map_;
RenameMap rename_map_;
string root_folder_;
stack<string> buffer_stack_;
};
} // namespace dragon
#endif // DRAGON_CORE_WORKSPACE_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
#include "core/operator.h"
#include "utils/math_functions.h"
namespace dragon {
template <class Context>
class DropoutOp final : public Operator<Context> {
public:
DropoutOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
prob(OperatorBase::GetSingleArg<float>("prob", 0)) {
bool use_scale = OperatorBase::GetSingleArg<bool>("scale", true);
threshold = static_cast<unsigned int>(UINT_MAX * prob);
if (use_scale) scale = 1.0 / (1.0 - prob);
else scale = 1.0;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float prob, scale;
unsigned int threshold;
Tensor* mask;
};
template <class Context>
class DropoutGradientOp final : public Operator<Context> {
public:
DropoutGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
prob(OperatorBase::GetSingleArg<float>("prob", 0)) {
bool use_scale = OperatorBase::GetSingleArg<bool>("scale", true);
threshold = static_cast<unsigned int>(UINT_MAX * prob);
if (use_scale) scale = 1.0 / (1.0 - prob);
else scale = 1.0;
}
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
float prob, scale;
unsigned int threshold;
Tensor* mask;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_DROPOUT_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ReluOp : public Operator<Context> {
public:
ReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float slope;
};
template <class Context>
class ReluGradientOp : public Operator<Context> {
public:
ReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float slope;
};
#ifdef WITH_CUDNN
template <class Context>
class CuDNNReluOp final : public ReluOp<Context> {
public:
CuDNNReluOp(const OperatorDef& op_def, Workspace* ws)
: ReluOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
cudnnTensorDescriptor_t input_desc, output_desc;
cudnnActivationDescriptor_t act_desc;
};
template <class Context>
class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
public:
CuDNNReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: ReluGradientOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
cudnnTensorDescriptor_t input_desc, output_desc;
cudnnActivationDescriptor_t act_desc;
};
#endif // WITH_CUDNN
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
#define DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
#include "core/operator.h"
namespace dragon {
template <class Context>
class SigmoidOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SigmoidOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class SigmoidGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SigmoidGradientOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_SIGMOID_OP_HPP
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SoftmaxOp final : public Operator<Context> {
public:
SoftmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
int axis;
TIndex outer_dim, inner_dim;
Tensor* sum_multiplier, *scale;
};
template <class Context>
class SoftmaxGradientOp final : public Operator<Context> {
public:
SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
int axis;
TIndex outer_dim, inner_dim;
Tensor* sum_multiplier, *scale;
};
#ifdef WITH_CUDNN
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNSoftmaxOp final : public Operator<Context> {
public:
CuDNNSoftmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
int axis;
TIndex outer_dim, inner_dim;
cudnnTensorDescriptor_t input_desc, output_desc;
};
template <class Context>
class CuDNNSoftmaxGradientOp final : public Operator<Context> {
public:
CuDNNSoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
int axis;
TIndex outer_dim, inner_dim;
cudnnTensorDescriptor_t input_desc, output_desc;
};
#endif // WITH_CUDNN
}
#endif // DRAGON_OPERATORS_ACTIVATION_SOFTMAX_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class TanhOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(TanhOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class TanhGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(TanhGradientOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_TANH_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class AddOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(AddOp);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class AddGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(AddGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_ADD_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class BiasAddOp : public Operator<Context> {
public:
BiasAddOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void NCHWRunWithType();
template <typename T> void NHWCRunWithType();
protected:
TIndex outer_dim, dim, inner_dim;
string data_format;
Tensor* bias_multiplier;
};
template <class Context>
class BiasAddGradientOp final : public Operator<Context> {
public:
BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void NCHWRunWithType();
template <typename T> void NHWCRunWithType();
protected:
int outer_dim, dim, inner_dim;
string data_format;
Tensor* bias_multiplier;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_BIAS_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
#include <float.h>
#include "core/operator.h"
namespace dragon {
template <class Context>
class ClipOp final : public Operator<Context> {
public:
ClipOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
low(OperatorBase::GetSingleArg<float>("low", -FLT_MAX)),
high(OperatorBase::GetSingleArg<float>("high", FLT_MAX)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float low, high;
Tensor* mask;
};
template <class Context>
class ClipGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ClipGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
Tensor* mask;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_CLIP_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class DivOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(DivOp);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class DivGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(DivGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
} // namepsace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_DIV_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class DotOp final : public Operator<Context> {
public:
DotOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
void RunOnDevice() override;
template <typename T> void DotRunWithType();
template <typename T> void GemmRunWithType();
template <typename T> void GemvRunWithType();
protected:
bool transA, transB;
TIndex M, K1, K2, N1, N2;
};
template <class Context>
class DotGradientOp final : public Operator<Context> {
public:
DotGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
void RunOnDevice() override;
template <typename T> void DotRunWithType();
template <typename T> void GemmRunWithType();
template <typename T> void GemvRunWithType();
protected:
bool transA, transB;
TIndex M, K1, K2, N1, N2;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_DOT_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class EltwiseOp final : public Operator<Context> {
public:
EltwiseOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "SUM")),
coeffs(OperatorBase::GetRepeatedArg<float>("coeffs")) {
if (coeffs.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize())
<< "\nop has " << InputSize() << " inputs, "
<< "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1));
}
void RunOnDevice() override;
template <typename T> void SumRunWithType();
template <typename T> void ProdRunWithType();
protected:
string operation;
vector<float> coeffs;
};
template <class Context>
class EltwiseGradientOp final : public Operator<Context> {
public:
EltwiseGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "SUM")),
coeffs(OperatorBase::GetRepeatedArg<float>("coeff")) {
if (coeffs.size() > 0) {
CHECK_EQ(coeffs.size(), InputSize())
<< "\nop has " << InputSize() << " inputs, "
<< "but provided " << coeffs.size() << " coeffs.";
} else coeffs.resize(InputSize(), float(1));
}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void SumRunWithType();
template <typename T> void ProdRunWithType();
protected:
string operation;
vector<float> coeffs;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_ELTWISE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ExpOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class ExpGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_EXP_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class GramMatrixOp final : public Operator<Context> {
public:
GramMatrixOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, dim, inner_dim;
TIndex x_offset, y_offset;
};
template <class Context>
class GramMatrixGradientOp final : public Operator<Context> {
public:
GramMatrixGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, dim, inner_dim;
TIndex x_offset, y_offset;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_GRAM_MATRIX_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class InnerProductOp: public Operator<Context> {
public:
InnerProductOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
void RunOnDevice();
template <typename T> void TransRunWithType();
template <typename T> void NoTransRunWithType();
protected:
TIndex axis, num_output, M, K;
bool transW;
Tensor* bias_multiplier;
};
template <class Context>
class InnerProductGradientOp final : public Operator<Context> {
public:
InnerProductGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
TIndex axis, num_output, M, K;
bool transW;
Tensor* bias_multiplier;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_INNER_PRODUCT_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class LogOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LogOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class LogGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LogGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_LOG_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class MatmulOp final : public Operator<Context> {
public:
MatmulOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool transA, transB;
TIndex n, x1_offset, x2_offset, y_offset;
TIndex M, K1, K2, N;
};
template <class Context>
class MatmulGradientOp final : public Operator<Context> {
public:
MatmulGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
bool transA, transB;
TIndex n, x1_offset, x2_offset, y_offset;
TIndex M, K1, K2, N;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class MulOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MulOp);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class MulGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MulGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_MUL_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class PowOp: public Operator<Context> {
public:
PowOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
scale(OperatorBase::GetSingleArg<float>("scale", 1.0)),
shift(OperatorBase::GetSingleArg<float>("shift", 0.0)),
power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float scale, shift, power, power_scale;
};
template <class Context>
class PowGradientOp final : public Operator<Context> {
public:
PowGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
scale(OperatorBase::GetSingleArg<float>("scale", 1.0)),
shift(OperatorBase::GetSingleArg<float>("shift", 0.0)),
power(OperatorBase::GetSingleArg<float>("power", 1.0)) {
power_scale = power * scale;
}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
float scale, shift, power, power_scale;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_POW_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ScaleOp : public Operator<Context> {
public:
ScaleOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, num_axes, inner_dim;
Tensor* bias_multiplier;
};
template <class Context>
class ScaleGradientOp final : public Operator<Context> {
public:
ScaleGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void BiasRunWithType();
template <typename T> void ScaleRunWithType();
template <typename T> void RunWithType();
protected:
TIndex axis, num_axes;
TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim;
Tensor* bias_multiplier, *sum_multiplier;
Tensor sum_result;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_SCALE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SquareOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SquareOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class SquareGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SquareGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_SQUARE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_
#define DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SubOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SubOp);
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
template <class Context>
class SubGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SubGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
protected:
Tensor* bcast_multiplier;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_SUB_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_ARGMAX_OP_H_
#define DRAGON_OPERATORS_COMMON_ARGMAX_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ArgmaxOp final : public Operator<Context> {
public:
ArgmaxOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, top_k, count, inner_dim;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_ARGMAX_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_AT_OP_H_
#define DRAGON_OPERATORS_COMMON_AT_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class AtOp final : public Operator<Context> {
public:
AtOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim, x_slice_dim, y_slice_dim;
vector<TIndex> output_dims;
};
template <class Context>
class AtGradientOp final : public Operator<Context> {
public:
AtGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim, x_slice_dim, y_slice_dim;
bool acc_grad;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_AT_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_CONCAT_OP_H_
#define DRAGON_OPERATORS_COMMON_CONCAT_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ConcatOp final : public Operator<Context> {
public:
ConcatOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, nin, outer_dim, inner_dim, x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset;
vector<TIndex> concat_dims;
};
template <class Context>
class ConcatGradientOp : public Operator<Context> {
public:
ConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
TIndex axis, nin, outer_dim, inner_dim, x_concat_dim, y_concat_dim;
TIndex x_offset, y_offset, concat_offset;
vector<TIndex> concat_dims;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_CONCAT_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_CROP_OP_H_
#define DRAGON_OPERATORS_COMMON_CROP_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class CropOp: public Operator<Context> {
public:
CropOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 2)),
offsets_param(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
CHECK(shape.size() * shape_like.size() == 0)
<< "\ncan not set shape and shape_like both.";
CHECK(shape.size() + shape_like.size() != 0)
<< "\nmust set shape and shape_like either.";
}
void ComputeOutputShape();
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename T> void RecursiveRunWithType(vector<TIndex> idxs,
const vector<TIndex>& offsets,
int cur_dim,
Tensor* x,
Tensor* y);
protected:
TIndex axis;
vector<int> offsets_param, shape;
vector<TIndex> output_shape, offsets;
string shape_like;
};
template <class Context>
class CropGradientOp final : public Operator<Context > {
public:
CropGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 2)),
offsets_param(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
CHECK(shape.size() * shape_like.size() == 0)
<< "\ncan not set shape and shape_like both.";
CHECK(shape.size() + shape_like.size() != 0)
<< "\nmust set shape and shape_like either.";
}
void ComputeOutputShape();
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
template <typename T> void RecursiveRunWithType(vector<TIndex> idxs,
const vector<TIndex>& offsets,
int cur_dim,
Tensor* dy,
Tensor* dx);
protected:
TIndex axis;
vector<int> offsets_param, shape;
vector<TIndex> output_shape, offsets;
string shape_like;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_CROP_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_EXPAND_DIMS_OP_H_
#define DRAGON_OPERATORS_COMMON_EXPAND_DIMS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ExpandDimsOp final : public Operator<Context> {
public:
ExpandDimsOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {}
void RunOnDevice() override;
protected:
TIndex axis;
};
template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ExpandDimsGradientOp);
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_EXPAND_DIMS_OP_H_
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_FLATTEN_OP_H_
#define DRAGON_OPERATORS_COMMON_FLATTEN_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class FlattenOp final : public Operator<Context> {
public:
FlattenOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
void RunOnDevice() override;
protected:
TIndex axis, num_axes;
};
template <class Context>
class FlattenGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(FlattenGradientOp);
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_FLATTEN_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_PYTHON_OP_H_
#define DRAGON_OPERATORS_COMMON_PYTHON_OP_H_
#include <Python.h>
#include "core/operator.h"
namespace dragon {
template <class Context>
class RunOp : public Operator<Context> {
public:
RunOp(const OperatorDef& op_def, Workspace* ws);
PyObject* String(const char* str) {
return PyBytes_FromStringAndSize(str, string(str).size());
}
void RunOnDevice() override;
protected:
PyObject* self, *inputs, *outputs;
string module, op, param_str;
};
template <class Context>
class TemplateOp : public RunOp<Context> {
public:
TemplateOp(const OperatorDef& op_def, Workspace* ws)
: RunOp<Context>(op_def, ws) {}
};
template <class Context>
class TemplateGradientOp : public TemplateOp<Context> {
public:
TemplateGradientOp(const OperatorDef& op_def, Workspace* ws)
: TemplateOp<Context>(op_def, ws) {}
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_PYTHON_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_REDUCE_OP_H_
#define DRAGON_OPERATORS_COMMON_REDUCE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ReduceOp final : public Operator<Context> {
public:
ReduceOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")),
keep_dims(OperatorBase::GetSingleArg<bool>("keep_dims", false)) {}
void RunOnDevice() override;
template <typename T> void SumRunWithType();
template <typename T> void MeanRunWithType();
protected:
bool keep_dims;
string operation;
TIndex axis, axis_dim, count, inner_dim;
Tensor* multiplier;
};
template <class Context>
class ReduceGradientOp final : public Operator<Context> {
public:
ReduceGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void SumRunWithType();
template <typename T> void MeanRunWithType();
protected:
string operation;
TIndex axis, axis_dim, count, inner_dim;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_REDUCE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_RESHAPE_OP_H_
#define DRAGON_OPERATORS_COMMON_RESHAPE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ReshapeOp final : public Operator<Context> {
public:
ReshapeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
shape(OperatorBase::GetRepeatedArg<int>("shape")) {
new_shape.resize(shape.size());
}
void RunOnDevice() override;
protected:
vector<int> shape;
vector<TIndex> new_shape;
};
template <class Context>
class ReshapeGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReshapeGradientOp);
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_RESHAPE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_SCAN_OP_H_
#define DRAGON_OPERATORS_COMMON_SCAN_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ScanOp final: public Operator<Context> {
public:
ScanOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nsteps(OperatorBase::GetSingleArg<int>("nsteps", 0)),
step_type(OperatorBase::GetSingleArg<string>("step_type", "Static")),
step_tensor(OperatorBase::GetSingleArg<string>("step_tensor", "")),
nseqs(OperatorBase::GetSingleArg<int>("nseqs", 0)),
default_outputs(OperatorBase::GetRepeatedArg<string>("default_outputs")),
nout((int)default_outputs.size()),
debug_mode(OperatorBase::GetSingleArg<bool>("debug_mode", false)) {
InitTemplate();
}
void RunOnDevice() override;
void InitTemplate();
void UnrollTemplate();
void UpdateTerms(int cur_step);
protected:
GraphDef func_def, template_def, new_def;
Map<int, unique_ptr<Graph>> graphs;
Graph* cur_graph;
Map<string, string> terms;
vector<string> default_outputs;
TIndex axis, nseqs, nsteps, nrepeats, nout;
string step_type, step_tensor;
bool debug_mode;
};
template <class Context>
class ScanGradientOp final: public Operator<Context> {
public:
ScanGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
nsteps(OperatorBase::GetSingleArg<int>("nsteps", 0)),
step_type(OperatorBase::GetSingleArg<string>("step_type", "Static")),
step_tensor(OperatorBase::GetSingleArg<string>("step_tensor", "")),
forward_inputs(OperatorBase::GetRepeatedArg<string>("inputs_name")),
forward_outputs(OperatorBase::GetRepeatedArg<string>("outputs_name")) {
// handle GO(x)
for (int i = 0; i < forward_outputs.size(); i++)
terms[forward_outputs[i] + "_grad"] = input(i + (int)OutputSize()).name();
// handle GI(x)
for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = output(i)->name();
}
void RunOnDevice() override;
void MakeGradientOps();
protected:
GraphDef forward_def, new_def;
Map<string, string> terms;
Map<int, unique_ptr<Graph>> graphs;
vector<string> forward_inputs, forward_outputs;
Graph* cur_graph;
TIndex axis, nsteps;
string step_type, step_tensor;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_SCAN_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_SLICE_OP_H_
#define DRAGON_OPERATORS_COMMON_SLICE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SliceOp : public Operator<Context> {
public:
SliceOp(const OperatorDef& op_def, Workspace* ws):
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, nout, steps;
TIndex outer_dim, inner_dim, x_slice_dim, y_slice_dim;
TIndex slice_offset;
vector<TIndex> slice_dims;
};
template <class Context>
class SliceGradientOp final : public Operator<Context> {
public:
SliceGradientOp(const OperatorDef& op_def, Workspace* ws):
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, nout;
TIndex outer_dim, inner_dim, x_slice_dim, y_slice_dim;
TIndex x_offset, y_offset, slice_offset;
vector<TIndex> slice_dims;
};
} // namespace dragon
#endif // #define DRAGON_OPERATORS_COMMON_SLICE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_TILE_OP_H_
#define DRAGON_OPERATORS_COMMON_TILE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class TileOp : public Operator<Context> {
public:
TileOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
multiples(OperatorBase::GetRepeatedArg<int>("multiples")) {
for (int i = 0; i < multiples.size(); i++)
if (multiples[i] > 1)
process_axes.push_back({ i, multiples[i] });
}
void RunOnDevice() override;
template<typename T> void TileRunWithType();
protected:
vector<int> multiples;
vector< pair<int, int> > process_axes;
TIndex axis, multiple, outer_dim, dim, inner_dim;
Tensor* dest, *source;
};
template <class Context>
class TileGradientOp : public Operator<Context> {
public:
TileGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
multiples(OperatorBase::GetRepeatedArg<int>("multiples")) {
for (int i = (int)multiples.size() - 1; i >= 0; i--)
if (multiples[i] > 1)
process_axes.push_back({ i, multiples[i] });
}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template<typename T> void TileRunWithType();
protected:
vector<int> multiples;
vector< pair<int, int> > process_axes;
TIndex axis, multiple, outer_dim, dim, inner_dim;
Tensor* dest, *source;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_TILE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_COMMON_TRANSPOSE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class TransposeOp final: public Operator<Context> {
public:
TransposeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
perm(OperatorBase::GetRepeatedArg<int>("perm")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
vector<int> perm;
Tensor* order, *old_steps, *new_steps;
};
template <class Context>
class TransposeGradientOp final : public Operator<Context> {
public:
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
Tensor* order, *old_steps, *new_steps;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_TRANSPOSE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_UTILS_OP_H_
#define DRAGON_OPERATORS_COMMON_UTILS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class CopyOp final: public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CopyOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class AccuracyOp final: public Operator<Context> {
public:
AccuracyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
top_k(OperatorBase::GetSingleArg<int>("top_k", 1)){
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels");
if (args.size()){
ignore_labels.Reshape(vector<TIndex>(1, args.size()));
int* ignore_data = ignore_labels.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex top_k, outer_num, inner_num, classes;
Tensor ignore_labels;
};
template <class Context>
class OneHotOp final : public Operator < Context > {
public:
OneHotOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
depth(OperatorBase::GetSingleArg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex depth, on_value, off_value;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_UTILS_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class L1LossOp : public Operator<Context> {
public:
L1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
coeff(OperatorBase::GetSingleArg<float>("coeff", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float coeff;
Tensor* diff;
string normalization;
};
template <class Context>
class L1LossGradientOp final : public Operator<Context> {
public:
L1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
coeff(OperatorBase::GetSingleArg<float>("coeff", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float coeff;
Tensor* diff;
string normalization;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_L1_LOSS_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class L2LossOp : public Operator<Context> {
public:
L2LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
coeff(OperatorBase::GetSingleArg<float>("coeff", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float coeff;
Tensor* diff;
string normalization;
};
template <class Context>
class L2LossGradientOp final : public Operator<Context> {
public:
L2LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
coeff(OperatorBase::GetSingleArg<float>("coeff", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float coeff;
Tensor* diff;
string normalization;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_L2_LOSS_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SigmoidCrossEntropyLossOp final : public Operator<Context> {
public:
SigmoidCrossEntropyLossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
Tensor losses;
Tensor* prob;
string normalization;
};
template <class Context>
class SigmoidCrossEntropyLossGradientOp final : public Operator<Context> {
public:
SigmoidCrossEntropyLossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
Tensor* prob;
string normalization;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SmoothL1LossOp final : public Operator<Context> {
public:
SmoothL1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)) {
sigma2 *= sigma2;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float sigma2;
Tensor* diff, *error;
};
template <class Context>
class SmoothL1LossGradientOp final : public Operator<Context> {
public:
SmoothL1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)) {
sigma2 *= sigma2;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float sigma2;
Tensor* diff;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_LOSS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SoftmaxCrossEntropyLossOp final : public Operator<Context> {
public:
SoftmaxCrossEntropyLossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ input(0).name() }),
vector<string>({ "_t_" + anchor() + "_softmax_prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim;
Tensor losses;
Tensor* prob;
unique_ptr<OperatorBase> softmax_op;
string normalization;
};
template <class Context>
class SoftmaxCrossEntropyLossGradientOp final : public Operator<Context> {
public:
SoftmaxCrossEntropyLossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim;
Tensor* prob;
string normalization;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_LOSS_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_LOSS_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SoftmaxLossOp final : public Operator<Context> {
public:
SoftmaxLossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels");
if (args.size()) {
ignore.Reshape(vector<TIndex>(1, args.size()));
int* ignore_data = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ input(0).name() }),
vector<string>({ "_t_" + anchor() + "_softmax_prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws));
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim;
Tensor ignore, valid, losses;
Tensor* prob;
unique_ptr<OperatorBase> softmax_op;
string normalization;
};
template <class Context>
class SoftmaxLossGradientOp final : public Operator<Context> {
public:
SoftmaxLossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
vector<int> args = OperatorBase::GetRepeatedArg<int>("ignore_labels");
if (args.size()) {
ignore.Reshape(vector<TIndex>(1, args.size()));
int* ignore_data = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim;
Tensor ignore, valid;
Tensor* prob;
string normalization;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SOFTMAX_LOSS_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#define DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
#ifdef WITH_MPI
#include "core/operator.h"
#include "mpi/mpi.h"
namespace dragon {
template <class Context>
class ModelMPIBase : public Operator<Context> {
public:
ModelMPIBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
comm((MPI_Comm)OperatorBase::GetSingleArg<int>("comm", 0)),
group((MPI_Group)OperatorBase::GetSingleArg<int>("group", 0)) {
if (comm == MPI_COMM_NULL) return;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
MPI_Comm_size(comm, &comm_size);
MPI_Comm_rank(comm, &comm_rank);
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
int world_root = OperatorBase::GetSingleArg<int>("root", 0);
MPI_Group_translate_ranks(world_group, 1, &world_root, group, &comm_root);
CHECK(comm_root != MPI_UNDEFINED) << "mpi root is not included in layer group.";
}
protected:
MPI_Comm comm;
MPI_Group group;
int comm_size, comm_rank, comm_root;
int world_size, world_rank;
};
} // namespace dragon
#endif // WITH_MPI
#endif // DRAGON_OPERATORS_MPI_BASE_MPI_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
#ifdef WITH_MPI
#include "operators/mpi/base_mpi_op.h"
namespace dragon {
template <class Context>
class MPIBroadcastOp final : public ModelMPIBase<Context> {
public:
MPIBroadcastOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class MPIBroadcastGradientOp final : public ModelMPIBase<Context> {
public:
MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // WITH_MPI
#endif //DRAGON_OPERATORS_MPI_MPI_BROADCAST_OP_H_
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
#define DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
#ifdef WITH_MPI
#include "operators/mpi/base_mpi_op.h"
namespace dragon {
template <class Context>
class MPIGatherOp final : public ModelMPIBase<Context> {
public:
MPIGatherOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class MPIGatherGradientOp final : public ModelMPIBase<Context> {
public:
MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // WITH_MPI
#endif // DRAGON_OPERATORS_MPI_MPI_GATHER_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class BatchNormOp : public Operator<Context> {
public:
BatchNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
inplace(OperatorBase::GetSingleArg<bool>("inplace", true)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float momentum, eps;
Tensor mean, num_by_chans;
Tensor* num_multiplier, *spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans;
int use_stats;
bool use_global_stats, inplace;
};
template <class Context>
class BatchNormGradientOp final : public Operator<Context> {
public:
BatchNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
Tensor num_by_chans;
Tensor* num_multiplier, *spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans;
int use_stats;
bool use_global_stats;
};
template <class Context>
class BNOp : public Operator<Context> {
public:
BNOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { }
void RunOnDevice() override { NOT_IMPLEMENTED; }
template <typename T> void RunWithType() { NOT_IMPLEMENTED; }
protected:
float momentum, eps;
int use_stats;
bool use_global_stats;
};
template <class Context>
class BNGradientOp : public Operator<Context> {
public:
BNGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { }
void ShareBeforeRun() override;
void RunOnDevice() override { NOT_IMPLEMENTED; }
void ClearAfterRun() override;
template <typename T> void RunWithType() { NOT_IMPLEMENTED; }
protected:
float eps;
int use_stats;
bool use_global_stats;
};
#ifdef WITH_CUDNN
#if CUDNN_VERSION_MIN(5, 0, 0)
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNBNOp final : public BNOp<Context> {
public:
CuDNNBNOp(const OperatorDef& op_def, Workspace* ws)
: BNOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON));
}
void RunOnDevice() override;
template <typename T> void SpatialRunWithType();
template <typename T> void PerActivationRunWithType();
protected:
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
TIndex num, channels, spatial_dim;
Tensor* mean, *var;
bool use_global_stats;
};
template <class Context>
class CuDNNBNGradientOp final : public BNGradientOp<Context> {
public:
CuDNNBNGradientOp(const OperatorDef& op_def, Workspace* ws)
: BNGradientOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON));
}
void RunOnDevice() override;
template <typename T> void SpatialRunWithType();
template <typename T> void PerActivationRunWithType();
protected:
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
Tensor num_by_chans;
Tensor* num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev;
TIndex num, channels, spatial_dim, nbychans;
bool use_global_stats;
};
#endif
#endif // WITH_CUDNN
} // namespace dragon
#endif // DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
#define DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class BatchRenormOp : public Operator<Context> {
public:
BatchRenormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
r_max(OperatorBase::GetSingleArg<float>("r_max", float(3.0))),
d_max(OperatorBase::GetSingleArg<float>("d_max", float(5.0))),
t_delta(OperatorBase::GetSingleArg<float>("t_delta", float(1.0))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
inplace(OperatorBase::GetSingleArg<bool>("inplace", true)),
t_r_max(float(1.0)), t_d_max(float(0.0)), t_val(float(0.0)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float momentum, eps, r_max, d_max, t_delta;
float t_r_max, t_d_max, t_val;
Tensor mean, d, t_h_mean, t_h_var, num_by_chans;
Tensor* num_multiplier, *spatial_multiplier;
Tensor* stddev, *r, *var, *x_norm;
TIndex num, channels, spatial_dim, nbychans;
int use_stats;
bool use_global_stats, inplace;
};
template <class Context>
class BatchRenormGradientOp final : public Operator<Context> {
public:
BatchRenormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
Tensor mean, num_by_chans;
Tensor* num_multiplier, *spatial_multiplier;
Tensor* stddev, *r, *var, *x_norm;
TIndex num, channels, spatial_dim, nbychans;
int use_stats;
bool use_global_stats;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NORM_BATCH_RENORM_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class InstanceNormOp : public Operator<Context> {
public:
InstanceNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
inplace(OperatorBase::GetSingleArg<bool>("inplace", true)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float eps;
Tensor mean;
Tensor* spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans;
bool inplace;
};
template <class Context>
class InstanceNormGradientOp final : public Operator<Context> {
public:
InstanceNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
Tensor* spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NORM_INSTANCE_NORM_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_L2_NORM_H_
#define DRAGON_OPERATORS_NORM_L2_NORM_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class L2NormOp final : public Operator<Context> {
public:
L2NormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-5))) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float eps;
TIndex axis, num_axes, end_axis;
bool across_inner;
Tensor* norm, *buffer, *multiplier;
TIndex outer_dim, dim, inner_dim, spatial_dim;
};
template <class Context>
class L2NormGradientOp final : public Operator<Context> {
public:
L2NormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType();
protected:
TIndex axis, num_axes, end_axis;
bool across_inner;
Tensor* norm, *multiplier, *buffer, *buffer_inner;
TIndex outer_dim, dim, inner_dim;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NORM_L2_NORM_H_
\ No newline at end of file
// -// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
#define DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class LSTMUnitOp : public Operator<Context> {
public:
LSTMUnitOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
has_cont(OperatorBase::GetSingleArg<string>("cont_t", "")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex num, channels;
string has_cont;
Tensor* cont_t;
};
template <class Context>
class LSTMUnitGradientOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LSTMUnitGradientOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex num, channels;
Tensor* zeros;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_RECURRENT_LSTM_UNIT_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
#include "operators/update/update_op_base.h"
namespace dragon {
template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
t(0),
eps(param("eps")),
beta1(param("beta1")),
beta2(param("beta2")) {}
void ComputeRunWithFloat() override;
protected:
unique_ptr<Tensor> m, v, tmp;
float lr, beta1, beta2, eps, coeff;
int t;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_ADAM_UPDATE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_ASYNC_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_ASYNC_UPDATE_OP_H_
#ifdef WITH_MPI
#include "operators/update/update_op_base.h"
#include "utils/thread.h"
namespace dragon {
template <class Context>
class AsyncUpdateOp final: public UpdateOpBase<Context> {
public:
AsyncUpdateOp(const OperatorDef& op_def, Workspace* ws);
int GetDelay(int tag);
void UpdateTimestamp(int tag);
void RunOnDevice() override;
void ComputeRunWithFloat() override { /* do nothing */ }
template <typename T> void RootRunWithType();
template <typename T> void ThreadRunWithType();
protected:
string mode;
unique_ptr<Tensor> recv_buffer;
Tensor** acc_buffers;
string* tags;
TIndex update_count;
int node_id, nsync, max_recv;
Map<int, int> local_timestamp;
std::unique_ptr<std::thread> thread;
#ifdef WITH_CUDA_AWARE
cudaStream_t stream;
cublasHandle_t handle;
#endif
};
} // namespace dragon
#endif // WITH_MPI
#endif // DRAGON_OPERATORS_UPDATE_ASYNC_UPDATE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
#define DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class MovingAverageOp final : public Operator<Context> {
public:
MovingAverageOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
decay(OperatorBase::GetSingleArg<float>("decay", 1.0)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float decay;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_MOVING_AVERAGE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
#include "operators/update/update_op_base.h"
namespace dragon {
template <class Context>
class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
eps(param("eps")),
decay(param("decay")) {}
void ComputeRunWithFloat() override;
protected:
float lr, decay, eps;
unique_ptr<Tensor> history;
Tensor buffer;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_RMSPROP_UPDATE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
#define DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
#include "operators/update/update_op_base.h"
namespace dragon {
template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> {
public:
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
momentum(param("momentum")) {}
void ComputeRunWithFloat() override;
protected:
float lr, momentum;
unique_ptr<Tensor> history;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_SGD_UPDATE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
#define DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class UpdateOpBase : public Operator<Context> {
public:
UpdateOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
allow_parallel(false),
async_tag(-1),
lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)),
domain(OperatorBase::GetSingleArg<string>("domain", "_")),
mode(OperatorBase::GetSingleArg<string>("mode", "Sync")) { InitMPI(); }
float param(const string& name) const;
void InitMPI();
void ShareBeforeRun() override;
void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void ReduceRunWithType();
template <typename T> void PreprocessRunWithType();
virtual void ComputeRunWithFloat() = 0;
template <typename T> void UpdateRunWithType();
template <typename T> void RecvRunWithType();
protected:
float lr_mult, decay_mult;
float l2_decay, clip_thresh, scale_factor;
int comm_size, comm_rank, comm_root;
int world_size, world_rank;
bool allow_parallel;
int async_tag;
Tensor* buffer;
string domain, mode;
#ifdef WITH_MPI
MPI_Comm comm;
MPI_Group group;
#endif // WITH_MPI
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UPDATE_UPDATE_OP_BASE_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UTILS_CAST_OP_H_
#define DRAGON_OPERATORS_UTILS_CAST_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class FloatToHalfOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(FloatToHalfOp);
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UTILS_CAST_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UTILS_COMPARE_OP_H_
#define DRAGON_OPERATORS_UTILS_COMPARE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class CompareOp final : public Operator<Context> {
public:
CompareOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
void RunOnDevice() override;
template <typename T> void EqualRunWithType();
protected:
string operation;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UTILS_COMPARE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UTILS_GRADIENT_GENERATE_OP_H_
#define DRAGON_OPERATORS_UTILS_GRADIENT_GENERATE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class GradientGenerateOp final: public Operator<Context> {
public:
GradientGenerateOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
defaults(OperatorBase::GetRepeatedArg<float>("defaults")) {
CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize());
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
vector<float> defaults;
};
template <class Context>
class GradientGatherOp final : public Operator<Context> {
public:
GradientGatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
for (int i = 0; i < InputSize(); i++)
if (input(i).name() != "ignore") indices.push_back(i);
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
vector<int> indices;
};
template <class Context>
class StopGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(StopGradientOp);
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UTILS_GRADIENT_GENERATE_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UTILS_INITIALIZE_OP_H_
#define DRAGON_OPERATORS_UTILS_INITIALIZE_OP_H_
#include "core/operator.h"
#include "utils/filler.h"
namespace dragon {
template <class Context>
class InitializeOp: public Operator<Context> {
public:
InitializeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
static_shape(OperatorBase::GetRepeatedArg<int>("static_shape")),
dynamic_shape(OperatorBase::GetSingleArg<string>("dynamic_shape", "")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TensorFiller filler;
vector<int> static_shape;
string dynamic_shape;
};
template <class Context>
class FillOp final : public InitializeOp<Context> {
public:
FillOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
this->filler.set_type("constant");
this->filler.set_value(OperatorBase::GetSingleArg<float>("value", 0.0));
}
};
template <class Context>
class RandomUniformOp final : public InitializeOp<Context> {
public:
RandomUniformOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
this->filler.set_type("uniform");
this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -1.0));
this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 1.0));
}
};
template <class Context>
class RandomNormalOp final : public InitializeOp<Context> {
public:
RandomNormalOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
this->filler.set_type("normal");
this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0));
}
};
template <class Context>
class TruncatedNormalOp final : public InitializeOp<Context> {
public:
TruncatedNormalOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
this->filler.set_type("truncated_normal");
this->filler.set_mean(OperatorBase::GetSingleArg<float>("mean", 0.0));
this->filler.set_std(OperatorBase::GetSingleArg<float>("std", 1.0));
this->filler.set_low(OperatorBase::GetSingleArg<float>("low", -2.0));
this->filler.set_high(OperatorBase::GetSingleArg<float>("high", 2.0));
}
};
template <class Context>
class GlorotUniformOp final : public InitializeOp<Context> {
public:
GlorotUniformOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
string mode = OperatorBase::GetSingleArg<string>("mode", "fan_in");
float scale = OperatorBase::GetSingleArg<float>("scale", 3.0);
this->filler.set_type("xavier");
if (mode == "fan_avg") {
this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_AVG);
} else if (mode == "fan_out") {
this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_OUT);
} else {
this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_IN);
}
this->filler.set_scale(scale);
}
};
template <class Context>
class GlorotNormalOp final : public InitializeOp<Context> {
public:
GlorotNormalOp(const OperatorDef& op_def, Workspace* ws)
: InitializeOp<Context>(op_def, ws) {
string mode = OperatorBase::GetSingleArg<string>("mode", "fan_in");
float scale = OperatorBase::GetSingleArg<float>("scale", 2.0);
this->filler.set_type("msra");
if (mode == "fan_avg") {
this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_AVG);
} else if (mode == "fan_out") {
this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_OUT);
} else {
this->filler.set_variance_norm(TensorFiller_VarianceNorm_FAN_IN);
}
this->filler.set_scale(scale);
}
};
} // namespace
#endif // DRAGON_OPERATORS_UTILS_INITIALIZE_OP_H_
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!