Commit 176b7bbb by Ting PAN

Dragon 0.2.2 Preview

1 parent 310bcb5f
Showing with 488 additions and 428 deletions
# DragonDocker
**Note: use [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) to run all GPU builds.**
This folder lists some basic Dockerfiles.
You can modify them in order to be compatible with your environment.
For the built official images, See https://hub.docker.com/r/seetaresearch/dragon
\ No newline at end of file
FROM ubuntu:16.04
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
wget \
unzip \
ssh \
vim \
libprotobuf-dev \
protobuf-compiler \
libopenblas-dev \
python3-pip \
python3-dev \
python3-pyqt4 \
python3-tk \
&& rm -rf /var/lib/apt/lists/*
RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \
protobuf \
lmdb \
opencv-python \
six \
Pillow
matplotlib \
pyyaml
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/3rdparty.zip && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \
cd openmpi && ls | grep -v install | xargs rm -r && cp install/bin/mpirun /usr/bin
RUN git clone https://github.com/seetaresearch/Dragon.git && \
cd Dragon/Dragon && rm CMakeLists.txt && \
wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/CMakeLists.txt && \
mkdir build && cd build && cmake .. && make install -j8 && cd .. && rm -rf build && \
cd python && python3 setup.py install
\ No newline at end of file
FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
wget \
unzip \
ssh \
vim \
libprotobuf-dev \
protobuf-compiler \
libopenblas-dev \
libnccl2 \
libnccl-dev \
python3-pip \
python3-dev \
python3-pyqt4 \
python3-tk \
&& rm -rf /var/lib/apt/lists/*
RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \
protobuf \
lmdb \
opencv-python \
six \
Pillow \
matplotlib \
pyyaml
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/3rdparty.zip && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \
cd openmpi && ls | grep -v install | xargs rm -r && cp install/bin/mpirun /usr/bin
RUN git clone https://github.com/seetaresearch/Dragon.git && \
cd Dragon/Dragon && rm CMakeLists.txt && \
wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/CMakeLists.txt && \
mkdir build && cd build && cmake .. && make install -j8 && cd .. && rm -rf build && \
cd python && python3 setup.py install
\ No newline at end of file
...@@ -30,6 +30,7 @@ set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty) ...@@ -30,6 +30,7 @@ set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda # set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda
# Set CUDA compiling architecture # Set CUDA compiling architecture
# Remove "compute_70/sm_70" if using CUDA 8.0
set(CUDA_ARCH -gencode arch=compute_30,code=sm_30 set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
-gencode arch=compute_35,code=sm_35 -gencode arch=compute_35,code=sm_35
-gencode arch=compute_50,code=sm_50 -gencode arch=compute_50,code=sm_50
...@@ -37,6 +38,7 @@ set(CUDA_ARCH -gencode arch=compute_30,code=sm_30 ...@@ -37,6 +38,7 @@ set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
-gencode arch=compute_70,code=sm_70) -gencode arch=compute_70,code=sm_70)
# Set CUDNN Library Dir if necessary (Linux Only) # Set CUDNN Library Dir if necessary (Linux Only)
# For Win, Recommend to use ``3RDPARTY_DIR/lib``
set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64) set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64)
# ---------------- User Config ---------------- # ---------------- User Config ----------------
...@@ -97,7 +99,7 @@ if (WITH_MPI) ...@@ -97,7 +99,7 @@ if (WITH_MPI)
include_directories(${3RDPARTY_DIR}/include/mpi) include_directories(${3RDPARTY_DIR}/include/mpi)
endif() endif()
# ---[ libs # ---[ Lib Directories
set(3RDPARTY_LIBS ${3RDPARTY_DIR}/lib) set(3RDPARTY_LIBS ${3RDPARTY_DIR}/lib)
link_directories(${3RDPARTY_LIBS}) link_directories(${3RDPARTY_LIBS})
if (WITH_CUDNN) if (WITH_CUDNN)
...@@ -108,7 +110,7 @@ endif() ...@@ -108,7 +110,7 @@ endif()
set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" FORCE) set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" FORCE)
set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS}) set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS})
# ---[ defines # ---[ Defines
if (WITH_PYTHON) if (WITH_PYTHON)
ADD_DEFINITIONS(-DWITH_PYTHON) ADD_DEFINITIONS(-DWITH_PYTHON)
if (${PYTHON_VERSION_MAJOR} STREQUAL "2") if (${PYTHON_VERSION_MAJOR} STREQUAL "2")
...@@ -184,7 +186,7 @@ endif() ...@@ -184,7 +186,7 @@ endif()
# ---[ Warnings # ---[ Warnings
# ---[ execute # ---[ Commands
set (PROTOS_DIR ${PROJECT_SOURCE_DIR}/src/protos) set (PROTOS_DIR ${PROJECT_SOURCE_DIR}/src/protos)
message(STATUS "Generate Protobuf Files") message(STATUS "Generate Protobuf Files")
execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/caffemodel.proto) execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/caffemodel.proto)
...@@ -192,7 +194,7 @@ execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS ...@@ -192,7 +194,7 @@ execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS
# ---[ Subdirectories # ---[ Subdirectories
add_subdirectory(modules/python) add_subdirectory(modules/python)
#add_subdirectory(modules/cc) # Compile CC module if necessary #add_subdirectory(modules/cxx) # Compile CXX module if necessary
# ---[ Utils # ---[ Utils
file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib) file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib)
\ No newline at end of file
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
#include <random> #include <random>
#include <ctime> #include <ctime>
#include "common.h" #include "core/common.h"
#include "utils/logging.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "utils/cuda_device.h" #include "utils/cuda_device.h"
...@@ -38,7 +37,7 @@ class CPUContext { ...@@ -38,7 +37,7 @@ class CPUContext {
virtual ~CPUContext() {} virtual ~CPUContext() {}
inline void SwitchToDevice() {} inline void SwitchToDevice() {}
inline void FinishDeviceCompution() { return; } inline void static FinishDeviceCompution() { return; }
inline static void* New(size_t nbytes) { inline static void* New(size_t nbytes) {
void* data; void* data;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#ifndef DRAGON_CORE_CONTEXT_CUDA_H_ #ifndef DRAGON_CORE_CONTEXT_CUDA_H_
#define DRAGON_CORE_CONTEXT_CUDA_H_ #define DRAGON_CORE_CONTEXT_CUDA_H_
#include "common.h" #include "core/common.h"
#include "context.h" #include "core/context.h"
#include "utils/cuda_device.h" #include "utils/cuda_device.h"
#include "utils/cudnn_device.h" #include "utils/cudnn_device.h"
...@@ -21,8 +21,6 @@ namespace dragon { ...@@ -21,8 +21,6 @@ namespace dragon {
#ifdef WITH_CUDA #ifdef WITH_CUDA
#define MAX_GPUS 8
/************************************************************************** /**************************************************************************
* cuXXX libraries wrapper "Context" as "Handle". * cuXXX libraries wrapper "Context" as "Handle".
* It's well known that each "Context" binds to some "Devices" in OpenCL. * It's well known that each "Context" binds to some "Devices" in OpenCL.
...@@ -66,7 +64,7 @@ class CUDAObject { ...@@ -66,7 +64,7 @@ class CUDAObject {
class CUDAContext { class CUDAContext {
public: public:
CUDAContext(const DeviceOption& option) CUDAContext(const DeviceOption& option)
: gpu_id_(option.gpu_id()), : gpu_id_(option.device_id()),
random_seed_(option.has_random_seed() ? option.random_seed() : 3) { random_seed_(option.has_random_seed() ? option.random_seed() : 3) {
CPUContext context(option); CPUContext context(option);
CHECK_EQ(option.device_type(), CUDA); CHECK_EQ(option.device_type(), CUDA);
...@@ -92,11 +90,10 @@ class CUDAContext { ...@@ -92,11 +90,10 @@ class CUDAContext {
cuda_object_.cur_gpu = gpu_id_; cuda_object_.cur_gpu = gpu_id_;
} }
void FinishDeviceCompution() { inline static void FinishDeviceCompution() {
cudaStreamSynchronize(cudaStreamDefault); cudaStreamSynchronize(cudaStreamDefault);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
CHECK_EQ(error, cudaSuccess) CHECK_EQ(error, cudaSuccess) << "CUDA Error: " << cudaGetErrorString(error);
<< "CUDA Error: " << cudaGetErrorString(error);
} }
inline static void* New(size_t nbytes) { inline static void* New(size_t nbytes) {
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -51,6 +51,8 @@ class Graph final : public GraphBase { ...@@ -51,6 +51,8 @@ class Graph final : public GraphBase {
GraphDef Prune(const GraphDef& meta_graph); GraphDef Prune(const GraphDef& meta_graph);
GraphDef MakeUpdate(const GraphDef& meta_graph); GraphDef MakeUpdate(const GraphDef& meta_graph);
GraphDef Share(const GraphDef& optimized_graph); GraphDef Share(const GraphDef& optimized_graph);
void ShareGrads(GraphDef& optimized_graph);
void RecomputingAware(const GraphDef& optimized_graph, Workspace* ws); void RecomputingAware(const GraphDef& optimized_graph, Workspace* ws);
inline Workspace* ws() const { return ws_; } inline Workspace* ws() const { return ws_; }
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -16,31 +16,30 @@ ...@@ -16,31 +16,30 @@
namespace dragon { namespace dragon {
typedef pair<bool, vector<pair<string, int> > > CheckTuple;
class GraphGradientMaker { class GraphGradientMaker {
public: public:
GraphGradientMaker(const GraphDef& forward_def, GraphGradientMaker() : cur_op_idx_(0) {}
const vector<string>& targets)
: cur_op_idx_(0), void Make(const GraphDef& forward_def,
forward_def_(forward_def) { const vector<string>& targets,
for (auto& target : targets) targets_set_.insert(target); GraphDef& new_def);
}
GraphDef Make(); void Share(const string& grads_prefix, GraphDef& graph);
inline void SetTerms(const Map<string, string>& terms) { terms_ = terms; } inline void SetTerms(const Map<string, string>& terms) { terms_ = terms; }
inline void SetOperatorPrefix(const string& prefix) { op_prefix_ = prefix; } inline void SetOperatorPrefix(const string& prefix) { op_prefix_ = prefix; }
inline void SetOperatorSuffix(const string& suffix) { op_suffix_ = suffix; } inline void SetOperatorSuffix(const string& suffix) { op_suffix_ = suffix; }
inline void AddExternalGrad(const string& name) { external_grads_.insert(name); } inline void AddExternalGrad(const string& name) { external_grads_.insert(name); }
inline void AddIgnoreGrad(const string& name) { ignore_grads_.insert(name); }
private: private:
CheckTuple CheckMissingGrad(OperatorDef* forward_op); bool CheckGrad(const OperatorDef& forward_op,
const Set<string>& targets,
vector< pair<string, int> >& gen_grads);
string GetOperatorName(); string GetOperatorName();
GraphDef forward_def_, new_def_;
Map<string, string> terms_, inputs_to_grads_; Map<string, string> terms_, inputs_to_grads_;
Set<string> targets_set_, blacklist_set_, external_grads_; Set<string> blacklist_set_, external_grads_, ignore_grads_;
string op_prefix_, op_suffix_; string op_prefix_, op_suffix_;
int cur_op_idx_; int cur_op_idx_;
}; };
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#ifndef DRAGON_CORE_MIXEDMEM_H_ #ifndef DRAGON_CORE_MIXEDMEM_H_
#define DRAGON_CORE_MIXEDMEM_H_ #define DRAGON_CORE_MIXEDMEM_H_
#include "typeid.h"
#include "context.h" #include "context.h"
#include "context_cuda.h" #include "context_cuda.h"
...@@ -21,40 +20,40 @@ namespace dragon { ...@@ -21,40 +20,40 @@ namespace dragon {
class MixedMemory { class MixedMemory {
public: public:
enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED }; enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED };
MixedMemory() MixedMemory() : cpu_ptr_(nullptr), cuda_ptr_(nullptr) {}
: state_(UNINITIALIZED),
cpu_ptr_(nullptr), cuda_ptr_(nullptr),
nbytes_(0) {}
MixedMemory(const TypeMeta& meta, const size_t nbytes) MixedMemory(const TypeMeta& meta, const size_t nbytes)
: state_(UNINITIALIZED), meta_(meta), : meta_(meta), nbytes_(nbytes),
cpu_ptr_(nullptr), cuda_ptr_(nullptr), cpu_ptr_(nullptr), cuda_ptr_(nullptr) {}
nbytes_(nbytes) {}
~MixedMemory(); ~MixedMemory();
const void* cpu_data(); const void* cpu_data();
const void* cuda_data(); const void* cuda_data();
void* mutable_cpu_data(); void* mutable_cpu_data();
void* mutable_cuda_data(); void* mutable_cuda_data();
void set_cpu_data(void* cpu_ptr, size_t nbytes);
#ifdef WITH_CUDA #ifdef WITH_CUDA
void async_cuda_data(const cudaStream_t& stream); void async_cuda_data(const cudaStream_t& stream);
#endif #endif
void SwitchToDevice(); void SwitchToDevice();
void SwitchToCUDADevice(int device_id);
inline size_t nbytes() const { return nbytes_; } inline size_t nbytes() const { return nbytes_; }
inline void* cpu_ptr() { state_ = STATE_AT_CPU; return cpu_ptr_; } inline void* cpu_ptr() { state_ = STATE_AT_CPU; return cpu_ptr_; }
inline void* cuda_ptr() { state_ = STATE_AT_CUDA; return cuda_ptr_; } inline void* cuda_ptr() { state_ = STATE_AT_CUDA; return cuda_ptr_; }
inline State state() { return state_; } inline State state() const { return state_; }
const Map<string, string> info() const;
private:
void ToCUDA(); void ToCUDA();
void ToCPU(); void ToCPU();
private:
void* cpu_ptr_, *cuda_ptr_; void* cpu_ptr_, *cuda_ptr_;
State state_; bool own_cpu_ptr_ = true;
size_t nbytes_; State state_ = UNINITIALIZED;
size_t nbytes_ = 0;
TypeMeta meta_; TypeMeta meta_;
}; };
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -38,12 +38,18 @@ class OperatorBase { ...@@ -38,12 +38,18 @@ class OperatorBase {
inline size_t InputSize() { return inputs_.size(); } inline size_t InputSize() { return inputs_.size(); }
inline size_t OutputSize() { return outputs_.size(); } inline size_t OutputSize() { return outputs_.size(); }
void MutableOp(const OperatorDef& op_def);
void MutableOp(const vector<string>& inputs,
const vector<string>& outputs,
const string& anchor);
inline void SwitchToPhase(const string& phase) { this->phase_ = phase; } inline void SwitchToPhase(const string& phase) { this->phase_ = phase; }
virtual void Run() { NOT_IMPLEMENTED; } virtual void Run() { NOT_IMPLEMENTED; }
inline const string& name() const { return op_def_.name(); } inline const string& name() const { return op_def_.name(); }
inline const string& type() const { return op_def_.type(); } inline const string& type() const { return op_def_.type(); }
inline const string& phase() const { return phase_; } inline const string& phase() const { return phase_; }
inline const string& anchor() { return anchor_; }
inline Workspace* ws() const { return ws_; } inline Workspace* ws() const { return ws_; }
template <typename T> template <typename T>
...@@ -60,10 +66,11 @@ class OperatorBase { ...@@ -60,10 +66,11 @@ class OperatorBase {
void set_recompute_map(RecomputeMap recompute_map) { recompute_map_ = recompute_map; } void set_recompute_map(RecomputeMap recompute_map) { recompute_map_ = recompute_map; }
inline const OperatorDef& op_def() const { return op_def_; } inline const OperatorDef& op_def() const { return op_def_; }
inline const string DebugString() const { return op_def_.DebugString(); } inline string DebugString() const { return op_def_.DebugString(); }
string DTypeHelper(const Tensor& tensor, const Set<string>& dtypes) const;
protected: protected:
string phase_; string phase_, anchor_;
Map<std::string, const Argument*> args_; Map<std::string, const Argument*> args_;
Map<string, vector<OperatorBase*> > recompute_map_; Map<string, vector<OperatorBase*> > recompute_map_;
vector<Tensor*> inputs_, outputs_; vector<Tensor*> inputs_, outputs_;
...@@ -75,47 +82,41 @@ template <class Context> ...@@ -75,47 +82,41 @@ template <class Context>
class Operator : public OperatorBase { class Operator : public OperatorBase {
public: public:
Operator(const OperatorDef& op_def, Workspace* ws) Operator(const OperatorDef& op_def, Workspace* ws)
: OperatorBase(op_def, ws), ctx_(op_def.device_option()) { : OperatorBase(op_def, ws), ctx_(op_def.device_option()),
do_synchronize_(Operator::GetSingleArg<bool>("do_synchronize", false)),
recomputing_aware_(Operator::GetSingleArg<bool>("recomputing_aware", false)) {
allow_run_ = true; allow_run_ = true;
allow_run_ &= _MPICheck(); allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && Output(0)->name() == "ignore")); allow_run_ &= (!(OutputSize() == 1 && Output(0)->name() == "ignore"));
allow_share_grads_ = (!op_def.debug_mode());
allow_share_grads_ &= op_def.share_grads();
allow_share_grads_ &= (type().find("Gradient") != string::npos);
} }
virtual void Run() final { virtual void Run() final {
if (!allow_run_) return; if (!allow_run_) return;
MakeResource(); if (recomputing_aware_) MakeResource();
ctx_.SwitchToDevice(); ctx_.SwitchToDevice();
MemorySwitch(); MemorySwitch();
RunOnDevice(); RunOnDevice();
ctx_.FinishDeviceCompution(); if (do_synchronize_) ctx_.FinishDeviceCompution();
CleanResource(); if (recomputing_aware_) CleanResource();
} }
virtual void ElimateCorruption(); virtual void ElimateCorruption();
virtual void ShareGradient();
virtual void MakeResource(); virtual void MakeResource();
virtual void CleanResource(); virtual void CleanResource();
void MemorySwitch() { void MemorySwitch() {
for (int i = 0; i < InputSize(); i++) for (auto* I : inputs_) if(I->name() != "ignore") I->SwitchToDevice();
if (Input(i).name() != "ignore") Input(i).SwitchToDevice(); for (auto* O : outputs_) if(O->name() != "ignore") O->SwitchToDevice();
for (int i = 0; i < OutputSize(); i++)
if (Output(i)->name() != "ignore") Output(i)->SwitchToDevice();
} }
virtual void RunOnDevice() = 0; virtual void RunOnDevice() = 0;
inline Context& ctx() { return ctx_; } inline Context& ctx() { return ctx_; }
inline string Anchor() { return GetSingleArg("anchor", name()); }
inline bool AllowRun() { return allow_run_; } inline bool AllowRun() { return allow_run_; }
protected: protected:
Context ctx_; Context ctx_;
bool allow_run_, allow_share_grads_; bool allow_run_, recomputing_aware_, do_synchronize_;
private: private:
bool _MPICheck() { bool _MPICheck() {
...@@ -147,15 +148,16 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws); ...@@ -147,15 +148,16 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
using OperatorBase::name; \ using OperatorBase::name; \
using OperatorBase::type; \ using OperatorBase::type; \
using OperatorBase::phase; \ using OperatorBase::phase; \
using OperatorBase::anchor; \
using OperatorBase::op_def; \ using OperatorBase::op_def; \
using OperatorBase::InputSize; \ using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \ using OperatorBase::OutputSize; \
using OperatorBase::DebugString \ using OperatorBase::DebugString; \
using OperatorBase::DTypeHelper \
#define USE_OPERATOR_FUNCTIONS(context) \ #define USE_OPERATOR_FUNCTIONS(context) \
USE_OPERATOR_BASE_FUNCTIONS; \ USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<context>::ctx; \ using Operator<context>::ctx; \
using Operator<context>::Anchor; \
using Operator<context>::AllowRun using Operator<context>::AllowRun
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*); DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
...@@ -238,8 +240,11 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -238,8 +240,11 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
return argument##_tensor->template data<type, CPUContext>()[0]; \ return argument##_tensor->template data<type, CPUContext>()[0]; \
} }
#define DISABLE_SHARE_GRADIENT \ #define GET_ARGUMENTS_SIZE(argument) \
this->allow_share_grads_ = false std::max(argument##_value.size(), argument##_desc.size())
#define XIsType(x, dtype) \
x.template IsType<dtype>()
#define INSTANTIATE_OPERATOR(name, context) \ #define INSTANTIATE_OPERATOR(name, context) \
template class name##Op<context>; template class name##Op<context>;
...@@ -272,6 +277,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -272,6 +277,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
#define DEPLOY_CUDNN(name) \ #define DEPLOY_CUDNN(name) \
REGISTER_CUDNN_OPERATOR(name, CuDNN##name##Op<CUDAContext>); \ REGISTER_CUDNN_OPERATOR(name, CuDNN##name##Op<CUDAContext>); \
INSTANTIATE_CUDNN_OPERATOR(name); INSTANTIATE_CUDNN_OPERATOR(name);
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_OPERATOR_H_ #endif // DRAGON_CORE_OPERATOR_H_
\ No newline at end of file
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include <functional> #include <functional>
#include "core/common.h" #include "core/common.h"
#include "utils/logging.h"
namespace dragon { namespace dragon {
...@@ -24,11 +23,11 @@ class Registry { ...@@ -24,11 +23,11 @@ class Registry {
public: public:
typedef std::function<ObjType*(Args ...)> Creator; typedef std::function<ObjType*(Args ...)> Creator;
void Register(const SrcType& key, Creator creator) { void Register(const SrcType& key, Creator creator) {
CHECK(!registry_.count(key)) << "Key(" << key << ") has already registered."; CHECK(!registry_.count(key)) << "\nKey(" << key << ") has already registered.";
registry_[key] = creator; registry_[key] = creator;
} }
ObjType* Create(const SrcType& key, Args ... args) { ObjType* Create(const SrcType& key, Args ... args) {
CHECK(registry_.count(key)) << "Key(" << key << ") has not registered yet."; CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered yet.";
return registry_[key](args...); return registry_[key](args...);
} }
bool Has(const SrcType& key) { return (registry_.count(key)) != 0; } bool Has(const SrcType& key) { return (registry_.count(key)) != 0; }
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include <vector> #include <vector>
#include "core/common.h" #include "core/common.h"
#include "core/typeid.h"
#include "core/mixedmem.h" #include "core/mixedmem.h"
namespace dragon { namespace dragon {
...@@ -68,6 +67,7 @@ class Tensor { ...@@ -68,6 +67,7 @@ class Tensor {
inline const vector<TIndex>& dims() const { return dims_; } inline const vector<TIndex>& dims() const { return dims_; }
inline TSize nbytes() const { return size_ * meta_.itemsize(); } inline TSize nbytes() const { return size_ * meta_.itemsize(); }
inline TSize capacity() const { return capacity_; }
inline TIndex count(const TIndex start, const TIndex end) const { inline TIndex count(const TIndex start, const TIndex end) const {
TIndex ret = 1; TIndex ret = 1;
...@@ -97,17 +97,21 @@ class Tensor { ...@@ -97,17 +97,21 @@ class Tensor {
} }
inline string dim_string() const { inline string dim_string() const {
if (ndim() == 0) return "(0,)";
std::stringstream ss; std::stringstream ss;
ss << "("; ss << "(";
for (int i = 0; i < ndim() - 1; i++) ss << dim(i) << ","; for (int i = 0; i < ndim() - 1; i++) ss << dim(i) << ",";
ss << dim(ndim() - 1) << ")"; if (ndim() == 1) ss << dim(0) << ",)";
else ss << dim(ndim() - 1) << ")";
return ss.str(); return ss.str();
} }
inline bool is_corrupted() const { return is_corrupted_; } inline bool is_corrupted() const { return is_corrupted_; }
inline void Corrupt() { is_corrupted_ = true; } inline void Corrupt() { is_corrupted_ = true; }
inline bool has_memory() const { return memory_ || ex_memory_ != nullptr; }
MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; } MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; }
void set_memory(MixedMemory* mem) { memory_.reset(mem); capacity_ = mem->nbytes(); }
MixedMemory::State memory_state() const { MixedMemory::State memory_state() const {
MixedMemory* mem = memory(); MixedMemory* mem = memory();
CHECK(mem) << "\nMemory access before allowcating."; CHECK(mem) << "\nMemory access before allowcating.";
...@@ -142,13 +146,13 @@ class Tensor { ...@@ -142,13 +146,13 @@ class Tensor {
template <class Context> template <class Context>
const void* const_data_ptr() const { const void* const_data_ptr() const {
MixedMemory* mem = memory(); MixedMemory* mem = memory();
CHECK(mem) << "memory access before allowcating."; CHECK(mem) << "\nMemory access before allowcating.";
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) { if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) {
return mem->cpu_data(); return mem->cpu_data();
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) { } else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) {
return mem->cuda_data(); return mem->cuda_data();
} else { } else {
LOG(FATAL) << "unknown memory type access. only CPU or CUDA are supported."; LOG(FATAL) << "Unknown memory type. Only CPU or CUDA are supported.";
return nullptr; return nullptr;
} }
} }
...@@ -192,14 +196,25 @@ class Tensor { ...@@ -192,14 +196,25 @@ class Tensor {
template <typename T, class Context> template <typename T, class Context>
const T* data() const { const T* data() const {
CHECK(meta_ == TypeMeta::Make<T>())
<< "\nThe DType of Tensor(" << name() << ") is "
<< TypeMetaToString(meta_) << ", while required "
<< TypeMetaToString(TypeMeta::Make<T>());
return static_cast<const T*>(raw_data<Context>()); return static_cast<const T*>(raw_data<Context>());
} }
inline void Share(const Tensor& other) { template <class DstCTX, class SrcCTX>
inline void Copy(const Tensor& other) {
CHECK_EQ(size_, other.size_); CHECK_EQ(size_, other.size_);
memory_ = other.memory_;
meta_ = other.meta_; meta_ = other.meta_;
capacity_ = other.capacity_; auto* src = other.template raw_data<SrcCTX>();
auto* dst = raw_mutable_data<DstCTX>();
if (dst == src) return;
if (TypeMeta::Id<DstCTX>() == TypeMeta::Id<CPUContext>()) {
CPUContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src);
} else if (TypeMeta::Id<DstCTX>() == TypeMeta::Id<CUDAContext>()) {
CUDAContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src);
}
} }
inline void Move(MixedMemory* mem) { inline void Move(MixedMemory* mem) {
...@@ -213,14 +228,18 @@ class Tensor { ...@@ -213,14 +228,18 @@ class Tensor {
meta_ = TypeMeta(); meta_ = TypeMeta();
dims_.clear(); dims_.clear();
memory_.reset(); memory_.reset();
if (DECREFPyArray) DECREFPyArray();
} }
std::function<void()> DECREFPyArray;
~Tensor() { /* DO NOT CALL DECREFARRAY */ }
private: private:
vector<TIndex> dims_; vector<TIndex> dims_;
TIndex size_ = 0, capacity_ = 0; TIndex size_ = 0, capacity_ = 0;
TypeMeta meta_; TypeMeta meta_;
string name_; string name_;
shared_ptr<MixedMemory> memory_, host_memory_; shared_ptr<MixedMemory> memory_;
MixedMemory* ex_memory_ = nullptr; MixedMemory* ex_memory_ = nullptr;
bool is_corrupted_ = false, own_mem_ = true; bool is_corrupted_ = false, own_mem_ = true;
}; };
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
#ifndef DRAGON_CORE_TYPES_H_ #ifndef DRAGON_CORE_TYPES_H_
#define DRAGON_CORE_TYPES_H_ #define DRAGON_CORE_TYPES_H_
#include <unordered_map>
#include "core/typeid.h"
namespace dragon { namespace dragon {
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -36,6 +40,31 @@ typedef struct { ...@@ -36,6 +40,31 @@ typedef struct {
#endif #endif
inline const TypeMeta& TypeStringToMeta(const std::string& str_type) {
static std::unordered_map<std::string, TypeMeta> s2m_type_map {
{ "float32", TypeMeta::Make<float>() },
{ "int32", TypeMeta::Make<int>() },
{ "int64", TypeMeta::Make<int64_t>() },
{ "float64", TypeMeta::Make<double>() },
{ "float16", TypeMeta::Make<float16>() },
{ "uint8", TypeMeta::Make<uint8_t>() }
};
static TypeMeta unknown_type;
return s2m_type_map.count(str_type) ? s2m_type_map[str_type] : unknown_type;
}
inline const std::string TypeMetaToString(const TypeMeta& meta) {
static std::unordered_map<TypeId, std::string> m2s_type_map {
{ TypeMeta::Id<float>(), "float32" },
{ TypeMeta::Id<int>(), "int32" },
{ TypeMeta::Id<int64_t>(), "int64" },
{ TypeMeta::Id<double>(), "float64", },
{ TypeMeta::Id<float16>(), "float16" },
{ TypeMeta::Id<uint8_t>(), "uint8" }
};
return m2s_type_map.count(meta.id()) ? m2s_type_map[meta.id()] : "unknown";
}
} // namespace dragon } // namespace dragon
#endif // DRAGON_CORE_TYPES_H_ #endif // DRAGON_CORE_TYPES_H_
\ No newline at end of file
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
namespace dragon { namespace dragon {
#define WORKSPACE_COMMON_BUFFER_SIZE 2 #define WORKSPACE_COMMON_BUFFER_SIZE 2
#define WORKSPACE_GRAD_BUFFER_SIZE 1
#define WORKSPACE_MAX_CORRUPTED_SIZE 2 #define WORKSPACE_MAX_CORRUPTED_SIZE 2
class Workspace { class Workspace {
...@@ -28,10 +27,10 @@ class Workspace { ...@@ -28,10 +27,10 @@ class Workspace {
typedef Map<string, unique_ptr<Tensor> > TensorMap; typedef Map<string, unique_ptr<Tensor> > TensorMap;
typedef Map<string, stack<string> > BufferMap; typedef Map<string, stack<string> > BufferMap;
typedef Map<string, unique_ptr<mutex> > LockMap; typedef Map<string, unique_ptr<mutex> > LockMap;
typedef Map<string, unique_ptr<OperatorBase> > OperatorMap;
typedef Map<string, unique_ptr<GraphBase> > GraphMap; typedef Map<string, unique_ptr<GraphBase> > GraphMap;
typedef Map<string, TensorFiller> FillerMap; typedef Map<string, TensorFiller> FillerMap;
typedef Map<string, string> RenameMap; typedef Map<string, string> RenameMap;
typedef Map<string, string> AvatarMap;
Workspace(const string& name) : name_(name) { Init(); } Workspace(const string& name) : name_(name) { Init(); }
~Workspace(); ~Workspace();
...@@ -39,7 +38,16 @@ class Workspace { ...@@ -39,7 +38,16 @@ class Workspace {
void Init() { void Init() {
CreateTensor("ignore"); CreateTensor("ignore");
CreateBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE); CreateBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE);
CreateBuffer("Grad", WORKSPACE_GRAD_BUFFER_SIZE); Tensor* head = CreateTensor("/opt/mirror_stage/head");
head->Reshape(vector<TIndex>(1, WORKSPACE_MAX_CORRUPTED_SIZE));
Tensor* recompute_flag = CreateTensor("/opt/mirror_stage/recompute_flag");
recompute_flag->Reshape(vector<TIndex>(1, 1));
recompute_flag->mutable_data<bool, CPUContext>()[0] = false;
for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) {
string name = "/opt/mirror_stage/buffer_" + dragon_cast<string, int>(i);
Tensor* buffer = CreateTensor(name);
head->mutable_data<string, CPUContext>()[i] = "";
}
} }
inline const string& name() { return name_; } inline const string& name() { return name_; }
...@@ -54,13 +62,11 @@ class Workspace { ...@@ -54,13 +62,11 @@ class Workspace {
} }
inline void ClearWorkspace() { inline void ClearWorkspace() {
// clear the relationship of avatars // clear tensors & buffers
avatar_map_.clear();
// clear the buffers
ResetBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE);
ResetBuffer("Grad", WORKSPACE_GRAD_BUFFER_SIZE);
// clear tenosrs
for (auto& kv : tensor_map_) kv.second->Reset(); for (auto& kv : tensor_map_) kv.second->Reset();
ResetBuffers("Common");
// Re-Initialization
Init();
} }
/******************** Tensor ********************/ /******************** Tensor ********************/
...@@ -71,26 +77,7 @@ class Workspace { ...@@ -71,26 +77,7 @@ class Workspace {
} else { return name; } } else { return name; }
} }
bool HasTensor(const string& name, bool use_remote=true) { inline Tensor* TryGetTensor(const string& name, bool use_remote=true) {
// search local workspace
string query = GetTensorName(name);
bool result = tensor_map_.count(query) > 0;
if (!use_remote) return result;
// search remote workspace
for (auto& it : workspace_map_)
result |= it.second->HasTensor(query);
return result;
}
inline Tensor* CreateTensor(const string& name) {
string query = GetTensorName(name);
if (!HasTensor(query))
tensor_map_[query] = unique_ptr<Tensor>(new Tensor(query));
return GetTensor(query);
}
Tensor* GetTensor(const string& name, bool use_remote=true) {
string query = GetTensorName(name); string query = GetTensorName(name);
// search local workspace // search local workspace
if (tensor_map_.count(query) > 0) if (tensor_map_.count(query) > 0)
...@@ -102,11 +89,29 @@ class Workspace { ...@@ -102,11 +89,29 @@ class Workspace {
return it.second->GetTensor(query); return it.second->GetTensor(query);
} }
} }
LOG(FATAL) << "Tensor(" << name << ") does not exist "
<< "in current workspace and it's sub-workspace.";
return nullptr; return nullptr;
} }
inline bool HasTensor(const string& name, bool use_remote=true) {
return TryGetTensor(name, use_remote) ? true : false;
}
inline Tensor* CreateTensor(const string& name) {
Tensor* tensor = TryGetTensor(name);
if (!tensor) {
tensor_map_[name] = unique_ptr<Tensor>(new Tensor(name));
return tensor_map_[name].get();
}
return tensor;
}
inline Tensor* GetTensor(const string& name, bool use_remote=true) {
Tensor* tensor = TryGetTensor(name, use_remote);
CHECK(tensor) << "\nTensor(" << name << ") does not exist "
<< "in current workspace or sub-workspace.";
return tensor;
}
inline void LockTensor(const string& name) { inline void LockTensor(const string& name) {
string query = GetTensorName(name); string query = GetTensorName(name);
if (!lock_map_.count(query)) if (!lock_map_.count(query))
...@@ -121,12 +126,11 @@ class Workspace { ...@@ -121,12 +126,11 @@ class Workspace {
lock_map_[query]->unlock(); lock_map_[query]->unlock();
} }
inline void ReleaseTensor(const string& name) { inline void ResetTensor(const string& name) {
CHECK(HasTensor(name, false)) Tensor* tensor = TryGetTensor(name, false);
<< "\nTensor(" << name << ") does not " CHECK(tensor) << "\nTensor(" << name << ") does not "
<< "belong to current workspace, could not release it."; << "belong to current workspace, could not be reset.";
string query = GetTensorName(name); tensor->Reset();
tensor_map_[query]->Reset();
} }
vector<string> GetTensors() { vector<string> GetTensors() {
...@@ -144,7 +148,7 @@ class Workspace { ...@@ -144,7 +148,7 @@ class Workspace {
/******************** Filler ********************/ /******************** Filler ********************/
bool HasFiller(const string& name, bool use_remote=true) { inline bool HasFiller(const string& name, bool use_remote=true) {
// search local workspace // search local workspace
bool result = filler_map_.count(name) > 0; bool result = filler_map_.count(name) > 0;
if (!use_remote) return result; if (!use_remote) return result;
...@@ -175,24 +179,9 @@ class Workspace { ...@@ -175,24 +179,9 @@ class Workspace {
return nullptr; return nullptr;
} }
/******************** Avatar ********************/
inline void CreateAvatar(Tensor* orig, Tensor* avatar) {
CHECK(tensor_map_.count(orig->name()) > 0)
<< "\nFailed to create avatar for Tensor(" << orig->name() << ")."
<< "\nAs it has not been registered in the current workspace.";
avatar_map_[orig->name()] = avatar->name();
}
inline Tensor* SearchAvatar(Tensor* orig) {
if (avatar_map_.count(orig->name()) > 0)
return GetTensor(avatar_map_[orig->name()]);
return orig;
}
/******************** Buffer ********************/ /******************** Buffer ********************/
void CreateBuffer(string category, int num) { inline void CreateBuffer(string category, int num) {
if (!buffer_map_.count(category)) if (!buffer_map_.count(category))
buffer_map_[category] = stack<string>(); buffer_map_[category] = stack<string>();
for (int i = 1; i <= num; i++) { for (int i = 1; i <= num; i++) {
...@@ -213,24 +202,13 @@ class Workspace { ...@@ -213,24 +202,13 @@ class Workspace {
return nullptr; return nullptr;
} }
void ResetBuffer(string category, int num) { inline void ReleaseBuffer(Tensor* tensor,
while (!buffer_map_[category].empty()) {
string name = buffer_map_[category].top();
buffer_map_[category].pop();
tensor_map_[name]->Reset();
}
CreateBuffer(category, num);
}
void ReleaseBuffer(Tensor* tensor,
string category = "Common", string category = "Common",
bool enforce = false) { bool enforce = false) {
static Map<string, int> limits = { static Map<string, int> limits = {
{ "Common", WORKSPACE_COMMON_BUFFER_SIZE }, { "Common", WORKSPACE_COMMON_BUFFER_SIZE }};
{ "Grad", WORKSPACE_GRAD_BUFFER_SIZE }
};
if (buffer_map_[category].size() >= limits[category] || enforce) { if (buffer_map_[category].size() >= limits[category] || enforce) {
ReleaseTensor(tensor->name()); ResetTensor(tensor->name());
if (buffer_map_[category].empty()) if (buffer_map_[category].empty())
buffer_map_[category].push(tensor->name()); buffer_map_[category].push(tensor->name());
} else { } else {
...@@ -238,18 +216,68 @@ class Workspace { ...@@ -238,18 +216,68 @@ class Workspace {
} }
} }
inline void ResetBuffers(string category) {
while (!buffer_map_[category].empty()) {
string name = buffer_map_[category].top();
buffer_map_[category].pop();
tensor_map_[name]->Reset();
}
}
/******************** Operator ********************/
inline void CreatePersistentOp(const OperatorDef& meta_op) {
string persistent_key;
for (auto& arg : meta_op.arg())
if (arg.name() == "persistent_key")
persistent_key = arg.s();
CHECK(persistent_key.size() > 0) << "\nGot empty persistent key.";
if (!op_map_.count(persistent_key)) {
for (auto& input : meta_op.input()) CreateTensor(input);
op_map_[persistent_key] = unique_ptr<OperatorBase>(
CreateOperator(meta_op, this));
}
}
inline void RunPersistentOp(const string& key, const string& anchor,
const vector<string>& inputs,
const vector<string>& outputs) {
CHECK(op_map_.count(key) > 0)
<< "\nPersistentOp(" << key << ") does not exist.";
op_map_[key]->MutableOp(inputs, outputs, anchor);
op_map_[key]->Run();
}
void RunOperator(const OperatorDef& meta_op) {
string persistent_key;
for (auto& arg : meta_op.arg()) {
if (arg.name() == "persistent_key")
persistent_key = arg.s();
}
if (persistent_key.empty()) {
// run op in the "ONCE" mode
unique_ptr<OperatorBase> op(CreateOperator(meta_op, this));
op->Run();
} else {
// run op in the "PERSISTENT" mode
if (!op_map_.count(persistent_key))
op_map_[persistent_key] = unique_ptr<OperatorBase>(
CreateOperator(meta_op, this));
else op_map_[persistent_key]->MutableOp(meta_op);
op_map_[persistent_key]->Run();
}
}
/******************** Graph ********************/ /******************** Graph ********************/
GraphBase* CreateGraph(const GraphDef& meta_graph); GraphBase* CreateGraph(const GraphDef& meta_graph);
bool RunGraph(const string& graph_name, void RunGraph(const string& graph_name,
const string& include, const string& include,
const string& exclude) { const string& exclude) {
if (!graph_map_.count(graph_name)) { if (!graph_map_.count(graph_name))
LOG(ERROR) << "Graph(" << graph_name << ") does not exist."; LOG(FATAL) << "Graph(" << graph_name << ") does not exist.";
return false; graph_map_[graph_name]->Run(include, exclude);
}
return graph_map_[graph_name]->Run(include, exclude);
} }
vector<string> GetGraphs() { vector<string> GetGraphs() {
...@@ -271,10 +299,10 @@ class Workspace { ...@@ -271,10 +299,10 @@ class Workspace {
TensorMap tensor_map_; TensorMap tensor_map_;
BufferMap buffer_map_; BufferMap buffer_map_;
LockMap lock_map_; LockMap lock_map_;
OperatorMap op_map_;
GraphMap graph_map_; GraphMap graph_map_;
FillerMap filler_map_; FillerMap filler_map_;
RenameMap rename_map_; RenameMap rename_map_;
AvatarMap avatar_map_;
}; };
} // namespace dragon } // namespace dragon
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -43,7 +43,6 @@ class DropoutGradientOp final : public Operator<Context> { ...@@ -43,7 +43,6 @@ class DropoutGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) { use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5); GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -36,9 +36,7 @@ class EluGradientOp : public Operator<Context> { ...@@ -36,9 +36,7 @@ class EluGradientOp : public Operator<Context> {
public: public:
EluGradientOp(const OperatorDef& op_def, Workspace* ws) EluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) { alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -36,9 +36,7 @@ class ReluGradientOp : public Operator<Context> { ...@@ -36,9 +36,7 @@ class ReluGradientOp : public Operator<Context> {
public: public:
ReluGradientOp(const OperatorDef& op_def, Workspace* ws) ReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) { slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -19,8 +19,7 @@ namespace dragon { ...@@ -19,8 +19,7 @@ namespace dragon {
template <class Context> template <class Context>
class SEluOp : public Operator<Context> { class SEluOp : public Operator<Context> {
public: public:
SEluOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(SEluOp);
: Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
...@@ -30,10 +29,7 @@ class SEluOp : public Operator<Context> { ...@@ -30,10 +29,7 @@ class SEluOp : public Operator<Context> {
template <class Context> template <class Context>
class SEluGradientOp : public Operator<Context> { class SEluGradientOp : public Operator<Context> {
public: public:
SEluGradientOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(SEluGradientOp);
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -29,10 +29,7 @@ class SigmoidOp : public Operator<Context> { ...@@ -29,10 +29,7 @@ class SigmoidOp : public Operator<Context> {
template <class Context> template <class Context>
class SigmoidGradientOp : public Operator<Context> { class SigmoidGradientOp : public Operator<Context> {
public: public:
SigmoidGradientOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(SigmoidGradientOp);
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -38,9 +38,7 @@ class SoftmaxGradientOp final : public Operator<Context> { ...@@ -38,9 +38,7 @@ class SoftmaxGradientOp final : public Operator<Context> {
public: public:
SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws) SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) { axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -29,10 +29,7 @@ class TanhOp : public Operator<Context> { ...@@ -29,10 +29,7 @@ class TanhOp : public Operator<Context> {
template <class Context> template <class Context>
class TanhGradientOp : public Operator<Context> { class TanhGradientOp : public Operator<Context> {
public: public:
TanhGradientOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(TanhGradientOp);
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -36,7 +36,6 @@ class AddGradientOp final : public Operator<Context> { ...@@ -36,7 +36,6 @@ class AddGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(AddGradientOp); USE_SIMPLE_CTOR_DTOR(AddGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
...@@ -65,7 +64,6 @@ class RAddGradientOp final : public Operator<Context> { ...@@ -65,7 +64,6 @@ class RAddGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(RAddGradientOp); USE_SIMPLE_CTOR_DTOR(RAddGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -38,9 +38,7 @@ class BiasAddGradientOp final : public Operator<Context> { ...@@ -38,9 +38,7 @@ class BiasAddGradientOp final : public Operator<Context> {
public: public:
BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws) BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -36,7 +36,6 @@ class DivGradientOp final : public Operator<Context> { ...@@ -36,7 +36,6 @@ class DivGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(DivGradientOp); USE_SIMPLE_CTOR_DTOR(DivGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
...@@ -65,7 +64,6 @@ class RDivGradientOp final : public Operator<Context> { ...@@ -65,7 +64,6 @@ class RDivGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(RDivGradientOp); USE_SIMPLE_CTOR_DTOR(RDivGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -44,7 +44,6 @@ class DotGradientOp final : public Operator<Context> { ...@@ -44,7 +44,6 @@ class DotGradientOp final : public Operator<Context> {
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void DotRunWithType(); template <typename T> void DotRunWithType();
template <typename T> void GemmRunWithType(); template <typename T> void GemmRunWithType();
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -55,7 +55,6 @@ class EltwiseGradientOp final : public Operator<Context> { ...@@ -55,7 +55,6 @@ class EltwiseGradientOp final : public Operator<Context> {
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunWithType();
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -43,7 +43,6 @@ class MatmulGradientOp final : public Operator<Context> { ...@@ -43,7 +43,6 @@ class MatmulGradientOp final : public Operator<Context> {
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -36,7 +36,6 @@ class MulGradientOp final : public Operator<Context> { ...@@ -36,7 +36,6 @@ class MulGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(MulGradientOp); USE_SIMPLE_CTOR_DTOR(MulGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
...@@ -65,7 +64,6 @@ class RMulGradientOp final : public Operator<Context> { ...@@ -65,7 +64,6 @@ class RMulGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(RMulGradientOp); USE_SIMPLE_CTOR_DTOR(RMulGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -36,7 +36,6 @@ class SubGradientOp final : public Operator<Context> { ...@@ -36,7 +36,6 @@ class SubGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(SubGradientOp); USE_SIMPLE_CTOR_DTOR(SubGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
...@@ -65,7 +64,6 @@ class RSubGradientOp final : public Operator<Context> { ...@@ -65,7 +64,6 @@ class RSubGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(RSubGradientOp); USE_SIMPLE_CTOR_DTOR(RSubGradientOp);
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -67,16 +67,13 @@ class ScanGradientOp final: public Operator<Context> { ...@@ -67,16 +67,13 @@ class ScanGradientOp final: public Operator<Context> {
// handle GI(x) // handle GI(x)
for (int i = 0; i < forward_inputs.size(); i++) for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = Output(i)->name(); terms[forward_inputs[i] + "_grad"] = Output(i)->name();
DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
void MakeGradientOps(); void MakeOps(const GraphDef& forward_def, GraphDef& new_def);
protected: protected:
GraphDef forward_def, new_def;
Map<string, string> terms; Map<string, string> terms;
Map<int, unique_ptr<Graph>> graphs; Map<int, unique_ptr<Graph>> graphs;
vector<string> forward_inputs, forward_outputs; vector<string> forward_inputs, forward_outputs;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -40,7 +40,6 @@ class L1LossGradientOp final : public Operator<Context> { ...@@ -40,7 +40,6 @@ class L1LossGradientOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -40,7 +40,6 @@ class L2LossGradientOp final : public Operator<Context> { ...@@ -40,7 +40,6 @@ class L2LossGradientOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -21,17 +21,15 @@ class SmoothL1LossOp final : public Operator<Context> { ...@@ -21,17 +21,15 @@ class SmoothL1LossOp final : public Operator<Context> {
public: public:
SmoothL1LossOp(const OperatorDef& op_def, Workspace* ws) SmoothL1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)), beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
sigma2 *= sigma2;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float sigma2; float beta;
Tensor* diff, *error; Tensor* diff, *error;
string normalization; string normalization;
}; };
...@@ -41,17 +39,15 @@ class SmoothL1LossGradientOp final : public Operator<Context> { ...@@ -41,17 +39,15 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
public: public:
SmoothL1LossGradientOp(const OperatorDef& op_def, Workspace* ws) SmoothL1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)), beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
sigma2 *= sigma2;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float sigma2; float beta;
Tensor* diff; Tensor* diff;
string normalization; string normalization;
}; };
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -23,16 +23,10 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> { ...@@ -23,16 +23,10 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ Input(0).name() }),
vector<string>({ "/mnt/" + Anchor() + "/softmax/prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void SoftmaxRun();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -29,18 +29,12 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> { ...@@ -29,18 +29,12 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
int* ignore_data = ignore.mutable_data<int, CPUContext>(); int* ignore_data = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i]; for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
} }
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ Input(0).name() }),
vector<string>({ "/mnt/" + Anchor() + "/softmax/prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws));
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void SoftmaxRun();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename Tx, typename Ty> void RunWithType();
protected: protected:
TIndex axis, outer_dim, inner_dim; TIndex axis, outer_dim, inner_dim;
...@@ -67,7 +61,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> { ...@@ -67,7 +61,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename Tx, typename Ty> void RunWithType();
protected: protected:
TIndex axis, outer_dim, inner_dim; TIndex axis, outer_dim, inner_dim;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -24,7 +24,6 @@ class GradientGenerateOp final: public Operator<Context> { ...@@ -24,7 +24,6 @@ class GradientGenerateOp final: public Operator<Context> {
defaults(OperatorBase::GetRepeatedArg<float>("defaults")) { defaults(OperatorBase::GetRepeatedArg<float>("defaults")) {
CHECK_EQ(InputSize(), OutputSize()); CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize()); CHECK_EQ(defaults.size(), OutputSize());
DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
...@@ -42,7 +41,6 @@ class GradientGatherOp final : public Operator<Context> { ...@@ -42,7 +41,6 @@ class GradientGatherOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
for (int i = 0; i < InputSize(); i++) for (int i = 0; i < InputSize(); i++)
if (Input(i).name() != "ignore") indices.push_back(i); if (Input(i).name() != "ignore") indices.push_back(i);
DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
...@@ -56,10 +54,7 @@ class GradientGatherOp final : public Operator<Context> { ...@@ -56,10 +54,7 @@ class GradientGatherOp final : public Operator<Context> {
template <class Context> template <class Context>
class StopGradientOp final : public Operator<Context> { class StopGradientOp final : public Operator<Context> {
public: public:
StopGradientOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(StopGradientOp);
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -45,9 +45,7 @@ template <class Context> ...@@ -45,9 +45,7 @@ template <class Context>
class TemplateGradientOp : public TemplateOp<Context> { class TemplateGradientOp : public TemplateOp<Context> {
public: public:
TemplateGradientOp(const OperatorDef& op_def, Workspace* ws) TemplateGradientOp(const OperatorDef& op_def, Workspace* ws)
: TemplateOp<Context>(op_def, ws) { : TemplateOp<Context>(op_def, ws) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -34,9 +34,7 @@ template <class Context> ...@@ -34,9 +34,7 @@ template <class Context>
class MPIBroadcastGradientOp final : public ModelMPIBase<Context> { class MPIBroadcastGradientOp final : public ModelMPIBase<Context> {
public: public:
MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws) MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) { : ModelMPIBase<Context>(op_def, ws) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -34,9 +34,7 @@ template <class Context> ...@@ -34,9 +34,7 @@ template <class Context>
class MPIGatherGradientOp final : public ModelMPIBase<Context> { class MPIGatherGradientOp final : public ModelMPIBase<Context> {
public: public:
MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws) MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) { : ModelMPIBase<Context>(op_def, ws) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context); USE_MPIMODEL_FUNCTIONS(Context);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -40,9 +40,7 @@ class ConcatGradientOp : public Operator<Context> { ...@@ -40,9 +40,7 @@ class ConcatGradientOp : public Operator<Context> {
ConcatGradientOp(const OperatorDef& op_def, Workspace* ws) ConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) { nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -21,12 +21,13 @@ class CropOp: public Operator<Context> { ...@@ -21,12 +21,13 @@ class CropOp: public Operator<Context> {
public: public:
CropOp(const OperatorDef& op_def, Workspace* ws) CropOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
starts(OperatorBase::GetRepeatedArg<int>("starts")),
ends(OperatorBase::GetRepeatedArg<int>("ends")),
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)), start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")), offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")), shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {} shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, starts);
GET_ARGUMENTS_WITH_DESC(int, ends);
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -36,25 +37,21 @@ class CropOp: public Operator<Context> { ...@@ -36,25 +37,21 @@ class CropOp: public Operator<Context> {
protected: protected:
TIndex start_axis; TIndex start_axis;
string shape_like; string shape_like;
vector<int> starts, ends, offsets, shape; vector<int> st, ed, offsets, shape, keep_dims;
DECLARE_ARGUMENTS_WITH_DESC(int, starts);
DECLARE_ARGUMENTS_WITH_DESC(int, ends);
vector< pair<int, int> > process_axes; vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim; TIndex axis, inner_dim, dim;
Tensor* dest, *source; Tensor* dest, *source;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, starts);
DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, ends);
template <class Context> template <class Context>
class CropGradientOp final : public Operator<Context > { class CropGradientOp final : public Operator<Context > {
public: public:
CropGradientOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(CropGradientOp);
: Operator<Context>(op_def, ws),
starts(OperatorBase::GetRepeatedArg<int>("starts")),
ends(OperatorBase::GetRepeatedArg<int>("ends")),
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
...@@ -62,9 +59,7 @@ class CropGradientOp final : public Operator<Context > { ...@@ -62,9 +59,7 @@ class CropGradientOp final : public Operator<Context > {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex start_axis; vector<int> st, ed, offsets, keep_dims;
string shape_like;
vector<int> starts, ends, offsets, shape;
vector< pair<int, int> > process_axes; vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim; TIndex axis, inner_dim, dim;
Tensor* dest, *source; Tensor* dest, *source;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -33,10 +33,7 @@ class ExpandDimsOp final : public Operator<Context> { ...@@ -33,10 +33,7 @@ class ExpandDimsOp final : public Operator<Context> {
template <class Context> template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> { class ExpandDimsGradientOp final : public Operator<Context> {
public: public:
ExpandDimsGradientOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(ExpandDimsGradientOp);
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -37,10 +37,7 @@ class FlattenOp final : public Operator<Context> { ...@@ -37,10 +37,7 @@ class FlattenOp final : public Operator<Context> {
template <class Context> template <class Context>
class FlattenGradientOp final : public Operator<Context> { class FlattenGradientOp final : public Operator<Context> {
public: public:
FlattenGradientOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(FlattenGradientOp);
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -69,7 +69,6 @@ class PadGradientOp final : public Operator<Context> { ...@@ -69,7 +69,6 @@ class PadGradientOp final : public Operator<Context> {
} }
std::sort(process_axes.begin(), process_axes.end()); std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end()); std::reverse(process_axes.begin(), process_axes.end());
DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -40,9 +40,7 @@ class RandomPickGradientOp final : public Operator<Context> { ...@@ -40,9 +40,7 @@ class RandomPickGradientOp final : public Operator<Context> {
public: public:
RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws) RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) { axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -37,10 +37,7 @@ class ReshapeOp final : public Operator<Context> { ...@@ -37,10 +37,7 @@ class ReshapeOp final : public Operator<Context> {
template <class Context> template <class Context>
class ReshapeGradientOp final : public Operator<Context> { class ReshapeGradientOp final : public Operator<Context> {
public: public:
ReshapeGradientOp(const OperatorDef& op_def, Workspace* ws) USE_SIMPLE_CTOR_DTOR(ReshapeGradientOp);
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -41,9 +41,7 @@ class SliceGradientOp final : public Operator<Context> { ...@@ -41,9 +41,7 @@ class SliceGradientOp final : public Operator<Context> {
SliceGradientOp(const OperatorDef& op_def, Workspace* ws): SliceGradientOp(const OperatorDef& op_def, Workspace* ws):
Operator<Context>(op_def, ws), Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) { nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -43,7 +43,6 @@ class StackGradientOp : public Operator<Context> { ...@@ -43,7 +43,6 @@ class StackGradientOp : public Operator<Context> {
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {} nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -40,7 +40,6 @@ class TileGradientOp : public Operator<Context> { ...@@ -40,7 +40,6 @@ class TileGradientOp : public Operator<Context> {
TileGradientOp(const OperatorDef& op_def, Workspace* ws) TileGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples); GET_ARGUMENTS_WITH_DESC(int, multiples);
DISABLE_SHARE_GRADIENT;
} }
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -118,8 +118,6 @@ class FusedBatchNormGradientOp : public Operator<Context> { ...@@ -118,8 +118,6 @@ class FusedBatchNormGradientOp : public Operator<Context> {
void Setup(); void Setup();
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void TrainingRunWithType(); template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType(); template <typename T> void InferenceRunWithType();
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -23,10 +23,7 @@ class GroupNormOp : public Operator<Context> { ...@@ -23,10 +23,7 @@ class GroupNormOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))) {
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
...@@ -36,18 +33,15 @@ class GroupNormOp : public Operator<Context> { ...@@ -36,18 +33,15 @@ class GroupNormOp : public Operator<Context> {
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void TrainingRunWithType(); template <typename T> void RunWithType();
template <typename T> void InferenceRunWithType();
protected: protected:
float momentum, eps; float eps;
Tensor mean, num_by_chans; Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier; Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var; Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS; TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format, mode; string data_format;
int use_stats;
bool use_global_stats, is_recomputing;
}; };
template <class Context> template <class Context>
...@@ -56,8 +50,7 @@ class GroupNormGradientOp final : public Operator<Context> { ...@@ -56,8 +50,7 @@ class GroupNormGradientOp final : public Operator<Context> {
GroupNormGradientOp(const OperatorDef& op_def, Workspace *ws) GroupNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
if (axis != -1) if (axis != -1)
CHECK_EQ(axis, 1) CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1."; << "\nThe axis can only be set to 1.";
...@@ -67,8 +60,7 @@ class GroupNormGradientOp final : public Operator<Context> { ...@@ -67,8 +60,7 @@ class GroupNormGradientOp final : public Operator<Context> {
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void TrainingRunWithType(); template <typename T> void RunWithType();
template <typename T> void InferenceRunWithType();
protected: protected:
Tensor num_by_chans; Tensor num_by_chans;
...@@ -76,8 +68,6 @@ class GroupNormGradientOp final : public Operator<Context> { ...@@ -76,8 +68,6 @@ class GroupNormGradientOp final : public Operator<Context> {
Tensor* stddev, *var; Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS; TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format; string data_format;
int use_stats;
bool use_global_stats;
}; };
template <class Context> template <class Context>
...@@ -87,26 +77,21 @@ class FusedGroupNormOp : public Operator<Context> { ...@@ -87,26 +77,21 @@ class FusedGroupNormOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))) {}
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void TrainingRunWithType(); template <typename T> void RunWithType();
template <typename T> void InferenceRunWithType();
protected: protected:
float momentum, eps; float eps;
Tensor num_by_chans; Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier; Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm; Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS; TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format; string data_format;
int use_stats;
bool use_global_stats, is_recomputing;
}; };
template <class Context> template <class Context>
...@@ -116,17 +101,13 @@ class FusedGroupNormGradientOp : public Operator<Context> { ...@@ -116,17 +101,13 @@ class FusedGroupNormGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)), group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))) {}
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void Setup(); void Setup();
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void TrainingRunWithType(); template <typename T> void RunWithType();
template <typename T> void InferenceRunWithType();
protected: protected:
float eps; float eps;
...@@ -135,8 +116,6 @@ class FusedGroupNormGradientOp : public Operator<Context> { ...@@ -135,8 +116,6 @@ class FusedGroupNormGradientOp : public Operator<Context> {
Tensor* mean, *var, *stddev, *x_norm; Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS; TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format; string data_format;
int use_stats;
bool use_global_stats;
}; };
} // namespace dragon } // namespace dragon
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -37,9 +37,7 @@ template <class Context> ...@@ -37,9 +37,7 @@ template <class Context>
class LSTMUnitGradientOp : public Operator<Context> { class LSTMUnitGradientOp : public Operator<Context> {
public: public:
LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws) LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {}
this->allow_share_grads_ = false;
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override; void RunOnDevice() override;
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -20,20 +20,14 @@ template <class Context> ...@@ -20,20 +20,14 @@ template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> { class AdamUpdateOp final : public UpdateOpBase<Context> {
public: public:
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws) AdamUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws), t(0) {}
t(0),
eps(Param("eps")),
beta1(Param("beta1")),
beta2(Param("beta2")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override; void ComputeRunWithFloat() override;
protected: protected:
float lr, beta1, beta2, eps, coeff; int t; float lr, beta1, beta2, eps;
int t;
Tensor* m, *v, *tmp;
}; };
} // namespace dragon } // namespace dragon
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -20,8 +20,7 @@ template <class Context> ...@@ -20,8 +20,7 @@ template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> { class NesterovUpdateOp final : public UpdateOpBase<Context> {
public: public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws) NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws) {}
momentum(Param("momentum")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
...@@ -29,7 +28,6 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> { ...@@ -29,7 +28,6 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
protected: protected:
float lr, momentum; float lr, momentum;
Tensor* h, *tmp;
}; };
} // namespace dragon } // namespace dragon
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -20,9 +20,7 @@ template <class Context> ...@@ -20,9 +20,7 @@ template <class Context>
class RMSPropUpdateOp final : public UpdateOpBase<Context> { class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public: public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws) RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws) {}
eps(Param("eps")),
decay(Param("decay")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
...@@ -30,7 +28,6 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> { ...@@ -30,7 +28,6 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
protected: protected:
float lr, decay, eps; float lr, decay, eps;
Tensor* h, *tmp;
}; };
} // namespace dragon } // namespace dragon
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -20,8 +20,7 @@ template <class Context> ...@@ -20,8 +20,7 @@ template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> { class SGDUpdateOp final : public UpdateOpBase<Context> {
public: public:
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws) SGDUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws), : UpdateOpBase<Context>(op_def, ws) {}
momentum(Param("momentum")) {}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context); USE_UPDATER_FUNCTIONS(Context);
...@@ -29,8 +28,6 @@ class SGDUpdateOp final : public UpdateOpBase<Context> { ...@@ -29,8 +28,6 @@ class SGDUpdateOp final : public UpdateOpBase<Context> {
protected: protected:
float lr, momentum; float lr, momentum;
Tensor* h;
}; };
} // namespace dragon } // namespace dragon
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -23,7 +23,10 @@ class UpdateOpBase : public Operator<Context> { ...@@ -23,7 +23,10 @@ class UpdateOpBase : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)), lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)), decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)),
domain(OperatorBase::GetSingleArg<string>("domain", "_")) {} slot(OperatorBase::GetSingleArg<string>("slot", "")),
zero_grad(OperatorBase::GetSingleArg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot";
}
USE_OPERATOR_FUNCTIONS(Context); USE_OPERATOR_FUNCTIONS(Context);
float Param(const string& name) const; float Param(const string& name) const;
...@@ -37,7 +40,8 @@ class UpdateOpBase : public Operator<Context> { ...@@ -37,7 +40,8 @@ class UpdateOpBase : public Operator<Context> {
protected: protected:
float lr_mult, decay_mult; float lr_mult, decay_mult;
float l2_decay, clip_thresh, scale_factor; float l2_decay, clip_thresh, scale_factor;
string domain; string slot;
bool zero_grad;
}; };
#define USE_UPDATER_FUNCTIONS(context) \ #define USE_UPDATER_FUNCTIONS(context) \
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -95,6 +95,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -95,6 +95,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
...@@ -107,6 +108,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -107,6 +108,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t workspace_fwd_data_size;
TIndex bias_offset, cudnn_group; TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
}; };
template <class Context> template <class Context>
...@@ -151,6 +153,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -151,6 +153,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
...@@ -164,6 +167,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -164,6 +167,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size; size_t workspace_bwd_filter_size, workspace_bwd_data_size;
TIndex bias_offset, cudnn_group; TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -98,6 +98,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -98,6 +98,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
...@@ -110,6 +111,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -110,6 +111,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t workspace_fwd_data_size;
TIndex bias_offset, cudnn_group; TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
}; };
template <class Context> template <class Context>
...@@ -154,6 +156,7 @@ public: ...@@ -154,6 +156,7 @@ public:
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
...@@ -167,6 +170,7 @@ public: ...@@ -167,6 +170,7 @@ public:
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size; size_t workspace_bwd_filter_size, workspace_bwd_data_size;
TIndex bias_offset, cudnn_group; TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
...@@ -39,7 +39,6 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> { ...@@ -39,7 +39,6 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> {
TIndex growth_rate; TIndex growth_rate;
}; };
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_ #endif // DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
\ No newline at end of file
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------ // ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd. // Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
// //
// Licensed under the BSD 2-Clause License. // Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License // You should have received a copy of the BSD 2-Clause License
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!