Commit 176b7bbb by Ting PAN

Dragon 0.2.2 Preview

1 parent 310bcb5f
Showing with 403 additions and 367 deletions
# DragonDocker
**Note: use [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) to run all GPU builds.**
This folder lists some basic Dockerfiles.
You can modify them in order to be compatible with your environment.
For the built official images, See https://hub.docker.com/r/seetaresearch/dragon
\ No newline at end of file
FROM ubuntu:16.04
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
wget \
unzip \
ssh \
vim \
libprotobuf-dev \
protobuf-compiler \
libopenblas-dev \
python3-pip \
python3-dev \
python3-pyqt4 \
python3-tk \
&& rm -rf /var/lib/apt/lists/*
RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \
protobuf \
lmdb \
opencv-python \
six \
Pillow
matplotlib \
pyyaml
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/3rdparty.zip && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \
cd openmpi && ls | grep -v install | xargs rm -r && cp install/bin/mpirun /usr/bin
RUN git clone https://github.com/seetaresearch/Dragon.git && \
cd Dragon/Dragon && rm CMakeLists.txt && \
wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cpu-openblas/CMakeLists.txt && \
mkdir build && cd build && cmake .. && make install -j8 && cd .. && rm -rf build && \
cd python && python3 setup.py install
\ No newline at end of file
FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
wget \
unzip \
ssh \
vim \
libprotobuf-dev \
protobuf-compiler \
libopenblas-dev \
libnccl2 \
libnccl-dev \
python3-pip \
python3-dev \
python3-pyqt4 \
python3-tk \
&& rm -rf /var/lib/apt/lists/*
RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel && \
pip3 install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple \
numpy \
protobuf \
lmdb \
opencv-python \
six \
Pillow \
matplotlib \
pyyaml
RUN wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/3rdparty.zip && \
unzip 3rdparty.zip && rm 3rdparty.zip && cd 3rdparty && bash ./setup_mpi.sh && rm -f *.gz && \
cd openmpi && ls | grep -v install | xargs rm -r && cp install/bin/mpirun /usr/bin
RUN git clone https://github.com/seetaresearch/Dragon.git && \
cd Dragon/Dragon && rm CMakeLists.txt && \
wget http://dragon.seetatech.com/download/docker/ubuntu-16.04-cuda9.0-cudnn7/CMakeLists.txt && \
mkdir build && cd build && cmake .. && make install -j8 && cd .. && rm -rf build && \
cd python && python3 setup.py install
\ No newline at end of file
......@@ -30,13 +30,15 @@ set(3RDPARTY_DIR ${PROJECT_SOURCE_DIR}/../3rdparty)
# set(PYTHON_EXECUTABLE X:/Anaconda/python) # Win, Anaconda
# Set CUDA compiling architecture
set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
-gencode arch=compute_35,code=sm_35
-gencode arch=compute_50,code=sm_50
-gencode arch=compute_60,code=sm_60
-gencode arch=compute_70,code=sm_70)
# Remove "compute_70/sm_70" if using CUDA 8.0
set(CUDA_ARCH -gencode arch=compute_30,code=sm_30
-gencode arch=compute_35,code=sm_35
-gencode arch=compute_50,code=sm_50
-gencode arch=compute_60,code=sm_60
-gencode arch=compute_70,code=sm_70)
# Set CUDNN Library Dir if necessary (Linux Only)
# For Win, Recommend to use ``3RDPARTY_DIR/lib``
set(CUDNN_LIBRARY_DIR /usr/local/cuda/lib64)
# ---------------- User Config ----------------
......@@ -97,7 +99,7 @@ if (WITH_MPI)
include_directories(${3RDPARTY_DIR}/include/mpi)
endif()
# ---[ libs
# ---[ Lib Directories
set(3RDPARTY_LIBS ${3RDPARTY_DIR}/lib)
link_directories(${3RDPARTY_LIBS})
if (WITH_CUDNN)
......@@ -108,7 +110,7 @@ endif()
set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" FORCE)
set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS})
# ---[ defines
# ---[ Defines
if (WITH_PYTHON)
ADD_DEFINITIONS(-DWITH_PYTHON)
if (${PYTHON_VERSION_MAJOR} STREQUAL "2")
......@@ -184,7 +186,7 @@ endif()
# ---[ Warnings
# ---[ execute
# ---[ Commands
set (PROTOS_DIR ${PROJECT_SOURCE_DIR}/src/protos)
message(STATUS "Generate Protobuf Files")
execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS_DIR}/caffemodel.proto)
......@@ -192,7 +194,7 @@ execute_process(COMMAND protoc -I=${PROTOS_DIR} --cpp_out=${PROTOS_DIR} ${PROTOS
# ---[ Subdirectories
add_subdirectory(modules/python)
#add_subdirectory(modules/cc) # Compile CC module if necessary
#add_subdirectory(modules/cxx) # Compile CXX module if necessary
# ---[ Utils
file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/../lib)
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -15,8 +15,7 @@
#include <random>
#include <ctime>
#include "common.h"
#include "utils/logging.h"
#include "core/common.h"
#ifdef WITH_CUDA
#include "utils/cuda_device.h"
......@@ -38,7 +37,7 @@ class CPUContext {
virtual ~CPUContext() {}
inline void SwitchToDevice() {}
inline void FinishDeviceCompution() { return; }
inline void static FinishDeviceCompution() { return; }
inline static void* New(size_t nbytes) {
void* data;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -12,8 +12,8 @@
#ifndef DRAGON_CORE_CONTEXT_CUDA_H_
#define DRAGON_CORE_CONTEXT_CUDA_H_
#include "common.h"
#include "context.h"
#include "core/common.h"
#include "core/context.h"
#include "utils/cuda_device.h"
#include "utils/cudnn_device.h"
......@@ -21,8 +21,6 @@ namespace dragon {
#ifdef WITH_CUDA
#define MAX_GPUS 8
/**************************************************************************
* cuXXX libraries wrapper "Context" as "Handle".
* It's well known that each "Context" binds to some "Devices" in OpenCL.
......@@ -66,7 +64,7 @@ class CUDAObject {
class CUDAContext {
public:
CUDAContext(const DeviceOption& option)
: gpu_id_(option.gpu_id()),
: gpu_id_(option.device_id()),
random_seed_(option.has_random_seed() ? option.random_seed() : 3) {
CPUContext context(option);
CHECK_EQ(option.device_type(), CUDA);
......@@ -92,11 +90,10 @@ class CUDAContext {
cuda_object_.cur_gpu = gpu_id_;
}
void FinishDeviceCompution() {
inline static void FinishDeviceCompution() {
cudaStreamSynchronize(cudaStreamDefault);
cudaError_t error = cudaGetLastError();
CHECK_EQ(error, cudaSuccess)
<< "CUDA Error: " << cudaGetErrorString(error);
CHECK_EQ(error, cudaSuccess) << "CUDA Error: " << cudaGetErrorString(error);
}
inline static void* New(size_t nbytes) {
......@@ -210,4 +207,4 @@ class CUDAContext {
} // namespace dragon
#endif // DRAGON_CORE_CONTEXT_CUDA_H_
#endif // DRAGON_CORE_CONTEXT_CUDA_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -51,6 +51,8 @@ class Graph final : public GraphBase {
GraphDef Prune(const GraphDef& meta_graph);
GraphDef MakeUpdate(const GraphDef& meta_graph);
GraphDef Share(const GraphDef& optimized_graph);
void ShareGrads(GraphDef& optimized_graph);
void RecomputingAware(const GraphDef& optimized_graph, Workspace* ws);
inline Workspace* ws() const { return ws_; }
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -16,31 +16,30 @@
namespace dragon {
typedef pair<bool, vector<pair<string, int> > > CheckTuple;
class GraphGradientMaker {
public:
GraphGradientMaker(const GraphDef& forward_def,
const vector<string>& targets)
: cur_op_idx_(0),
forward_def_(forward_def) {
for (auto& target : targets) targets_set_.insert(target);
}
GraphGradientMaker() : cur_op_idx_(0) {}
void Make(const GraphDef& forward_def,
const vector<string>& targets,
GraphDef& new_def);
GraphDef Make();
void Share(const string& grads_prefix, GraphDef& graph);
inline void SetTerms(const Map<string, string>& terms) { terms_ = terms; }
inline void SetOperatorPrefix(const string& prefix) { op_prefix_ = prefix; }
inline void SetOperatorSuffix(const string& suffix) { op_suffix_ = suffix; }
inline void AddExternalGrad(const string& name) { external_grads_.insert(name); }
inline void AddIgnoreGrad(const string& name) { ignore_grads_.insert(name); }
private:
CheckTuple CheckMissingGrad(OperatorDef* forward_op);
bool CheckGrad(const OperatorDef& forward_op,
const Set<string>& targets,
vector< pair<string, int> >& gen_grads);
string GetOperatorName();
GraphDef forward_def_, new_def_;
Map<string, string> terms_, inputs_to_grads_;
Set<string> targets_set_, blacklist_set_, external_grads_;
Set<string> blacklist_set_, external_grads_, ignore_grads_;
string op_prefix_, op_suffix_;
int cur_op_idx_;
};
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -12,7 +12,6 @@
#ifndef DRAGON_CORE_MIXEDMEM_H_
#define DRAGON_CORE_MIXEDMEM_H_
#include "typeid.h"
#include "context.h"
#include "context_cuda.h"
......@@ -21,43 +20,43 @@ namespace dragon {
class MixedMemory {
public:
enum State { UNINITIALIZED, STATE_AT_CPU, STATE_AT_CUDA, SWITCHED, SYNCED };
MixedMemory()
: state_(UNINITIALIZED),
cpu_ptr_(nullptr), cuda_ptr_(nullptr),
nbytes_(0) {}
MixedMemory() : cpu_ptr_(nullptr), cuda_ptr_(nullptr) {}
MixedMemory(const TypeMeta& meta, const size_t nbytes)
: state_(UNINITIALIZED), meta_(meta),
cpu_ptr_(nullptr), cuda_ptr_(nullptr),
nbytes_(nbytes) {}
: meta_(meta), nbytes_(nbytes),
cpu_ptr_(nullptr), cuda_ptr_(nullptr) {}
~MixedMemory();
const void* cpu_data();
const void* cuda_data();
void* mutable_cpu_data();
void* mutable_cuda_data();
void set_cpu_data(void* cpu_ptr, size_t nbytes);
#ifdef WITH_CUDA
void async_cuda_data(const cudaStream_t& stream);
#endif
void SwitchToDevice();
void SwitchToCUDADevice(int device_id);
inline size_t nbytes() const { return nbytes_; }
inline void* cpu_ptr() { state_ = STATE_AT_CPU; return cpu_ptr_; }
inline void* cuda_ptr() { state_ = STATE_AT_CUDA; return cuda_ptr_; }
inline State state() { return state_; }
inline State state() const { return state_; }
const Map<string, string> info() const;
private:
void ToCUDA();
void ToCPU();
private:
void* cpu_ptr_, *cuda_ptr_;
State state_;
size_t nbytes_;
bool own_cpu_ptr_ = true;
State state_ = UNINITIALIZED;
size_t nbytes_ = 0;
TypeMeta meta_;
};
} // namespace dragon
#endif
#endif
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -38,12 +38,18 @@ class OperatorBase {
inline size_t InputSize() { return inputs_.size(); }
inline size_t OutputSize() { return outputs_.size(); }
void MutableOp(const OperatorDef& op_def);
void MutableOp(const vector<string>& inputs,
const vector<string>& outputs,
const string& anchor);
inline void SwitchToPhase(const string& phase) { this->phase_ = phase; }
virtual void Run() { NOT_IMPLEMENTED; }
inline const string& name() const { return op_def_.name(); }
inline const string& type() const { return op_def_.type(); }
inline const string& phase() const { return phase_; }
inline const string& anchor() { return anchor_; }
inline Workspace* ws() const { return ws_; }
template <typename T>
......@@ -60,10 +66,11 @@ class OperatorBase {
void set_recompute_map(RecomputeMap recompute_map) { recompute_map_ = recompute_map; }
inline const OperatorDef& op_def() const { return op_def_; }
inline const string DebugString() const { return op_def_.DebugString(); }
inline string DebugString() const { return op_def_.DebugString(); }
string DTypeHelper(const Tensor& tensor, const Set<string>& dtypes) const;
protected:
string phase_;
string phase_, anchor_;
Map<std::string, const Argument*> args_;
Map<string, vector<OperatorBase*> > recompute_map_;
vector<Tensor*> inputs_, outputs_;
......@@ -75,47 +82,41 @@ template <class Context>
class Operator : public OperatorBase {
public:
Operator(const OperatorDef& op_def, Workspace* ws)
: OperatorBase(op_def, ws), ctx_(op_def.device_option()) {
: OperatorBase(op_def, ws), ctx_(op_def.device_option()),
do_synchronize_(Operator::GetSingleArg<bool>("do_synchronize", false)),
recomputing_aware_(Operator::GetSingleArg<bool>("recomputing_aware", false)) {
allow_run_ = true;
allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && Output(0)->name() == "ignore"));
allow_share_grads_ = (!op_def.debug_mode());
allow_share_grads_ &= op_def.share_grads();
allow_share_grads_ &= (type().find("Gradient") != string::npos);
}
virtual void Run() final {
if (!allow_run_) return;
MakeResource();
if (recomputing_aware_) MakeResource();
ctx_.SwitchToDevice();
MemorySwitch();
RunOnDevice();
ctx_.FinishDeviceCompution();
CleanResource();
if (do_synchronize_) ctx_.FinishDeviceCompution();
if (recomputing_aware_) CleanResource();
}
virtual void ElimateCorruption();
virtual void ShareGradient();
virtual void MakeResource();
virtual void CleanResource();
void MemorySwitch() {
for (int i = 0; i < InputSize(); i++)
if (Input(i).name() != "ignore") Input(i).SwitchToDevice();
for (int i = 0; i < OutputSize(); i++)
if (Output(i)->name() != "ignore") Output(i)->SwitchToDevice();
for (auto* I : inputs_) if(I->name() != "ignore") I->SwitchToDevice();
for (auto* O : outputs_) if(O->name() != "ignore") O->SwitchToDevice();
}
virtual void RunOnDevice() = 0;
inline Context& ctx() { return ctx_; }
inline string Anchor() { return GetSingleArg("anchor", name()); }
inline bool AllowRun() { return allow_run_; }
protected:
Context ctx_;
bool allow_run_, allow_share_grads_;
bool allow_run_, recomputing_aware_, do_synchronize_;
private:
bool _MPICheck() {
......@@ -147,15 +148,16 @@ OperatorBase* CreateOperator(const OperatorDef& op_def, Workspace* ws);
using OperatorBase::name; \
using OperatorBase::type; \
using OperatorBase::phase; \
using OperatorBase::anchor; \
using OperatorBase::op_def; \
using OperatorBase::InputSize; \
using OperatorBase::OutputSize; \
using OperatorBase::DebugString \
using OperatorBase::DebugString; \
using OperatorBase::DTypeHelper \
#define USE_OPERATOR_FUNCTIONS(context) \
USE_OPERATOR_BASE_FUNCTIONS; \
using Operator<context>::ctx; \
using Operator<context>::Anchor; \
using Operator<context>::AllowRun
DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
......@@ -238,8 +240,11 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
return argument##_tensor->template data<type, CPUContext>()[0]; \
}
#define DISABLE_SHARE_GRADIENT \
this->allow_share_grads_ = false
#define GET_ARGUMENTS_SIZE(argument) \
std::max(argument##_value.size(), argument##_desc.size())
#define XIsType(x, dtype) \
x.template IsType<dtype>()
#define INSTANTIATE_OPERATOR(name, context) \
template class name##Op<context>;
......@@ -272,6 +277,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
#define DEPLOY_CUDNN(name) \
REGISTER_CUDNN_OPERATOR(name, CuDNN##name##Op<CUDAContext>); \
INSTANTIATE_CUDNN_OPERATOR(name);
} // namespace dragon
#endif // DRAGON_CORE_OPERATOR_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -15,7 +15,6 @@
#include <functional>
#include "core/common.h"
#include "utils/logging.h"
namespace dragon {
......@@ -24,11 +23,11 @@ class Registry {
public:
typedef std::function<ObjType*(Args ...)> Creator;
void Register(const SrcType& key, Creator creator) {
CHECK(!registry_.count(key)) << "Key(" << key << ") has already registered.";
CHECK(!registry_.count(key)) << "\nKey(" << key << ") has already registered.";
registry_[key] = creator;
}
ObjType* Create(const SrcType& key, Args ... args) {
CHECK(registry_.count(key)) << "Key(" << key << ") has not registered yet.";
CHECK(registry_.count(key)) << "\nKey(" << key << ") has not registered yet.";
return registry_[key](args...);
}
bool Has(const SrcType& key) { return (registry_.count(key)) != 0; }
......@@ -80,4 +79,4 @@ public:
} // namepsace dragon
#endif //DRAGON_CORE_REGISTRY_H_
#endif //DRAGON_CORE_REGISTRY_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -15,7 +15,6 @@
#include <vector>
#include "core/common.h"
#include "core/typeid.h"
#include "core/mixedmem.h"
namespace dragon {
......@@ -68,6 +67,7 @@ class Tensor {
inline const vector<TIndex>& dims() const { return dims_; }
inline TSize nbytes() const { return size_ * meta_.itemsize(); }
inline TSize capacity() const { return capacity_; }
inline TIndex count(const TIndex start, const TIndex end) const {
TIndex ret = 1;
......@@ -97,17 +97,21 @@ class Tensor {
}
inline string dim_string() const {
if (ndim() == 0) return "(0,)";
std::stringstream ss;
ss << "(";
for (int i = 0; i < ndim() - 1; i++) ss << dim(i) << ",";
ss << dim(ndim() - 1) << ")";
if (ndim() == 1) ss << dim(0) << ",)";
else ss << dim(ndim() - 1) << ")";
return ss.str();
}
inline bool is_corrupted() const { return is_corrupted_; }
inline void Corrupt() { is_corrupted_ = true; }
inline bool has_memory() const { return memory_ || ex_memory_ != nullptr; }
MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; }
void set_memory(MixedMemory* mem) { memory_.reset(mem); capacity_ = mem->nbytes(); }
MixedMemory::State memory_state() const {
MixedMemory* mem = memory();
CHECK(mem) << "\nMemory access before allowcating.";
......@@ -142,13 +146,13 @@ class Tensor {
template <class Context>
const void* const_data_ptr() const {
MixedMemory* mem = memory();
CHECK(mem) << "memory access before allowcating.";
CHECK(mem) << "\nMemory access before allowcating.";
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) {
return mem->cpu_data();
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) {
return mem->cuda_data();
} else {
LOG(FATAL) << "unknown memory type access. only CPU or CUDA are supported.";
LOG(FATAL) << "Unknown memory type. Only CPU or CUDA are supported.";
return nullptr;
}
}
......@@ -192,14 +196,25 @@ class Tensor {
template <typename T, class Context>
const T* data() const {
CHECK(meta_ == TypeMeta::Make<T>())
<< "\nThe DType of Tensor(" << name() << ") is "
<< TypeMetaToString(meta_) << ", while required "
<< TypeMetaToString(TypeMeta::Make<T>());
return static_cast<const T*>(raw_data<Context>());
}
inline void Share(const Tensor& other) {
template <class DstCTX, class SrcCTX>
inline void Copy(const Tensor& other) {
CHECK_EQ(size_, other.size_);
memory_ = other.memory_;
meta_ = other.meta_;
capacity_ = other.capacity_;
auto* src = other.template raw_data<SrcCTX>();
auto* dst = raw_mutable_data<DstCTX>();
if (dst == src) return;
if (TypeMeta::Id<DstCTX>() == TypeMeta::Id<CPUContext>()) {
CPUContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src);
} else if (TypeMeta::Id<DstCTX>() == TypeMeta::Id<CUDAContext>()) {
CUDAContext::Memcpy<DstCTX, SrcCTX>(nbytes(), dst, src);
}
}
inline void Move(MixedMemory* mem) {
......@@ -213,18 +228,22 @@ class Tensor {
meta_ = TypeMeta();
dims_.clear();
memory_.reset();
if (DECREFPyArray) DECREFPyArray();
}
std::function<void()> DECREFPyArray;
~Tensor() { /* DO NOT CALL DECREFARRAY */ }
private:
vector<TIndex> dims_;
TIndex size_ = 0, capacity_ = 0;
TypeMeta meta_;
string name_;
shared_ptr<MixedMemory> memory_, host_memory_;
shared_ptr<MixedMemory> memory_;
MixedMemory* ex_memory_ = nullptr;
bool is_corrupted_ = false, own_mem_ = true;
};
} // namespace dragon
#endif // DRAONG_CORE_TENSOR_H_
#endif // DRAONG_CORE_TENSOR_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -12,10 +12,14 @@
#ifndef DRAGON_CORE_TYPES_H_
#define DRAGON_CORE_TYPES_H_
#include <unordered_map>
#include "core/typeid.h"
namespace dragon {
#ifdef _MSC_VER
typedef struct __declspec(align(2)) {
unsigned short x;
} float16;
......@@ -24,7 +28,7 @@ typedef struct __declspec(align(4)) {
unsigned int x;
} float32;
#else
#else
typedef struct {
unsigned short x;
......@@ -36,6 +40,31 @@ typedef struct {
#endif
inline const TypeMeta& TypeStringToMeta(const std::string& str_type) {
static std::unordered_map<std::string, TypeMeta> s2m_type_map {
{ "float32", TypeMeta::Make<float>() },
{ "int32", TypeMeta::Make<int>() },
{ "int64", TypeMeta::Make<int64_t>() },
{ "float64", TypeMeta::Make<double>() },
{ "float16", TypeMeta::Make<float16>() },
{ "uint8", TypeMeta::Make<uint8_t>() }
};
static TypeMeta unknown_type;
return s2m_type_map.count(str_type) ? s2m_type_map[str_type] : unknown_type;
}
inline const std::string TypeMetaToString(const TypeMeta& meta) {
static std::unordered_map<TypeId, std::string> m2s_type_map {
{ TypeMeta::Id<float>(), "float32" },
{ TypeMeta::Id<int>(), "int32" },
{ TypeMeta::Id<int64_t>(), "int64" },
{ TypeMeta::Id<double>(), "float64", },
{ TypeMeta::Id<float16>(), "float16" },
{ TypeMeta::Id<uint8_t>(), "uint8" }
};
return m2s_type_map.count(meta.id()) ? m2s_type_map[meta.id()] : "unknown";
}
} // namespace dragon
#endif // DRAGON_CORE_TYPES_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -43,7 +43,6 @@ class DropoutGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
use_scale(OperatorBase::GetSingleArg<bool>("scale", true)) {
GET_ARGUMENT_WITH_DESC(float, prob, 0.5);
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -36,9 +36,7 @@ class EluGradientOp : public Operator<Context> {
public:
EluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {
DISABLE_SHARE_GRADIENT;
}
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -36,9 +36,7 @@ class ReluGradientOp : public Operator<Context> {
public:
ReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {
DISABLE_SHARE_GRADIENT;
}
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -19,8 +19,7 @@ namespace dragon {
template <class Context>
class SEluOp : public Operator<Context> {
public:
SEluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
USE_SIMPLE_CTOR_DTOR(SEluOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......@@ -30,10 +29,7 @@ class SEluOp : public Operator<Context> {
template <class Context>
class SEluGradientOp : public Operator<Context> {
public:
SEluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_SIMPLE_CTOR_DTOR(SEluGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -29,10 +29,7 @@ class SigmoidOp : public Operator<Context> {
template <class Context>
class SigmoidGradientOp : public Operator<Context> {
public:
SigmoidGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_SIMPLE_CTOR_DTOR(SigmoidGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -38,9 +38,7 @@ class SoftmaxGradientOp final : public Operator<Context> {
public:
SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
DISABLE_SHARE_GRADIENT;
}
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -29,10 +29,7 @@ class TanhOp : public Operator<Context> {
template <class Context>
class TanhGradientOp : public Operator<Context> {
public:
TanhGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_SIMPLE_CTOR_DTOR(TanhGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -36,7 +36,6 @@ class AddGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(AddGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
......@@ -65,7 +64,6 @@ class RAddGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(RAddGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -38,9 +38,7 @@ class BiasAddGradientOp final : public Operator<Context> {
public:
BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
DISABLE_SHARE_GRADIENT;
}
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -36,7 +36,6 @@ class DivGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(DivGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
......@@ -65,7 +64,6 @@ class RDivGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(RDivGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -44,7 +44,6 @@ class DotGradientOp final : public Operator<Context> {
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void DotRunWithType();
template <typename T> void GemmRunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -55,7 +55,6 @@ class EltwiseGradientOp final : public Operator<Context> {
}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void SumRunWithType();
template <typename T> void ProdRunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -43,7 +43,6 @@ class MatmulGradientOp final : public Operator<Context> {
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -36,7 +36,6 @@ class MulGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(MulGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
......@@ -65,7 +64,6 @@ class RMulGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(RMulGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -36,7 +36,6 @@ class SubGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(SubGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
......@@ -65,7 +64,6 @@ class RSubGradientOp final : public Operator<Context> {
USE_SIMPLE_CTOR_DTOR(RSubGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -67,16 +67,13 @@ class ScanGradientOp final: public Operator<Context> {
// handle GI(x)
for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = Output(i)->name();
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
void MakeGradientOps();
void MakeOps(const GraphDef& forward_def, GraphDef& new_def);
protected:
GraphDef forward_def, new_def;
Map<string, string> terms;
Map<int, unique_ptr<Graph>> graphs;
vector<string> forward_inputs, forward_outputs;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -40,7 +40,6 @@ class L1LossGradientOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -40,7 +40,6 @@ class L2LossGradientOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -21,17 +21,15 @@ class SmoothL1LossOp final : public Operator<Context> {
public:
SmoothL1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {
sigma2 *= sigma2;
}
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float sigma2;
float beta;
Tensor* diff, *error;
string normalization;
};
......@@ -41,17 +39,15 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
public:
SmoothL1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {
sigma2 *= sigma2;
}
beta(OperatorBase::GetSingleArg<float>("beta", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float sigma2;
float beta;
Tensor* diff;
string normalization;
};
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -23,16 +23,10 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ Input(0).name() }),
vector<string>({ "/mnt/" + Anchor() + "/softmax/prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws));
}
USE_OPERATOR_FUNCTIONS(Context);
void SoftmaxRun();
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -29,18 +29,12 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
int* ignore_data = ignore.mutable_data<int, CPUContext>();
for (int i = 0; i < args.size(); i++) ignore_data[i] = args[i];
}
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ Input(0).name() }),
vector<string>({ "/mnt/" + Anchor() + "/softmax/prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
softmax_op.reset(CreateOperator(softmax_def, ws));
}
USE_OPERATOR_FUNCTIONS(Context);
void SoftmaxRun();
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename Tx, typename Ty> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim;
......@@ -67,7 +61,7 @@ class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename Tx, typename Ty> void RunWithType();
protected:
TIndex axis, outer_dim, inner_dim;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -24,7 +24,6 @@ class GradientGenerateOp final: public Operator<Context> {
defaults(OperatorBase::GetRepeatedArg<float>("defaults")) {
CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize());
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
......@@ -42,7 +41,6 @@ class GradientGatherOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) {
for (int i = 0; i < InputSize(); i++)
if (Input(i).name() != "ignore") indices.push_back(i);
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
......@@ -56,10 +54,7 @@ class GradientGatherOp final : public Operator<Context> {
template <class Context>
class StopGradientOp final : public Operator<Context> {
public:
StopGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_SIMPLE_CTOR_DTOR(StopGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -45,9 +45,7 @@ template <class Context>
class TemplateGradientOp : public TemplateOp<Context> {
public:
TemplateGradientOp(const OperatorDef& op_def, Workspace* ws)
: TemplateOp<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
: TemplateOp<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -34,9 +34,7 @@ template <class Context>
class MPIBroadcastGradientOp final : public ModelMPIBase<Context> {
public:
MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
: ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -34,9 +34,7 @@ template <class Context>
class MPIGatherGradientOp final : public ModelMPIBase<Context> {
public:
MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
: ModelMPIBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_MPIMODEL_FUNCTIONS(Context);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -40,9 +40,7 @@ class ConcatGradientOp : public Operator<Context> {
ConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {
DISABLE_SHARE_GRADIENT;
}
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -21,12 +21,13 @@ class CropOp: public Operator<Context> {
public:
CropOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
starts(OperatorBase::GetRepeatedArg<int>("starts")),
ends(OperatorBase::GetRepeatedArg<int>("ends")),
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {}
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, starts);
GET_ARGUMENTS_WITH_DESC(int, ends);
}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -36,25 +37,21 @@ class CropOp: public Operator<Context> {
protected:
TIndex start_axis;
string shape_like;
vector<int> starts, ends, offsets, shape;
vector<int> st, ed, offsets, shape, keep_dims;
DECLARE_ARGUMENTS_WITH_DESC(int, starts);
DECLARE_ARGUMENTS_WITH_DESC(int, ends);
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
};
DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, starts);
DEFINE_ARGUMENTS_WITH_DESC(int, CropOp, ends);
template <class Context>
class CropGradientOp final : public Operator<Context > {
public:
CropGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
starts(OperatorBase::GetRepeatedArg<int>("starts")),
ends(OperatorBase::GetRepeatedArg<int>("ends")),
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
DISABLE_SHARE_GRADIENT;
}
USE_SIMPLE_CTOR_DTOR(CropGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
......@@ -62,9 +59,7 @@ class CropGradientOp final : public Operator<Context > {
template <typename T> void RunWithType();
protected:
TIndex start_axis;
string shape_like;
vector<int> starts, ends, offsets, shape;
vector<int> st, ed, offsets, keep_dims;
vector< pair<int, int> > process_axes;
TIndex axis, inner_dim, dim;
Tensor* dest, *source;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -33,10 +33,7 @@ class ExpandDimsOp final : public Operator<Context> {
template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> {
public:
ExpandDimsGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_SIMPLE_CTOR_DTOR(ExpandDimsGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -37,10 +37,7 @@ class FlattenOp final : public Operator<Context> {
template <class Context>
class FlattenGradientOp final : public Operator<Context> {
public:
FlattenGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_SIMPLE_CTOR_DTOR(FlattenGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -69,7 +69,6 @@ class PadGradientOp final : public Operator<Context> {
}
std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end());
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -40,9 +40,7 @@ class RandomPickGradientOp final : public Operator<Context> {
public:
RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {
DISABLE_SHARE_GRADIENT;
}
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -37,10 +37,7 @@ class ReshapeOp final : public Operator<Context> {
template <class Context>
class ReshapeGradientOp final : public Operator<Context> {
public:
ReshapeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
USE_SIMPLE_CTOR_DTOR(ReshapeGradientOp);
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -41,9 +41,7 @@ class SliceGradientOp final : public Operator<Context> {
SliceGradientOp(const OperatorDef& op_def, Workspace* ws):
Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {
DISABLE_SHARE_GRADIENT;
}
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -43,7 +43,6 @@ class StackGradientOp : public Operator<Context> {
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
USE_OPERATOR_FUNCTIONS(Context);
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void RunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -40,7 +40,6 @@ class TileGradientOp : public Operator<Context> {
TileGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
GET_ARGUMENTS_WITH_DESC(int, multiples);
DISABLE_SHARE_GRADIENT;
}
USE_OPERATOR_FUNCTIONS(Context);
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -118,8 +118,6 @@ class FusedBatchNormGradientOp : public Operator<Context> {
void Setup();
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -23,10 +23,7 @@ class GroupNormOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
......@@ -36,18 +33,15 @@ class GroupNormOp : public Operator<Context> {
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
template <typename T> void RunWithType();
protected:
float momentum, eps;
float eps;
Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format, mode;
int use_stats;
bool use_global_stats, is_recomputing;
string data_format;
};
template <class Context>
......@@ -56,8 +50,7 @@ class GroupNormGradientOp final : public Operator<Context> {
GroupNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
......@@ -67,8 +60,7 @@ class GroupNormGradientOp final : public Operator<Context> {
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
template <typename T> void RunWithType();
protected:
Tensor num_by_chans;
......@@ -76,8 +68,6 @@ class GroupNormGradientOp final : public Operator<Context> {
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats;
};
template <class Context>
......@@ -87,26 +77,21 @@ class FusedGroupNormOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
template <typename T> void RunWithType();
protected:
float momentum, eps;
float eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats, is_recomputing;
};
template <class Context>
......@@ -116,17 +101,13 @@ class FusedGroupNormGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))) {}
USE_OPERATOR_FUNCTIONS(Context);
void Setup();
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
template <typename T> void RunWithType();
protected:
float eps;
......@@ -135,8 +116,6 @@ class FusedGroupNormGradientOp : public Operator<Context> {
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats;
};
} // namespace dragon
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -37,9 +37,7 @@ template <class Context>
class LSTMUnitGradientOp : public Operator<Context> {
public:
LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
this->allow_share_grads_ = false;
}
: Operator<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
void RunOnDevice() override;
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -20,20 +20,14 @@ template <class Context>
class AdamUpdateOp final : public UpdateOpBase<Context> {
public:
AdamUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
t(0),
eps(Param("eps")),
beta1(Param("beta1")),
beta2(Param("beta2")) {}
: UpdateOpBase<Context>(op_def, ws), t(0) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
protected:
float lr, beta1, beta2, eps, coeff;
int t;
Tensor* m, *v, *tmp;
int t; float lr, beta1, beta2, eps;
};
} // namespace dragon
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -20,8 +20,7 @@ template <class Context>
class NesterovUpdateOp final : public UpdateOpBase<Context> {
public:
NesterovUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
momentum(Param("momentum")) {}
: UpdateOpBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
......@@ -29,7 +28,6 @@ class NesterovUpdateOp final : public UpdateOpBase<Context> {
protected:
float lr, momentum;
Tensor* h, *tmp;
};
} // namespace dragon
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -20,9 +20,7 @@ template <class Context>
class RMSPropUpdateOp final : public UpdateOpBase<Context> {
public:
RMSPropUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
eps(Param("eps")),
decay(Param("decay")) {}
: UpdateOpBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
......@@ -30,7 +28,6 @@ class RMSPropUpdateOp final : public UpdateOpBase<Context> {
protected:
float lr, decay, eps;
Tensor* h, *tmp;
};
} // namespace dragon
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -20,17 +20,14 @@ template <class Context>
class SGDUpdateOp final : public UpdateOpBase<Context> {
public:
SGDUpdateOp(const OperatorDef& op_def, Workspace* ws)
: UpdateOpBase<Context>(op_def, ws),
momentum(Param("momentum")) {}
: UpdateOpBase<Context>(op_def, ws) {}
USE_OPERATOR_FUNCTIONS(Context);
USE_UPDATER_FUNCTIONS(Context);
void ComputeRunWithFloat() override;
protected:
protected:
float lr, momentum;
Tensor* h;
};
} // namespace dragon
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -19,11 +19,14 @@ namespace dragon {
template <class Context>
class UpdateOpBase : public Operator<Context> {
public:
UpdateOpBase(const OperatorDef& op_def, Workspace* ws)
UpdateOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)),
domain(OperatorBase::GetSingleArg<string>("domain", "_")) {}
lr_mult(OperatorBase::GetSingleArg<float>("lr_mult", 1.0)),
decay_mult(OperatorBase::GetSingleArg<float>("decay_mult", 1.0)),
slot(OperatorBase::GetSingleArg<string>("slot", "")),
zero_grad(OperatorBase::GetSingleArg<bool>("zero_grad", true)) {
CHECK(!slot.empty()) << "\nRequired a non-empty slot";
}
USE_OPERATOR_FUNCTIONS(Context);
float Param(const string& name) const;
......@@ -37,7 +40,8 @@ class UpdateOpBase : public Operator<Context> {
protected:
float lr_mult, decay_mult;
float l2_decay, clip_thresh, scale_factor;
string domain;
string slot;
bool zero_grad;
};
#define USE_UPDATER_FUNCTIONS(context) \
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -95,6 +95,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
}
void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType();
protected:
......@@ -107,6 +108,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
};
template <class Context>
......@@ -151,6 +153,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
}
void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType();
protected:
......@@ -164,6 +167,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
};
#endif // WITH_CUDNN
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -98,6 +98,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
}
void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType();
protected:
......@@ -110,6 +111,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
};
template <class Context>
......@@ -154,6 +156,7 @@ public:
}
void RunOnDevice() override;
template <typename T> void ResetDesc();
template <typename T> void RunWithType();
protected:
......@@ -167,6 +170,7 @@ public:
cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size;
TIndex bias_offset, cudnn_group;
vector<TIndex> input_dims;
};
#endif // WITH_CUDNN
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......@@ -39,7 +39,6 @@ class DenseConcatGradientOp : public ConcatGradientOp<Context> {
TIndex growth_rate;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_VISION_DENSE_CONCAT_OP_H_
\ No newline at end of file
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
// ------------------------------------------------------------
// Copyright (c) 2017-preseent, SeetaTech, Co.,Ltd.
// Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
//
// Licensed under the BSD 2-Clause License.
// You should have received a copy of the BSD 2-Clause License
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!