Commit ddb76e7b by Ting PAN

add memonger for Dragon

1 parent d64a3943
Showing with 813 additions and 410 deletions
...@@ -42,6 +42,7 @@ class Graph final : public GraphBase { ...@@ -42,6 +42,7 @@ class Graph final : public GraphBase {
GraphDef Prune(const GraphDef& graph_def); GraphDef Prune(const GraphDef& graph_def);
GraphDef Share(const GraphDef& graph_def); GraphDef Share(const GraphDef& graph_def);
GraphDef MakeUpdate(const GraphDef& graph_def); GraphDef MakeUpdate(const GraphDef& graph_def);
void RecomputingAware(const GraphDef& graph_def, Workspace* ws);
inline Workspace* ws() const { return ws_; } inline Workspace* ws() const { return ws_; }
......
...@@ -80,30 +80,35 @@ class Operator : public OperatorBase { ...@@ -80,30 +80,35 @@ class Operator : public OperatorBase {
allow_run_ = true; allow_run_ = true;
allow_run_ &= _MPICheck(); allow_run_ &= _MPICheck();
allow_run_ &= (!(OutputSize() == 1 && output(0)->name() == "ignore")); allow_run_ &= (!(OutputSize() == 1 && output(0)->name() == "ignore"));
allow_share_grads_ = (!op_def.debug_mode());
allow_share_grads_ &= op_def.share_grads();
allow_share_grads_ &= (type().find("Gradient") != string::npos);
} }
virtual void Run() final { virtual void Run() final {
if (!allow_run_) return; if (!allow_run_) return;
MakeResource();
ctx_.SwitchToDevice(); ctx_.SwitchToDevice();
if (!op_def_.debug_mode()) ShareBeforeRun();
MemorySwitch(); MemorySwitch();
RunOnDevice(); RunOnDevice();
if (!op_def_.debug_mode()) ClearAfterRun();
ctx_.FinishDeviceCompution(); ctx_.FinishDeviceCompution();
CleanResource();
} }
virtual void ElimateCorruption();
virtual void ShareGradient();
virtual void MakeResource();
virtual void CleanResource();
void MemorySwitch() { void MemorySwitch() {
for (int i = 0; i < InputSize(); i++) for (int i = 0; i < InputSize(); i++)
if (input(i).name() != "ignore") if (input(i).name() != "ignore") input(i).SwitchToDevice();
input(i).SwitchToDevice();
for (int i = 0; i < OutputSize(); i++) for (int i = 0; i < OutputSize(); i++)
if (output(i)->name() != "ignore") if (output(i)->name() != "ignore") output(i)->SwitchToDevice();
output(i)->SwitchToDevice();
} }
virtual void ShareBeforeRun() { /*** share tensors here if necessary ***/ }
virtual void RunOnDevice() = 0; virtual void RunOnDevice() = 0;
virtual void ClearAfterRun() { /*** clear tensors here if necessary ***/ }
inline Context& ctx() { return ctx_; } inline Context& ctx() { return ctx_; }
inline string anchor() { return GetSingleArg("anchor", name()); } inline string anchor() { return GetSingleArg("anchor", name()); }
...@@ -111,7 +116,7 @@ class Operator : public OperatorBase { ...@@ -111,7 +116,7 @@ class Operator : public OperatorBase {
protected: protected:
Context ctx_; Context ctx_;
bool allow_run_; bool allow_run_, allow_share_grads_;
private: private:
bool _MPICheck() { bool _MPICheck() {
...@@ -169,6 +174,9 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp ...@@ -169,6 +174,9 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
} \ } \
} }
#define DISABLE_SHARE_GRADIENT \
this->allow_share_grads_ = false
#define INSTANTIATE_OPERATOR(name, context) \ #define INSTANTIATE_OPERATOR(name, context) \
template class name##Op<context>; template class name##Op<context>;
......
...@@ -30,7 +30,7 @@ class Tensor { ...@@ -30,7 +30,7 @@ class Tensor {
CHECK_GT(d, 0); CHECK_GT(d, 0);
new_size *= d; new_size *= d;
} }
if (size_ != new_size && if (size_ != new_size && own_mem_ &&
capacity_ < TIndex(new_size * meta_.itemsize())) { capacity_ < TIndex(new_size * meta_.itemsize())) {
memory_.reset(); memory_.reset();
capacity_ = 0; capacity_ = 0;
...@@ -38,9 +38,7 @@ class Tensor { ...@@ -38,9 +38,7 @@ class Tensor {
size_ = new_size; size_ = new_size;
} }
void ReshapeLike(const Tensor& other) { void ReshapeLike(const Tensor& other) { Reshape(other.dims_); }
Reshape(other.dims_);
}
inline const string& name() const { return name_; } inline const string& name() const { return name_; }
...@@ -92,63 +90,86 @@ class Tensor { ...@@ -92,63 +90,86 @@ class Tensor {
return ss.str(); return ss.str();
} }
MixedMemory::State memory_state() const { return memory_->state(); } inline bool is_corrupted() const { return is_corrupted_; }
MixedMemory* memory() const { return memory_.get(); } inline void Corrupt() { is_corrupted_ = true; }
void SwitchToDevice() { if(memory_) memory_->SwitchToDevice(); }
MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; }
MixedMemory::State memory_state() const {
MixedMemory* mem = memory();
CHECK(mem) << "memory access before allowcating.";
return memory()->state();
}
void SwitchToDevice() {
MixedMemory* mem = own_mem_ ? memory_.get() : ex_memory_;
if (mem) mem->SwitchToDevice();
}
const TypeMeta& meta() const { return meta_; } const TypeMeta& meta() const { return meta_; }
void SetMeta(const TypeMeta& meta) { meta_ = meta; } void SetMeta(const TypeMeta& meta) { meta_ = meta; }
template <typename T> inline bool IsType() { return meta_.Match<T>(); } template <typename T> inline bool IsType() { return meta_.Match<T>(); }
template <class Context> template <class Context>
const void* raw_data() const { void mutable_data_ptr(void** data_ptr) {
CHECK(memory_.get()) << "memory access before allowcating."; MixedMemory* mem = memory();
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) if (!mem) {
return memory_->cpu_data(); *data_ptr = nullptr;
else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) } else {
return memory_->cuda_data(); if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) {
else LOG(FATAL) << "unknown memory type access. only CPU or CUDA are supported."; *data_ptr = mem->mutable_cpu_data();
return nullptr; } else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) {
*data_ptr = mem->mutable_cuda_data();
} else {
LOG(FATAL) << "unknown memory type access. only CPU or CUDA are supported.";
}
} }
template <typename T, class Context>
const T* data() const {
return static_cast<const T*>(raw_data<Context>());
} }
template <class Context> template <class Context>
void active_data_ptr(void** data_ptr) { const void* const_data_ptr() const {
if (!memory_) { MixedMemory* mem = memory();
*data_ptr = nullptr; CHECK(mem) << "memory access before allowcating.";
} else {
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) { if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) {
*data_ptr = memory_->mutable_cpu_data(); return mem->cpu_data();
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) { } else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) {
*data_ptr = memory_->mutable_cuda_data(); return mem->cuda_data();
} } else {
LOG(FATAL) << "unknown memory type access. only CPU or CUDA are supported.";
return nullptr;
} }
} }
template <class Context> template <class Context>
void* raw_mutable_data(const TypeMeta& meta) { void* raw_mutable_data(const TypeMeta& meta) {
void* data_ptr; void* data_ptr;
active_data_ptr<Context>(&data_ptr); if (own_mem_) {
mutable_data_ptr<Context>(&data_ptr);
if (meta_ == meta && data_ptr) { if (meta_ == meta && data_ptr) {
return data_ptr; return data_ptr;
} else { } else {
meta_ = meta; // copy-assign the meta meta_ = meta;
CHECK_GT(size_, 0); // must specify a valid size CHECK_GT(size_, 0);
memory_.reset(new MixedMemory(meta, size_* meta_.itemsize())); memory_.reset(new MixedMemory(meta, size_* meta_.itemsize()));
// malloc mutable_data_ptr<Context>(&data_ptr); // malloc
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>())
data_ptr = memory_->mutable_cpu_data();
else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>())
data_ptr = memory_->mutable_cuda_data();
// init for each structed element if necessary
if (meta.ctor()) meta_.ctor()(data_ptr, size_); if (meta.ctor()) meta_.ctor()(data_ptr, size_);
} }
capacity_ = size_ * meta_.itemsize(); capacity_ = size_ * meta_.itemsize();
return data_ptr; return data_ptr;
} else {
meta_ = meta;
CHECK_GT(size_, 0);
TIndex ex_capacity_ = ex_memory_->nbytes();
if (ex_capacity_ >= TIndex(size_ * meta.itemsize())) {
mutable_data_ptr<Context>(&data_ptr);
} else {
delete ex_memory_;
ex_memory_ = new MixedMemory(meta, size_* meta_.itemsize());
mutable_data_ptr<Context>(&data_ptr); // malloc
if (meta.ctor()) meta_.ctor()(data_ptr, size_);
capacity_ = size_ * meta.itemsize();
}
return data_ptr;
}
} }
template <class Context> template <class Context>
...@@ -159,22 +180,30 @@ class Tensor { ...@@ -159,22 +180,30 @@ class Tensor {
return raw_mutable_data<Context>(meta_); return raw_mutable_data<Context>(meta_);
} }
template <class Context>
const void* raw_data() const { return const_data_ptr<Context>(); }
template <typename T, class Context> template <typename T, class Context>
T* mutable_data() { T* mutable_data() {
void* data_ptr; void* data_ptr;
active_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
if (data_ptr && meta_ == TypeMeta::Make<T>()) return static_cast<T*>(data_ptr); if (data_ptr && meta_ == TypeMeta::Make<T>()) return static_cast<T*>(data_ptr);
return static_cast<T*>(raw_mutable_data<Context>(TypeMeta::Make<T>())); return static_cast<T*>(raw_mutable_data<Context>(TypeMeta::Make<T>()));
} }
void Share(const Tensor& other) { template <typename T, class Context>
const T* data() const {
return static_cast<const T*>(raw_data<Context>());
}
inline void Share(const Tensor& other) {
CHECK_EQ(size_, other.size_); CHECK_EQ(size_, other.size_);
memory_ = other.memory_; memory_ = other.memory_;
meta_ = other.meta_; meta_ = other.meta_;
capacity_ = other.capacity_; capacity_ = other.capacity_;
} }
void Replace(const Tensor& other) { inline void Replace(const Tensor& other) {
memory_ = other.memory_; memory_ = other.memory_;
meta_ = other.meta_; meta_ = other.meta_;
capacity_ = other.capacity_; capacity_ = other.capacity_;
...@@ -182,23 +211,27 @@ class Tensor { ...@@ -182,23 +211,27 @@ class Tensor {
dims_ = other.dims_; dims_ = other.dims_;
} }
void Reset() { inline void Move(MixedMemory* mem) {
if (mem != nullptr) ex_memory_ = mem;
else ex_memory_ = new MixedMemory(TypeMeta::Make<float>(), 4);
own_mem_ = false;
}
inline void Reset() {
size_ = capacity_ = 0; size_ = capacity_ = 0;
meta_ = TypeMeta(); meta_ = TypeMeta();
dims_.clear(); dims_.clear();
memory_.reset(); memory_.reset();
} }
void Release() {
memory_.reset();
}
private: private:
vector<TIndex> dims_; vector<TIndex> dims_;
TIndex size_ = 0, capacity_ = 0; TIndex size_ = 0, capacity_ = 0;
TypeMeta meta_; TypeMeta meta_;
string name_; string name_;
shared_ptr<MixedMemory> memory_; shared_ptr<MixedMemory> memory_;
MixedMemory* ex_memory_ = nullptr;
bool is_corrupted_ = false, own_mem_ = true;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -13,23 +13,28 @@ ...@@ -13,23 +13,28 @@
namespace dragon { namespace dragon {
#define WORKSPACE_MIN_BUFFER_SIZE 3 #define WORKSPACE_COMMON_BUFFER_SIZE 2
#define WORKSPACE_MAX_BUFFER_SIZE 3 #define WORKSPACE_GRAD_BUFFER_SIZE 1
#define WORKSPACE_MAX_CORRUPTED_SIZE 2
class Workspace{ class Workspace{
public: public:
typedef Map<string, unique_ptr<Tensor> > TensorMap; typedef Map<string, unique_ptr<Tensor> > TensorMap;
typedef Map<string, stack<string> > BufferMap;
typedef Map<string, unique_ptr<mutex> > LockMap; typedef Map<string, unique_ptr<mutex> > LockMap;
typedef Map<string, unique_ptr<GraphBase> > GraphMap; typedef Map<string, unique_ptr<GraphBase> > GraphMap;
typedef Map<string, TensorFiller> FillerMap; typedef Map<string, TensorFiller> FillerMap;
typedef Map<string, string> RenameMap; typedef Map<string, string> RenameMap;
typedef Map<string, vector<OperatorBase*> > RecomputeMap;
Workspace(): root_folder_(".") { init(); } Workspace(): root_folder_(".") { init(); }
Workspace(string root_folder) : root_folder_(root_folder) { init(); } Workspace(string root_folder) : root_folder_(root_folder) { init(); }
~Workspace();
void init() { void init() {
CreateTensor("ignore"); CreateTensor("ignore");
for (int i = 0; i < WORKSPACE_MIN_BUFFER_SIZE; i++) CreateBuffer(); CreateBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE);
CreateBuffer("Grad", WORKSPACE_GRAD_BUFFER_SIZE);
} }
/******************** Tensor ********************/ /******************** Tensor ********************/
...@@ -101,33 +106,39 @@ class Workspace{ ...@@ -101,33 +106,39 @@ class Workspace{
/******************** Buffer ********************/ /******************** Buffer ********************/
inline Tensor* CreateBuffer() { inline void CreateBuffer(string category, int num) {
int buffer_idx = 1; CHECK(!buffer_map_.count(category));
string name; buffer_map_[category] = stack<string>();
while (1) { for (int i = 1; i <= num; i++) {
name = "_t_buffer_" + dragon_cast<string, int>(buffer_idx++); string name = "_t_" + category + "_buffer_" + dragon_cast<string, int>(i);
if (!HasTensor(name)) break; buffer_map_[category].push(name);
CreateTensor(name);
} }
buffer_stack_.push(name);
return CreateTensor(name);
} }
inline Tensor* GetBuffer() { inline Tensor* GetBuffer(string category = "Common") {
if (!buffer_stack_.empty()) { if (!buffer_map_[category].empty()) {
string name = buffer_stack_.top(); string name = buffer_map_[category].top();
buffer_stack_.pop(); buffer_map_[category].pop();
return GetTensor(name); return GetTensor(name);
} }
LOG(FATAL) << "buffers are not enough, add more if necessary."; LOG(FATAL) << "buffers of [" << category << "] "
<< "are not enough, add more if necessary.";
return nullptr; return nullptr;
} }
inline void ReleaseBuffer(Tensor* tensor, bool force_release=false) { inline void ReleaseBuffer(Tensor* tensor,
string category = "Common",
bool enforce = false) {
static Map<string, int> limits = {
{ "Common", WORKSPACE_COMMON_BUFFER_SIZE },
{ "Grad", WORKSPACE_GRAD_BUFFER_SIZE }};
if (buffer_map_[category].size() >= limits[category] || enforce) {
// release directly // release directly
if (buffer_stack_.size() >= WORKSPACE_MAX_BUFFER_SIZE || force_release) {
ReleaseTensor(tensor->name()); ReleaseTensor(tensor->name());
} else { // recover as a available buffer } else {
buffer_stack_.push(tensor->name()); // recover as a available buffer
buffer_map_[category].push(tensor->name());
} }
} }
...@@ -158,14 +169,30 @@ class Workspace{ ...@@ -158,14 +169,30 @@ class Workspace{
rename_map_[old_tensor] = new_tensor; rename_map_[old_tensor] = new_tensor;
} }
inline void AddRecompute(const string& tensor, OperatorBase* op) {
if (!recompute_map_.count(tensor)) {
recompute_map_[tensor] = vector<OperatorBase*>();
}
recompute_map_[tensor].push_back(op);
}
inline vector<OperatorBase*> GetRecompute(const string& tensor) {
if (recompute_map_.count(tensor)) {
return recompute_map_[tensor];
} else {
return vector<OperatorBase*>();
}
}
private: private:
TensorMap tensor_map_; TensorMap tensor_map_;
BufferMap buffer_map_;
LockMap lock_map_; LockMap lock_map_;
GraphMap graph_map_; GraphMap graph_map_;
FillerMap filler_map_; FillerMap filler_map_;
RenameMap rename_map_; RenameMap rename_map_;
RecomputeMap recompute_map_;
string root_folder_; string root_folder_;
stack<string> buffer_stack_;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -43,10 +43,11 @@ class DropoutGradientOp final : public Operator<Context> { ...@@ -43,10 +43,11 @@ class DropoutGradientOp final : public Operator<Context> {
threshold = static_cast<unsigned int>(UINT_MAX * prob); threshold = static_cast<unsigned int>(UINT_MAX * prob);
if (use_scale) scale = 1.0 / (1.0 - prob); if (use_scale) scale = 1.0 / (1.0 - prob);
else scale = 1.0; else scale = 1.0;
DISABLE_SHARE_GRADIENT;
} }
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override; void CleanResource() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -30,7 +30,9 @@ class ReluGradientOp : public Operator<Context> { ...@@ -30,7 +30,9 @@ class ReluGradientOp : public Operator<Context> {
public: public:
ReluGradientOp(const OperatorDef& op_def, Workspace* ws) ReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {} slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -23,7 +23,10 @@ class SigmoidOp final : public Operator<Context> { ...@@ -23,7 +23,10 @@ class SigmoidOp final : public Operator<Context> {
template <class Context> template <class Context>
class SigmoidGradientOp final : public Operator<Context> { class SigmoidGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SigmoidGradientOp); SigmoidGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -32,7 +32,9 @@ class SoftmaxGradientOp final : public Operator<Context> { ...@@ -32,7 +32,9 @@ class SoftmaxGradientOp final : public Operator<Context> {
public: public:
SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws) SoftmaxGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -23,7 +23,10 @@ class TanhOp final : public Operator<Context> { ...@@ -23,7 +23,10 @@ class TanhOp final : public Operator<Context> {
template <class Context> template <class Context>
class TanhGradientOp final : public Operator<Context> { class TanhGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(TanhGradientOp); TanhGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -29,9 +29,8 @@ class AddGradientOp final : public Operator<Context> { ...@@ -29,9 +29,8 @@ class AddGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(AddGradientOp); USE_SIMPLE_CTOR_DTOR(AddGradientOp);
void ShareBeforeRun() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
......
...@@ -35,9 +35,7 @@ class BiasAddGradientOp final : public Operator<Context> { ...@@ -35,9 +35,7 @@ class BiasAddGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void NCHWRunWithType(); template <typename T> void NCHWRunWithType();
template <typename T> void NHWCRunWithType(); template <typename T> void NHWCRunWithType();
......
...@@ -33,9 +33,7 @@ class ClipGradientOp final : public Operator<Context> { ...@@ -33,9 +33,7 @@ class ClipGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ClipGradientOp); USE_SIMPLE_CTOR_DTOR(ClipGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -29,9 +29,8 @@ class DivGradientOp final : public Operator<Context> { ...@@ -29,9 +29,8 @@ class DivGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(DivGradientOp); USE_SIMPLE_CTOR_DTOR(DivGradientOp);
void ShareBeforeRun() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
......
...@@ -37,6 +37,7 @@ class DotGradientOp final : public Operator<Context> { ...@@ -37,6 +37,7 @@ class DotGradientOp final : public Operator<Context> {
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void DotRunWithType(); template <typename T> void DotRunWithType();
template <typename T> void GemmRunWithType(); template <typename T> void GemmRunWithType();
......
...@@ -48,9 +48,8 @@ class EltwiseGradientOp final : public Operator<Context> { ...@@ -48,9 +48,8 @@ class EltwiseGradientOp final : public Operator<Context> {
} else coeffs.resize(InputSize(), float(1)); } else coeffs.resize(InputSize(), float(1));
} }
void ShareBeforeRun() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunWithType();
......
...@@ -25,9 +25,7 @@ class ExpGradientOp final : public Operator<Context> { ...@@ -25,9 +25,7 @@ class ExpGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpGradientOp); USE_SIMPLE_CTOR_DTOR(ExpGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
......
...@@ -33,9 +33,7 @@ class GramMatrixGradientOp final : public Operator<Context> { ...@@ -33,9 +33,7 @@ class GramMatrixGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)) {} axis(OperatorBase::GetSingleArg<int>("axis", 1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -39,9 +39,7 @@ class InnerProductGradientOp final : public Operator<Context> { ...@@ -39,9 +39,7 @@ class InnerProductGradientOp final : public Operator<Context> {
num_output(OperatorBase::GetSingleArg<int>("num_output", 0)), num_output(OperatorBase::GetSingleArg<int>("num_output", 0)),
transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {} transW(OperatorBase::GetSingleArg<bool>("TransW", true)) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -25,9 +25,7 @@ class LogGradientOp final : public Operator<Context> { ...@@ -25,9 +25,7 @@ class LogGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(LogGradientOp); USE_SIMPLE_CTOR_DTOR(LogGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
......
...@@ -36,9 +36,8 @@ class MatmulGradientOp final : public Operator<Context> { ...@@ -36,9 +36,8 @@ class MatmulGradientOp final : public Operator<Context> {
transA(OperatorBase::GetSingleArg<bool>("TransA", false)), transA(OperatorBase::GetSingleArg<bool>("TransA", false)),
transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {} transB(OperatorBase::GetSingleArg<bool>("TransB", false)) {}
void ShareBeforeRun() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -29,9 +29,8 @@ class MulGradientOp final : public Operator<Context> { ...@@ -29,9 +29,8 @@ class MulGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(MulGradientOp); USE_SIMPLE_CTOR_DTOR(MulGradientOp);
void ShareBeforeRun() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
......
...@@ -40,9 +40,7 @@ class PowGradientOp final : public Operator<Context> { ...@@ -40,9 +40,7 @@ class PowGradientOp final : public Operator<Context> {
power_scale = power * scale; power_scale = power * scale;
} }
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -35,9 +35,7 @@ class ScaleGradientOp final : public Operator<Context> { ...@@ -35,9 +35,7 @@ class ScaleGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {} num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void BiasRunWithType(); template <typename T> void BiasRunWithType();
template <typename T> void ScaleRunWithType(); template <typename T> void ScaleRunWithType();
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -25,9 +25,7 @@ class SquareGradientOp final : public Operator<Context> { ...@@ -25,9 +25,7 @@ class SquareGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SquareGradientOp); USE_SIMPLE_CTOR_DTOR(SquareGradientOp);
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
......
...@@ -29,9 +29,8 @@ class SubGradientOp final : public Operator<Context> { ...@@ -29,9 +29,8 @@ class SubGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(SubGradientOp); USE_SIMPLE_CTOR_DTOR(SubGradientOp);
void ShareBeforeRun() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void EltwiseRunWithType(); template <typename T> void EltwiseRunWithType();
template <typename T> void BroadcastRunWithType(int type); template <typename T> void BroadcastRunWithType(int type);
......
...@@ -34,9 +34,7 @@ class AtGradientOp final : public Operator<Context> { ...@@ -34,9 +34,7 @@ class AtGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {} acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -36,9 +36,8 @@ class ConcatGradientOp : public Operator<Context> { ...@@ -36,9 +36,8 @@ class ConcatGradientOp : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {} nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {}
void ShareBeforeRun() override; void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -57,9 +57,7 @@ class CropGradientOp final : public Operator<Context > { ...@@ -57,9 +57,7 @@ class CropGradientOp final : public Operator<Context > {
} }
void ComputeOutputShape(); void ComputeOutputShape();
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
template <typename T> void RecursiveRunWithType(vector<TIndex> idxs, template <typename T> void RecursiveRunWithType(vector<TIndex> idxs,
const vector<TIndex>& offsets, const vector<TIndex>& offsets,
......
...@@ -27,7 +27,10 @@ class ExpandDimsOp final : public Operator<Context> { ...@@ -27,7 +27,10 @@ class ExpandDimsOp final : public Operator<Context> {
template <class Context> template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> { class ExpandDimsGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ExpandDimsGradientOp); ExpandDimsGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -28,7 +28,10 @@ class FlattenOp final : public Operator<Context> { ...@@ -28,7 +28,10 @@ class FlattenOp final : public Operator<Context> {
template <class Context> template <class Context>
class FlattenGradientOp final : public Operator<Context> { class FlattenGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(FlattenGradientOp); FlattenGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -43,7 +43,9 @@ template <class Context> ...@@ -43,7 +43,9 @@ template <class Context>
class TemplateGradientOp : public TemplateOp<Context> { class TemplateGradientOp : public TemplateOp<Context> {
public: public:
TemplateGradientOp(const OperatorDef& op_def, Workspace* ws) TemplateGradientOp(const OperatorDef& op_def, Workspace* ws)
: TemplateOp<Context>(op_def, ws) {} : TemplateOp<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -39,9 +39,7 @@ class ReduceGradientOp final : public Operator<Context> { ...@@ -39,9 +39,7 @@ class ReduceGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {} operation(OperatorBase::GetSingleArg<string>("operation", "NONE")) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void SumRunWithType(); template <typename T> void SumRunWithType();
template <typename T> void MeanRunWithType(); template <typename T> void MeanRunWithType();
......
...@@ -30,7 +30,10 @@ class ReshapeOp final : public Operator<Context> { ...@@ -30,7 +30,10 @@ class ReshapeOp final : public Operator<Context> {
template <class Context> template <class Context>
class ReshapeGradientOp final : public Operator<Context> { class ReshapeGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(ReshapeGradientOp); ReshapeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -61,6 +61,8 @@ class ScanGradientOp final: public Operator<Context> { ...@@ -61,6 +61,8 @@ class ScanGradientOp final: public Operator<Context> {
// handle GI(x) // handle GI(x)
for (int i = 0; i < forward_inputs.size(); i++) for (int i = 0; i < forward_inputs.size(); i++)
terms[forward_inputs[i] + "_grad"] = output(i)->name(); terms[forward_inputs[i] + "_grad"] = output(i)->name();
DISABLE_SHARE_GRADIENT;
} }
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -35,7 +35,9 @@ class SliceGradientOp final : public Operator<Context> { ...@@ -35,7 +35,9 @@ class SliceGradientOp final : public Operator<Context> {
SliceGradientOp(const OperatorDef& op_def, Workspace* ws): SliceGradientOp(const OperatorDef& op_def, Workspace* ws):
Operator<Context>(op_def, ws), Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {} nout(OperatorBase::GetSingleArg<int>("num_output", 1)) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -43,9 +43,7 @@ class TileGradientOp : public Operator<Context> { ...@@ -43,9 +43,7 @@ class TileGradientOp : public Operator<Context> {
process_axes.push_back({ i, multiples[i] }); process_axes.push_back({ i, multiples[i] });
} }
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
protected: protected:
......
...@@ -33,9 +33,7 @@ class TransposeGradientOp final : public Operator<Context> { ...@@ -33,9 +33,7 @@ class TransposeGradientOp final : public Operator<Context> {
TransposeGradientOp(const OperatorDef& op_def, Workspace* ws) TransposeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -16,14 +16,12 @@ class L1LossOp : public Operator<Context> { ...@@ -16,14 +16,12 @@ class L1LossOp : public Operator<Context> {
public: public:
L1LossOp(const OperatorDef& op_def, Workspace* ws) L1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
coeff(OperatorBase::GetSingleArg<float>("coeff", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float coeff;
Tensor* diff; Tensor* diff;
string normalization; string normalization;
}; };
...@@ -33,14 +31,13 @@ class L1LossGradientOp final : public Operator<Context> { ...@@ -33,14 +31,13 @@ class L1LossGradientOp final : public Operator<Context> {
public: public:
L1LossGradientOp(const OperatorDef& op_def, Workspace* ws) L1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
coeff(OperatorBase::GetSingleArg<float>("coeff", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float coeff;
Tensor* diff; Tensor* diff;
string normalization; string normalization;
}; };
......
...@@ -16,14 +16,12 @@ class L2LossOp : public Operator<Context> { ...@@ -16,14 +16,12 @@ class L2LossOp : public Operator<Context> {
public: public:
L2LossOp(const OperatorDef& op_def, Workspace* ws) L2LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
coeff(OperatorBase::GetSingleArg<float>("coeff", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float coeff;
Tensor* diff; Tensor* diff;
string normalization; string normalization;
}; };
...@@ -33,14 +31,13 @@ class L2LossGradientOp final : public Operator<Context> { ...@@ -33,14 +31,13 @@ class L2LossGradientOp final : public Operator<Context> {
public: public:
L2LossGradientOp(const OperatorDef& op_def, Workspace* ws) L2LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
coeff(OperatorBase::GetSingleArg<float>("coeff", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {}
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
float coeff;
Tensor* diff; Tensor* diff;
string normalization; string normalization;
}; };
......
...@@ -4,19 +4,20 @@ ...@@ -4,19 +4,20 @@
// Written by Ting Pan // Written by Ting Pan
// -------------------------------------------------------- // --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SigmoidCrossEntropyLossOp final : public Operator<Context> { class SigmoidCrossEntropyOp final : public Operator<Context> {
public: public:
SigmoidCrossEntropyLossOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -27,9 +28,9 @@ class SigmoidCrossEntropyLossOp final : public Operator<Context> { ...@@ -27,9 +28,9 @@ class SigmoidCrossEntropyLossOp final : public Operator<Context> {
}; };
template <class Context> template <class Context>
class SigmoidCrossEntropyLossGradientOp final : public Operator<Context> { class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
public: public:
SigmoidCrossEntropyLossGradientOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
...@@ -43,4 +44,4 @@ class SigmoidCrossEntropyLossGradientOp final : public Operator<Context> { ...@@ -43,4 +44,4 @@ class SigmoidCrossEntropyLossGradientOp final : public Operator<Context> {
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SIGMOID_CROSS_ENTROPY_OP_H_
\ No newline at end of file \ No newline at end of file
...@@ -4,17 +4,17 @@ ...@@ -4,17 +4,17 @@
// Written by Ting Pan // Written by Ting Pan
// -------------------------------------------------------- // --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SoftmaxCrossEntropyLossOp final : public Operator<Context> { class SoftmaxCrossEntropyOp final : public Operator<Context> {
public: public:
SoftmaxCrossEntropyLossOp(const OperatorDef& op_def, Workspace* ws) SoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
...@@ -39,9 +39,9 @@ class SoftmaxCrossEntropyLossOp final : public Operator<Context> { ...@@ -39,9 +39,9 @@ class SoftmaxCrossEntropyLossOp final : public Operator<Context> {
}; };
template <class Context> template <class Context>
class SoftmaxCrossEntropyLossGradientOp final : public Operator<Context> { class SoftmaxCrossEntropyGradientOp final : public Operator<Context> {
public: public:
SoftmaxCrossEntropyLossGradientOp(const OperatorDef& op_def, Workspace* ws) SoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
...@@ -57,4 +57,4 @@ class SoftmaxCrossEntropyLossGradientOp final : public Operator<Context> { ...@@ -57,4 +57,4 @@ class SoftmaxCrossEntropyLossGradientOp final : public Operator<Context> {
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SOFTMAX_CROSS_ENTROPY_OP_H_
\ No newline at end of file \ No newline at end of file
...@@ -4,17 +4,17 @@ ...@@ -4,17 +4,17 @@
// Written by Ting Pan // Written by Ting Pan
// -------------------------------------------------------- // --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SOFTMAX_LOSS_OP_H_ #ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
#define DRAGON_OPERATORS_LOSS_SOFTMAX_LOSS_OP_H_ #define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class SoftmaxLossOp final : public Operator<Context> { class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
public: public:
SoftmaxLossOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
...@@ -45,9 +45,9 @@ class SoftmaxLossOp final : public Operator<Context> { ...@@ -45,9 +45,9 @@ class SoftmaxLossOp final : public Operator<Context> {
}; };
template <class Context> template <class Context>
class SoftmaxLossGradientOp final : public Operator<Context> { class SparseSoftmaxCrossEntropyGradientOp : public Operator<Context> {
public: public:
SoftmaxLossGradientOp(const OperatorDef& op_def, Workspace* ws) SparseSoftmaxCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) { normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {
...@@ -71,4 +71,4 @@ class SoftmaxLossGradientOp final : public Operator<Context> { ...@@ -71,4 +71,4 @@ class SoftmaxLossGradientOp final : public Operator<Context> {
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SOFTMAX_LOSS_OP_H_ #endif // DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_CROSS_ENTROPY_OP_H_
\ No newline at end of file \ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_
#define DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_
#include "operators/loss/sparse_softmax_cross_entropy_op.h"
namespace dragon {
template <class Context>
class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Context> {
public:
SparseSoftmaxFocalLossOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
alpha(OperatorBase::GetSingleArg<float>("alpha", 1.0)),
gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)),
use_pseudo_metric(OperatorBase::GetSingleArg<bool>("use_pseudo_metric", true)) {
if (alpha == 1.0) use_pseudo_metric = false;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float alpha, gamma;
bool use_pseudo_metric;
TIndex axis, outer_dim, inner_dim;
Tensor* scale;
string normalization;
};
template <class Context>
class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyGradientOp<Context> {
public:
SparseSoftmaxFocalLossGradientOp(const OperatorDef& op_def, Workspace* ws)
: SparseSoftmaxCrossEntropyGradientOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
float gamma, eps;
TIndex axis, outer_dim, inner_dim;
Tensor* scale;
string normalization;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_LOSS_SPARSE_SOFTMAX_FOCAL_LOSS_OP_H_
\ No newline at end of file
...@@ -27,7 +27,9 @@ template <class Context> ...@@ -27,7 +27,9 @@ template <class Context>
class MPIBroadcastGradientOp final : public ModelMPIBase<Context> { class MPIBroadcastGradientOp final : public ModelMPIBase<Context> {
public: public:
MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws) MPIBroadcastGradientOp(const OperatorDef& op_def, Workspace* ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -27,7 +27,9 @@ template <class Context> ...@@ -27,7 +27,9 @@ template <class Context>
class MPIGatherGradientOp final : public ModelMPIBase<Context> { class MPIGatherGradientOp final : public ModelMPIBase<Context> {
public: public:
MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws) MPIGatherGradientOp(const OperatorDef& op_def, Workspace *ws)
: ModelMPIBase<Context>(op_def, ws) {} : ModelMPIBase<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -30,7 +30,7 @@ class BatchNormOp : public Operator<Context> { ...@@ -30,7 +30,7 @@ class BatchNormOp : public Operator<Context> {
Tensor* num_multiplier, *spatial_multiplier, *stddev, *var; Tensor* num_multiplier, *spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans; TIndex num, channels, spatial_dim, nbychans;
int use_stats; int use_stats;
bool use_global_stats, inplace; bool use_global_stats, inplace, is_recomputing;
}; };
template <class Context> template <class Context>
...@@ -40,9 +40,7 @@ class BatchNormGradientOp final : public Operator<Context> { ...@@ -40,9 +40,7 @@ class BatchNormGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
...@@ -68,7 +66,7 @@ class BNOp : public Operator<Context> { ...@@ -68,7 +66,7 @@ class BNOp : public Operator<Context> {
protected: protected:
float momentum, eps; float momentum, eps;
int use_stats; int use_stats;
bool use_global_stats; bool use_global_stats, is_recomputing;
}; };
template <class Context> template <class Context>
...@@ -79,9 +77,8 @@ class BNGradientOp : public Operator<Context> { ...@@ -79,9 +77,8 @@ class BNGradientOp : public Operator<Context> {
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { } use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { }
void ShareBeforeRun() override; void ShareGradient() override;
void RunOnDevice() override { NOT_IMPLEMENTED; } void RunOnDevice() override { NOT_IMPLEMENTED; }
void ClearAfterRun() override;
template <typename T> void RunWithType() { NOT_IMPLEMENTED; } template <typename T> void RunWithType() { NOT_IMPLEMENTED; }
protected: protected:
...@@ -115,7 +112,7 @@ class CuDNNBNOp final : public BNOp<Context> { ...@@ -115,7 +112,7 @@ class CuDNNBNOp final : public BNOp<Context> {
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc; cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
TIndex num, channels, spatial_dim; TIndex num, channels, spatial_dim;
Tensor* mean, *var; Tensor* mean, *var;
bool use_global_stats; bool use_global_stats, is_recomputing;
}; };
template <class Context> template <class Context>
......
...@@ -36,7 +36,7 @@ class BatchRenormOp : public Operator<Context> { ...@@ -36,7 +36,7 @@ class BatchRenormOp : public Operator<Context> {
Tensor* stddev, *r, *var, *x_norm; Tensor* stddev, *r, *var, *x_norm;
TIndex num, channels, spatial_dim, nbychans; TIndex num, channels, spatial_dim, nbychans;
int use_stats; int use_stats;
bool use_global_stats, inplace; bool use_global_stats, inplace, is_recomputing;
}; };
template <class Context> template <class Context>
...@@ -46,9 +46,7 @@ class BatchRenormGradientOp final : public Operator<Context> { ...@@ -46,9 +46,7 @@ class BatchRenormGradientOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -36,9 +36,7 @@ class InstanceNormGradientOp final : public Operator<Context> { ...@@ -36,9 +36,7 @@ class InstanceNormGradientOp final : public Operator<Context> {
InstanceNormGradientOp(const OperatorDef& op_def, Workspace *ws) InstanceNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -40,9 +40,7 @@ class L2NormGradientOp final : public Operator<Context> { ...@@ -40,9 +40,7 @@ class L2NormGradientOp final : public Operator<Context> {
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {} num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -30,7 +30,10 @@ class LSTMUnitOp : public Operator<Context> { ...@@ -30,7 +30,10 @@ class LSTMUnitOp : public Operator<Context> {
template <class Context> template <class Context>
class LSTMUnitGradientOp : public Operator<Context> { class LSTMUnitGradientOp : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(LSTMUnitGradientOp); LSTMUnitGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
this->allow_share_grads_ = false;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -4,23 +4,14 @@ ...@@ -4,23 +4,14 @@
// Written by Ting Pan // Written by Ting Pan
// -------------------------------------------------------- // --------------------------------------------------------
#ifndef DRAGON_OPERATORS_COMMON_UTILS_OP_H_ #ifndef DRAGON_OPERATORS_UTILS_ACCURACY_OP_H_
#define DRAGON_OPERATORS_COMMON_UTILS_OP_H_ #define DRAGON_OPERATORS_UTILS_ACCURACY_OP_H_
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class CopyOp final: public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CopyOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class AccuracyOp final: public Operator<Context> { class AccuracyOp final: public Operator<Context> {
public: public:
AccuracyOp(const OperatorDef& op_def, Workspace* ws) AccuracyOp(const OperatorDef& op_def, Workspace* ws)
...@@ -42,22 +33,6 @@ class AccuracyOp final: public Operator<Context> { ...@@ -42,22 +33,6 @@ class AccuracyOp final: public Operator<Context> {
Tensor ignore_labels; Tensor ignore_labels;
}; };
template <class Context>
class OneHotOp final : public Operator < Context > {
public:
OneHotOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
depth(OperatorBase::GetSingleArg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex depth, on_value, off_value;
};
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_COMMON_UTILS_OP_H_ #endif // DRAGON_OPERATORS_UTILS_ACCURACY_OP_H_
\ No newline at end of file \ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UTILS_COPY_OP_H_
#define DRAGON_OPERATORS_UTILS_COPY_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class CopyOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CopyOp);
void RunOnDevice() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UTILS_COPY_OP_H_
\ No newline at end of file
...@@ -19,6 +19,7 @@ class GradientGenerateOp final: public Operator<Context> { ...@@ -19,6 +19,7 @@ class GradientGenerateOp final: public Operator<Context> {
defaults(OperatorBase::GetRepeatedArg<float>("defaults")) { defaults(OperatorBase::GetRepeatedArg<float>("defaults")) {
CHECK_EQ(InputSize(), OutputSize()); CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize()); CHECK_EQ(defaults.size(), OutputSize());
DISABLE_SHARE_GRADIENT;
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -35,6 +36,7 @@ class GradientGatherOp final : public Operator<Context> { ...@@ -35,6 +36,7 @@ class GradientGatherOp final : public Operator<Context> {
: Operator<Context>(op_def, ws) { : Operator<Context>(op_def, ws) {
for (int i = 0; i < InputSize(); i++) for (int i = 0; i < InputSize(); i++)
if (input(i).name() != "ignore") indices.push_back(i); if (input(i).name() != "ignore") indices.push_back(i);
DISABLE_SHARE_GRADIENT;
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -47,7 +49,11 @@ class GradientGatherOp final : public Operator<Context> { ...@@ -47,7 +49,11 @@ class GradientGatherOp final : public Operator<Context> {
template <class Context> template <class Context>
class StopGradientOp final : public Operator<Context> { class StopGradientOp final : public Operator<Context> {
public: public:
USE_SIMPLE_CTOR_DTOR(StopGradientOp); StopGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_UTILS_ONE_HOT_OP_H_
#define DRAGON_OPERATORS_UTILS_ONE_HOT_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class OneHotOp final : public Operator < Context > {
public:
OneHotOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
depth(OperatorBase::GetSingleArg<int>("depth", -1)),
on_value(OperatorBase::GetSingleArg<int>("on_value", 1)),
off_value(OperatorBase::GetSingleArg<int>("off_value", 0)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex depth, on_value, off_value;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_UTILS_ONE_HOT_OP_H_
\ No newline at end of file
...@@ -30,9 +30,7 @@ class ConvGradientOp : public ConvOp<Context> { ...@@ -30,9 +30,7 @@ class ConvGradientOp : public ConvOp<Context> {
ConvGradientOp(const OperatorDef& def, Workspace* ws) ConvGradientOp(const OperatorDef& def, Workspace* ws)
: ConvOp<Context>(def, ws) {} : ConvOp<Context>(def, ws) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
......
...@@ -31,9 +31,7 @@ class DeConvGradientOp : public DeConvOp<Context> { ...@@ -31,9 +31,7 @@ class DeConvGradientOp : public DeConvOp<Context> {
DeConvGradientOp(const OperatorDef& def, Workspace* ws) : DeConvGradientOp(const OperatorDef& def, Workspace* ws) :
DeConvOp<Context>(def, ws) {} DeConvOp<Context>(def, ws) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
......
...@@ -15,21 +15,21 @@ template <class Context> ...@@ -15,21 +15,21 @@ template <class Context>
class DenseConcatOp final : public ConcatOp<Context> { class DenseConcatOp final : public ConcatOp<Context> {
public: public:
DenseConcatOp(const OperatorDef& op_def, Workspace* ws) DenseConcatOp(const OperatorDef& op_def, Workspace* ws)
: ConcatOp<Context>(op_def, ws) { } : ConcatOp<Context>(op_def, ws) {}
void RunOnDevice() override;
}; };
template <class Context> template <class Context>
class DenseConcatGradientOp : public ConcatGradientOp<Context> { class DenseConcatGradientOp : public ConcatGradientOp<Context> {
public: public:
DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws) DenseConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: ConcatGradientOp<Context>(op_def, ws) {} : ConcatGradientOp<Context>(op_def, ws),
growth_rate(OperatorBase::GetSingleArg<int>("growth_rate", 0)) {}
void ElimateCorruption() override;
template <typename T> void RestoreX1();
void ShareBeforeRun() override; protected:
void RunOnDevice() override; TIndex growth_rate;
void ClearAfterRun() override;
template <typename T> void RunWithType();
}; };
......
...@@ -35,9 +35,7 @@ class NNResizeGradientOp : public Operator<Context> { ...@@ -35,9 +35,7 @@ class NNResizeGradientOp : public Operator<Context> {
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws) NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
......
...@@ -69,9 +69,7 @@ class PoolingGradientOp: public Operator<Context> { ...@@ -69,9 +69,7 @@ class PoolingGradientOp: public Operator<Context> {
} }
void Reshape(); void Reshape();
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override;
template <typename T> void MaxRunWithType(); template <typename T> void MaxRunWithType();
template <typename T> void AvgRunWithType(); template <typename T> void AvgRunWithType();
......
...@@ -44,9 +44,8 @@ class ROIAlignGradientOp : public Operator<Context> { ...@@ -44,9 +44,8 @@ class ROIAlignGradientOp : public Operator<Context> {
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override; void CleanResource() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -41,9 +41,8 @@ class ROIPoolingGradientOp final : public Operator<Context> { ...@@ -41,9 +41,8 @@ class ROIPoolingGradientOp final : public Operator<Context> {
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {} spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {}
void ShareBeforeRun() override;
void RunOnDevice() override; void RunOnDevice() override;
void ClearAfterRun() override; void CleanResource() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
......
...@@ -286,21 +286,12 @@ void TransposeGrad(const int count, ...@@ -286,21 +286,12 @@ void TransposeGrad(const int count,
const T* dy, const T* dy,
T* dx); T* dx);
/******************** common.utils ********************/
template <typename T, class Context>
void OneHot(const int count,
const int depth,
const int on_value,
const T* x,
T* y);
/******************** loss.l1_loss ********************/ /******************** loss.l1_loss ********************/
template <typename T, class Context> template <typename T, class Context>
void AbsGrad(const int count, const T* dy, T* dx); void AbsGrad(const int count, const T* dy, T* dx);
/******************** loss.sigmoid_cross_entropy_loss ********************/ /******************** loss.sigmoid_cross_entropy ********************/
template <typename T, class Context> template <typename T, class Context>
void SigmoidCrossEntropy(const int count, const T* x, const T* target, T* loss); void SigmoidCrossEntropy(const int count, const T* x, const T* target, T* loss);
...@@ -313,12 +304,12 @@ void SmoothL1(const int count, const float sigma2, const T* x, T* y); ...@@ -313,12 +304,12 @@ void SmoothL1(const int count, const float sigma2, const T* x, T* y);
template <typename T, class Context> template <typename T, class Context>
void SmoothL1Grad(const int count, const float sigma2, const T* dy, T* dx); void SmoothL1Grad(const int count, const float sigma2, const T* dy, T* dx);
/******************** loss.softmax_cross_entropy_loss ********************/ /******************** loss.softmax_cross_entropy ********************/
template <typename T, class Context> template <typename T, class Context>
void SoftmaxCrossEntropy(const int count, const T* prob, const T* target, T* loss); void SoftmaxCrossEntropy(const int count, const T* prob, const T* target, T* loss);
/******************** loss.softmax_loss ********************/ /******************** loss.sparse_softmax_cross_entropy ********************/
template <typename T, class Context> template <typename T, class Context>
void SparseSoftmaxCrossEntropy(const int count, void SparseSoftmaxCrossEntropy(const int count,
...@@ -332,12 +323,42 @@ void SparseSoftmaxCrossEntropy(const int count, ...@@ -332,12 +323,42 @@ void SparseSoftmaxCrossEntropy(const int count,
Tensor* ignore); Tensor* ignore);
template <typename T, class Context> template <typename T, class Context>
void SoftmaxLossGrad(const int count, void SparseSoftmaxCrossEntropyGrad(const int count,
const int classes, const int classes,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const T* prob,
const T* labels, const T* labels,
T* valid,
Tensor* ignore,
T* dXdata);
/******************** loss.sparse_softmax_focal_loss ********************/
template <typename T, class Context>
void SparseSoftmaxFocalLoss(const int count,
const int classes,
const int outer_dim,
const int inner_dim,
const float alpha,
const float gamma,
const T* prob, const T* prob,
const T* labels,
T* scale,
T* loss,
T* valid,
Tensor* ignore);
template <typename T, class Context>
void SparseSoftmaxFocalLossGrad(const int count,
const int classes,
const int outer_dim,
const int inner_dim,
const float gamma,
const float eps,
const T* scale,
const T* prob,
const T* labels,
T* valid, T* valid,
Tensor* ignore, Tensor* ignore,
T* dXdata); T* dXdata);
...@@ -422,6 +443,15 @@ void MemoryData(const int count, ...@@ -422,6 +443,15 @@ void MemoryData(const int count,
const Tx* x, const Tx* x,
Ty* y); Ty* y);
/******************** utils.one_hot ********************/
template <typename T, class Context>
void OneHot(const int count,
const int depth,
const int on_value,
const T* x,
T* y);
/******************** vision.conv ********************/ /******************** vision.conv ********************/
template <typename T, class Context> template <typename T, class Context>
......
...@@ -20,7 +20,11 @@ option['device'] = 'CPU' ...@@ -20,7 +20,11 @@ option['device'] = 'CPU'
option['gpu_id'] = 0 option['gpu_id'] = 0
option['use_cudnn'] = False option['use_cudnn'] = False
option['random_seed'] = 3 option['random_seed'] = 3
option['debug_mode'] = True
# if True, disable Dragon-Memonger
option['debug_mode'] = False
option['share_grads'] = False # set it by Dragon-Memonger
option['allow_mirrow_stage'] = True # default
def EnableCPU(): def EnableCPU():
global option global option
...@@ -32,8 +36,8 @@ def EnableCUDA(gpu_id=0, use_cudnn=True): ...@@ -32,8 +36,8 @@ def EnableCUDA(gpu_id=0, use_cudnn=True):
option['gpu_id'] = gpu_id option['gpu_id'] = gpu_id
option['use_cudnn'] = use_cudnn option['use_cudnn'] = use_cudnn
# TODO(Pan): please not use @setter # TODO(PhyscalX): please not use @setter
# TODO(Pan): seems that it can't change the global value # TODO(PhyscalX): seems that it can't change the global value
def SetRandomSeed(seed): def SetRandomSeed(seed):
global option global option
......
...@@ -176,6 +176,6 @@ def Restore(filename, format=0): ...@@ -176,6 +176,6 @@ def Restore(filename, format=0):
FeedTensor(key, ndarray) FeedTensor(key, ndarray)
elif format is 1: elif format is 1:
# TODO(pan): caffemodel can't save the tensor name # TODO(PhyscalX): caffemodel can't save the tensor name
# TODO(pan): we simply use 'Scope + LayerName + @paramX' # TODO(PhyscalX): we simply use 'Scope + LayerName + @paramX'
RestoreCC(filename, '', format) RestoreCC(filename, '', format)
\ No newline at end of file
# --------------------------------------------------------
# Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
def share_grads(enabled=True):
from dragon.config import option
option['share_grads'] = enabled
def drop(op_func, *args, **kwargs):
kwargs['mirrow_stage'] = True
return op_func(*args, **kwargs)
\ No newline at end of file
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import print_function
import numpy as np import numpy as np
import dragon.core.workspace as ws import dragon.core.workspace as ws
import dragon.ops as ops import dragon.ops as ops
import dragon.vm.theano as theano import dragon.vm.theano as theano
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from dragon.config import logger
""" How to custom a RunOp in Dragon """ """ How to custom a RunOp in Dragon """
...@@ -32,7 +32,7 @@ class Fetcher(Process): ...@@ -32,7 +32,7 @@ class Fetcher(Process):
self.daemon = True self.daemon = True
def cleanup(): def cleanup():
logger.info('Terminating Fetcher......') print('Terminating Fetcher......')
self.terminate() self.terminate()
self.join() self.join()
...@@ -104,4 +104,4 @@ if __name__ == '__main__': ...@@ -104,4 +104,4 @@ if __name__ == '__main__':
foo() foo()
# fetch # fetch
logger.info('y \n-------------- \n', y.get_value(), '\n') print('y \n-------------- \n', y.get_value(), '\n')
\ No newline at end of file \ No newline at end of file
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import print_function
import numpy as np import numpy as np
import dragon.core.workspace as ws import dragon.core.workspace as ws
import dragon.ops as ops import dragon.ops as ops
from dragon.core.tensor import Tensor from dragon.core.tensor import Tensor
import dragon.vm.theano.tensor as T import dragon.vm.theano.tensor as T
import dragon.vm.theano as theano import dragon.vm.theano as theano
from dragon.config import logger
""" How to custom a TemplateOp in Dragon """ """ How to custom a TemplateOp in Dragon """
...@@ -91,14 +91,14 @@ if __name__ == '__main__': ...@@ -91,14 +91,14 @@ if __name__ == '__main__':
foo = theano.function(outputs=y) foo = theano.function(outputs=y)
# feed # feed
ws.FeedTensor(x1, np.ones((5, 3))) ws.FeedTensor(x1, np.ones((5, 3), dtype=np.float32))
ws.FeedTensor(x2, np.ones((5, 3)) * 5.0) ws.FeedTensor(x2, np.ones((5, 3), dtype=np.float32) * 5.0)
# run # run
foo() foo()
# fetch # fetch
logger.info('y \n-------------- \n', y.get_value(), '\n') print('y \n-------------- \n', y.get_value(), '\n')
logger.info('dx1 \n-------------- \n', dx1.get_value(), '\n') print('dx1 \n-------------- \n', dx1.get_value(), '\n')
logger.info('dx2 \n-------------- \n', dx2.get_value(), '\n') print('dx2 \n-------------- \n', dx2.get_value(), '\n')
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from dragon.core.tensor import Tensor from dragon.core.tensor import Tensor
import numpy as np import numpy as np
def SoftmaxLoss(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwargs): def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwargs):
""" """
:param inputs: a list of Tensor contains [input, label] :param inputs: a list of Tensor contains [input, label]
:param axis a int of using which axis to compute softmax :param axis a int of using which axis to compute softmax
...@@ -17,12 +17,12 @@ def SoftmaxLoss(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwarg ...@@ -17,12 +17,12 @@ def SoftmaxLoss(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwarg
""" """
if not isinstance(inputs, list) or len(inputs) is not 2: if not isinstance(inputs, list) or len(inputs) is not 2:
raise RuntimeError('SoftmaxLoss Operator accpets a list of 2 Tensors') raise RuntimeError('SparseSoftmaxCrossEntropy Operator accpets a list of 2 Tensors')
args = locals(); kwargs = args['kwargs'] args = locals(); kwargs = args['kwargs']
del args['kwargs']; kwargs = dict(args, **kwargs) del args['kwargs']; kwargs = dict(args, **kwargs)
output = Tensor.CreateOperator(nout=1, op_type='SoftmaxLoss', **kwargs) output = Tensor.CreateOperator(nout=1, op_type='SparseSoftmaxCrossEntropy', **kwargs)
if inputs[0].shape is not None: if inputs[0].shape is not None:
if normalization != 'UNIT': output.shape = [1] if normalization != 'UNIT': output.shape = [1]
...@@ -35,7 +35,7 @@ def SoftmaxLoss(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwarg ...@@ -35,7 +35,7 @@ def SoftmaxLoss(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwarg
return output return output
def SigmoidCrossEntropyLoss(inputs, normalization='FULL', **kwargs): def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs):
""" """
:param inputs: a list of Tensor contains [input, label] :param inputs: a list of Tensor contains [input, label]
:param normalization: a str of (UNIT, FULL, BATCH_SIZE, NONE) :param normalization: a str of (UNIT, FULL, BATCH_SIZE, NONE)
...@@ -43,12 +43,12 @@ def SigmoidCrossEntropyLoss(inputs, normalization='FULL', **kwargs): ...@@ -43,12 +43,12 @@ def SigmoidCrossEntropyLoss(inputs, normalization='FULL', **kwargs):
""" """
if not isinstance(inputs, list) or len(inputs) is not 2: if not isinstance(inputs, list) or len(inputs) is not 2:
raise RuntimeError('SigmoidCrossEntropyLoss Operator accpets a list of 2 Tensors') raise RuntimeError('SigmoidCrossEntropy Operator accpets a list of 2 Tensors')
args = locals(); kwargs = args['kwargs'] args = locals(); kwargs = args['kwargs']
del args['kwargs']; kwargs = dict(args, **kwargs) del args['kwargs']; kwargs = dict(args, **kwargs)
output = Tensor.CreateOperator(nout=1, op_type='SigmoidCrossEntropyLoss', **kwargs) output = Tensor.CreateOperator(nout=1, op_type='SigmoidCrossEntropy', **kwargs)
if inputs[0].shape is not None: if inputs[0].shape is not None:
if normalization != 'UNIT': output.shape = [1] if normalization != 'UNIT': output.shape = [1]
...@@ -57,7 +57,7 @@ def SigmoidCrossEntropyLoss(inputs, normalization='FULL', **kwargs): ...@@ -57,7 +57,7 @@ def SigmoidCrossEntropyLoss(inputs, normalization='FULL', **kwargs):
return output return output
def SoftmaxCrossEntropyLoss(inputs, axis=1, normalization='FULL', **kwargs): def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs):
""" """
:param inputs: a list of Tensor contains [input, label] :param inputs: a list of Tensor contains [input, label]
:param normalization: a str of (UNIT, FULL, BATCH_SIZE, NONE) :param normalization: a str of (UNIT, FULL, BATCH_SIZE, NONE)
...@@ -65,12 +65,12 @@ def SoftmaxCrossEntropyLoss(inputs, axis=1, normalization='FULL', **kwargs): ...@@ -65,12 +65,12 @@ def SoftmaxCrossEntropyLoss(inputs, axis=1, normalization='FULL', **kwargs):
""" """
if not isinstance(inputs, list) or len(inputs) is not 2: if not isinstance(inputs, list) or len(inputs) is not 2:
raise RuntimeError('SoftmaxCrossEntropyLoss Operator accpets a list of 2 Tensors') raise RuntimeError('SoftmaxCrossEntropy Operator accpets a list of 2 Tensors')
args = locals(); kwargs = args['kwargs'] args = locals(); kwargs = args['kwargs']
del args['kwargs']; kwargs = dict(args, **kwargs) del args['kwargs']; kwargs = dict(args, **kwargs)
output = Tensor.CreateOperator(nout=1, op_type='SoftmaxCrossEntropyLoss', **kwargs) output = Tensor.CreateOperator(nout=1, op_type='SoftmaxCrossEntropy', **kwargs)
if inputs[0].shape is not None: if inputs[0].shape is not None:
if normalization != 'UNIT': output.shape = [1] if normalization != 'UNIT': output.shape = [1]
...@@ -96,7 +96,7 @@ def SmoothL1Loss(inputs, sigma=1.0, **kwargs): ...@@ -96,7 +96,7 @@ def SmoothL1Loss(inputs, sigma=1.0, **kwargs):
return output return output
def L1Loss(inputs, normalization='BATCH_SIZE', coeff=1.0, **kwargs): def L1Loss(inputs, normalization='BATCH_SIZE', **kwargs):
if not isinstance(inputs, list) or len(inputs) < 2: if not isinstance(inputs, list) or len(inputs) < 2:
raise RuntimeError('L1Loss Operator accpets a list of at least 2 Tensors') raise RuntimeError('L1Loss Operator accpets a list of at least 2 Tensors')
...@@ -109,7 +109,7 @@ def L1Loss(inputs, normalization='BATCH_SIZE', coeff=1.0, **kwargs): ...@@ -109,7 +109,7 @@ def L1Loss(inputs, normalization='BATCH_SIZE', coeff=1.0, **kwargs):
return output return output
def L2Loss(inputs, normalization='BATCH_SIZE', coeff=1.0, **kwargs): def L2Loss(inputs, normalization='BATCH_SIZE', **kwargs):
if not isinstance(inputs, list) or len(inputs) < 2: if not isinstance(inputs, list) or len(inputs) < 2:
raise RuntimeError('L2Loss Operator accpets a list of at least 2 Tensors') raise RuntimeError('L2Loss Operator accpets a list of at least 2 Tensors')
...@@ -120,3 +120,35 @@ def L2Loss(inputs, normalization='BATCH_SIZE', coeff=1.0, **kwargs): ...@@ -120,3 +120,35 @@ def L2Loss(inputs, normalization='BATCH_SIZE', coeff=1.0, **kwargs):
output = Tensor.CreateOperator(nout=1, op_type='L2Loss', **kwargs) output = Tensor.CreateOperator(nout=1, op_type='L2Loss', **kwargs)
if inputs[0].shape is not None: output.shape = [1] if inputs[0].shape is not None: output.shape = [1]
return output return output
def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=(),
alpha=0.25, gamma=2.0, eps=1e-10, use_pseudo_metric=True, **kwargs):
"""
:param inputs: a list of Tensor contains [input, label]
:param axis a int of using which axis to compute softmax
:param normalization: a str of (UNIT, FULL, VALID, BATCH_SIZE, NONE)
:param ignore_labels: a list of int contatins the labels to ignore
:param alpha a float of the alpha value
:param gamma a float of the gamma value
:param eps a float of the eps value
:return: a Tensor of loss with the shape (1,)
"""
if not isinstance(inputs, list) or len(inputs) is not 2:
raise RuntimeError('SoftmaxFocalLoss Operator accpets a list of 2 Tensors')
args = locals(); kwargs = args['kwargs']
del args['kwargs']; kwargs = dict(args, **kwargs)
output = Tensor.CreateOperator(nout=1, op_type='SparseSoftmaxFocalLoss', **kwargs)
if inputs[0].shape is not None:
if normalization != 'UNIT': output.shape = [1]
elif all(dim is not None for dim in inputs[0].shape):
outer_dim = int(np.prod(inputs[0].shape[0 : axis]))
inner_dim = int(np.prod(inputs[0].shape[axis + 1 :]))
output.shape = [outer_dim * inner_dim]
else: output.shape = [None]
return output
\ No newline at end of file
...@@ -197,7 +197,7 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs): ...@@ -197,7 +197,7 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
return output return output
def DenseConcat(inputs, axis=1, **kwargs): def DenseConcat(inputs, growth_rate, axis=1, **kwargs):
if not isinstance(inputs, list) or len(inputs) != 2: if not isinstance(inputs, list) or len(inputs) != 2:
raise RuntimeError('DenseConcat Operator accepts 2 Tensors as inputs') raise RuntimeError('DenseConcat Operator accepts 2 Tensors as inputs')
...@@ -207,6 +207,7 @@ def DenseConcat(inputs, axis=1, **kwargs): ...@@ -207,6 +207,7 @@ def DenseConcat(inputs, axis=1, **kwargs):
kwargs['num_input'] = len(inputs) kwargs['num_input'] = len(inputs)
output = Tensor.CreateOperator(nout=1, op_type='DenseConcat', **kwargs) output = Tensor.CreateOperator(nout=1, op_type='DenseConcat', **kwargs)
if all(input.shape is not None for input in inputs): if all(input.shape is not None for input in inputs):
if all(input.shape[axis] is not None for input in inputs): if all(input.shape[axis] is not None for input in inputs):
output.shape = inputs[0].shape[:] output.shape = inputs[0].shape[:]
......
...@@ -52,12 +52,13 @@ Softmax = act.Softmax ...@@ -52,12 +52,13 @@ Softmax = act.Softmax
Dropout = act.Dropout Dropout = act.Dropout
# loss # loss
SoftmaxLoss = loss.SoftmaxLoss SparseSoftmaxCrossEntropy = loss.SparseSoftmaxCrossEntropy
SigmoidCrossEntropyLoss = loss.SigmoidCrossEntropyLoss SigmoidCrossEntropy = loss.SigmoidCrossEntropy
SoftmaxCrossEntropyLoss = loss.SoftmaxCrossEntropyLoss SoftmaxCrossEntropy = loss.SoftmaxCrossEntropy
SmoothL1Loss = loss.SmoothL1Loss SmoothL1Loss = loss.SmoothL1Loss
L1Loss = loss.L1Loss L1Loss = loss.L1Loss
L2Loss = loss.L2Loss L2Loss = loss.L2Loss
SparseSoftmaxFocalLoss = loss.SparseSoftmaxFocalLoss
# arithmetic # arithmetic
Add = math.Add Add = math.Add
......
...@@ -50,6 +50,7 @@ message OperatorDef { ...@@ -50,6 +50,7 @@ message OperatorDef {
repeated Argument arg= 5; repeated Argument arg= 5;
optional DeviceOption device_option = 6; optional DeviceOption device_option = 6;
optional bool debug_mode = 7 [default = false]; optional bool debug_mode = 7 [default = false];
optional bool share_grads = 8 [default = false];
} }
message GradientTarget { message GradientTarget {
...@@ -65,7 +66,6 @@ message UpdateTarget { ...@@ -65,7 +66,6 @@ message UpdateTarget {
repeated Argument arg = 4; repeated Argument arg = 4;
} }
// simply copy from caffe1
message TensorFiller { message TensorFiller {
optional string tensor = 1; optional string tensor = 1;
optional string type = 2 [default = 'constant']; optional string type = 2 [default = 'constant'];
...@@ -89,4 +89,5 @@ message GraphDef { ...@@ -89,4 +89,5 @@ message GraphDef {
repeated GradientTarget g_target = 8; repeated GradientTarget g_target = 8;
repeated UpdateTarget u_target = 9; repeated UpdateTarget u_target = 9;
optional bool debug_mode = 10 [default = false]; optional bool debug_mode = 10 [default = false];
optional bool share_grads = 11 [default = false];
} }
\ No newline at end of file
...@@ -19,7 +19,7 @@ _sym_db = _symbol_database.Default() ...@@ -19,7 +19,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='dragon.proto', name='dragon.proto',
package='', package='',
serialized_pb=_b('\n\x0c\x64ragon.proto\"\xf7\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x05\x12/\n\tdata_type\x18\x02 \x01(\x0e\x32\x15.TensorProto.DataType:\x05\x46LOAT\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x11\n\tbyte_data\x18\x05 \x01(\x0c\x12\x13\n\x0bstring_data\x18\x06 \x03(\x0c\x12\x0c\n\x04name\x18\x07 \x01(\t\"C\n\x08\x44\x61taType\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05INT32\x10\x02\x12\x08\n\x04\x42YTE\x10\x03\x12\n\n\x06STRING\x10\x04\x12\x0b\n\x07\x46LOAT16\x10\x0c\",\n\x0cTensorProtos\x12\x1c\n\x06protos\x18\x01 \x03(\x0b\x32\x0c.TensorProto\"\x80\x01\n\x08\x41rgument\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x05\x12\x0b\n\x03i64\x18\t \x01(\x03\x12\t\n\x01s\x18\x04 \x01(\t\x12\t\n\x01\x62\x18\x08 \x01(\x08\x12\x0e\n\x06\x66loats\x18\x05 \x03(\x02\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0f\n\x07strings\x18\x07 \x03(\t\"p\n\x0c\x44\x65viceOption\x12%\n\x0b\x64\x65vice_type\x18\x01 \x01(\x0e\x32\x0b.DeviceType:\x03\x43PU\x12\x11\n\x06gpu_id\x18\x02 \x01(\x05:\x01\x30\x12\x16\n\x0brandom_seed\x18\x03 \x01(\r:\x01\x33\x12\x0e\n\x06\x65ngine\x18\x04 \x01(\t\"\xa1\x01\n\x0bOperatorDef\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x16\n\x03\x61rg\x18\x05 \x03(\x0b\x32\t.Argument\x12$\n\rdevice_option\x18\x06 \x01(\x0b\x32\r.DeviceOption\x12\x19\n\ndebug_mode\x18\x07 \x01(\x08:\x05\x66\x61lse\"=\n\x0eGradientTarget\x12\x0c\n\x04\x63ost\x18\x01 \x01(\t\x12\x0b\n\x03wrt\x18\x02 \x01(\t\x12\x10\n\x08\x65xternal\x18\x03 \x01(\t\"R\n\x0cUpdateTarget\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0e\n\x06tensor\x18\x03 \x03(\t\x12\x16\n\x03\x61rg\x18\x04 \x03(\x0b\x32\t.Argument\"\x8d\x02\n\x0cTensorFiller\x12\x0e\n\x06tensor\x18\x01 \x01(\t\x12\x16\n\x04type\x18\x02 \x01(\t:\x08\x63onstant\x12\x10\n\x05value\x18\x03 \x01(\x02:\x01\x30\x12\x0e\n\x03low\x18\x04 \x01(\x02:\x01\x30\x12\x0f\n\x04high\x18\x05 \x01(\x02:\x01\x31\x12\x0f\n\x04mean\x18\x06 \x01(\x02:\x01\x30\x12\x0e\n\x03std\x18\x07 \x01(\x02:\x01\x31\x12\x10\n\x05scale\x18\x08 \x01(\x02:\x01\x33\x12\x39\n\rvariance_norm\x18\t \x01(\x0e\x32\x1a.TensorFiller.VarianceNorm:\x06\x46\x41N_IN\"4\n\x0cVarianceNorm\x12\n\n\x06\x46\x41N_IN\x10\x00\x12\x0b\n\x07\x46\x41N_OUT\x10\x01\x12\x0b\n\x07\x46\x41N_AVG\x10\x02\"\xf3\x01\n\x08GraphDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x18\n\x02op\x18\x02 \x03(\x0b\x32\x0c.OperatorDef\x12\x12\n\ngraph_type\x18\x03 \x01(\t\x12$\n\rdevice_option\x18\x05 \x01(\x0b\x32\r.DeviceOption\x12\x16\n\x03\x61rg\x18\x06 \x03(\x0b\x32\t.Argument\x12\x0e\n\x06target\x18\x07 \x03(\t\x12!\n\x08g_target\x18\x08 \x03(\x0b\x32\x0f.GradientTarget\x12\x1f\n\x08u_target\x18\t \x03(\x0b\x32\r.UpdateTarget\x12\x19\n\ndebug_mode\x18\n \x01(\x08:\x05\x66\x61lse*+\n\nDeviceType\x12\x07\n\x03\x43PU\x10\x00\x12\x08\n\x04\x43UDA\x10\x01\x12\n\n\x06OPENCL\x10\x02') serialized_pb=_b('\n\x0c\x64ragon.proto\"\xf7\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x05\x12/\n\tdata_type\x18\x02 \x01(\x0e\x32\x15.TensorProto.DataType:\x05\x46LOAT\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x11\n\tbyte_data\x18\x05 \x01(\x0c\x12\x13\n\x0bstring_data\x18\x06 \x03(\x0c\x12\x0c\n\x04name\x18\x07 \x01(\t\"C\n\x08\x44\x61taType\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05INT32\x10\x02\x12\x08\n\x04\x42YTE\x10\x03\x12\n\n\x06STRING\x10\x04\x12\x0b\n\x07\x46LOAT16\x10\x0c\",\n\x0cTensorProtos\x12\x1c\n\x06protos\x18\x01 \x03(\x0b\x32\x0c.TensorProto\"\x80\x01\n\x08\x41rgument\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x05\x12\x0b\n\x03i64\x18\t \x01(\x03\x12\t\n\x01s\x18\x04 \x01(\t\x12\t\n\x01\x62\x18\x08 \x01(\x08\x12\x0e\n\x06\x66loats\x18\x05 \x03(\x02\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0f\n\x07strings\x18\x07 \x03(\t\"p\n\x0c\x44\x65viceOption\x12%\n\x0b\x64\x65vice_type\x18\x01 \x01(\x0e\x32\x0b.DeviceType:\x03\x43PU\x12\x11\n\x06gpu_id\x18\x02 \x01(\x05:\x01\x30\x12\x16\n\x0brandom_seed\x18\x03 \x01(\r:\x01\x33\x12\x0e\n\x06\x65ngine\x18\x04 \x01(\t\"\xbd\x01\n\x0bOperatorDef\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x16\n\x03\x61rg\x18\x05 \x03(\x0b\x32\t.Argument\x12$\n\rdevice_option\x18\x06 \x01(\x0b\x32\r.DeviceOption\x12\x19\n\ndebug_mode\x18\x07 \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0bshare_grads\x18\x08 \x01(\x08:\x05\x66\x61lse\"=\n\x0eGradientTarget\x12\x0c\n\x04\x63ost\x18\x01 \x01(\t\x12\x0b\n\x03wrt\x18\x02 \x01(\t\x12\x10\n\x08\x65xternal\x18\x03 \x01(\t\"R\n\x0cUpdateTarget\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0e\n\x06tensor\x18\x03 \x03(\t\x12\x16\n\x03\x61rg\x18\x04 \x03(\x0b\x32\t.Argument\"\x8d\x02\n\x0cTensorFiller\x12\x0e\n\x06tensor\x18\x01 \x01(\t\x12\x16\n\x04type\x18\x02 \x01(\t:\x08\x63onstant\x12\x10\n\x05value\x18\x03 \x01(\x02:\x01\x30\x12\x0e\n\x03low\x18\x04 \x01(\x02:\x01\x30\x12\x0f\n\x04high\x18\x05 \x01(\x02:\x01\x31\x12\x0f\n\x04mean\x18\x06 \x01(\x02:\x01\x30\x12\x0e\n\x03std\x18\x07 \x01(\x02:\x01\x31\x12\x10\n\x05scale\x18\x08 \x01(\x02:\x01\x33\x12\x39\n\rvariance_norm\x18\t \x01(\x0e\x32\x1a.TensorFiller.VarianceNorm:\x06\x46\x41N_IN\"4\n\x0cVarianceNorm\x12\n\n\x06\x46\x41N_IN\x10\x00\x12\x0b\n\x07\x46\x41N_OUT\x10\x01\x12\x0b\n\x07\x46\x41N_AVG\x10\x02\"\x8f\x02\n\x08GraphDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x18\n\x02op\x18\x02 \x03(\x0b\x32\x0c.OperatorDef\x12\x12\n\ngraph_type\x18\x03 \x01(\t\x12$\n\rdevice_option\x18\x05 \x01(\x0b\x32\r.DeviceOption\x12\x16\n\x03\x61rg\x18\x06 \x03(\x0b\x32\t.Argument\x12\x0e\n\x06target\x18\x07 \x03(\t\x12!\n\x08g_target\x18\x08 \x03(\x0b\x32\x0f.GradientTarget\x12\x1f\n\x08u_target\x18\t \x03(\x0b\x32\r.UpdateTarget\x12\x19\n\ndebug_mode\x18\n \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0bshare_grads\x18\x0b \x01(\x08:\x05\x66\x61lse*+\n\nDeviceType\x12\x07\n\x03\x43PU\x10\x00\x12\x08\n\x04\x43UDA\x10\x01\x12\n\n\x06OPENCL\x10\x02')
) )
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -44,8 +44,8 @@ _DEVICETYPE = _descriptor.EnumDescriptor( ...@@ -44,8 +44,8 @@ _DEVICETYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=1386, serialized_start=1442,
serialized_end=1429, serialized_end=1485,
) )
_sym_db.RegisterEnumDescriptor(_DEVICETYPE) _sym_db.RegisterEnumDescriptor(_DEVICETYPE)
...@@ -110,8 +110,8 @@ _TENSORFILLER_VARIANCENORM = _descriptor.EnumDescriptor( ...@@ -110,8 +110,8 @@ _TENSORFILLER_VARIANCENORM = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=1086, serialized_start=1114,
serialized_end=1138, serialized_end=1166,
) )
_sym_db.RegisterEnumDescriptor(_TENSORFILLER_VARIANCENORM) _sym_db.RegisterEnumDescriptor(_TENSORFILLER_VARIANCENORM)
...@@ -412,6 +412,13 @@ _OPERATORDEF = _descriptor.Descriptor( ...@@ -412,6 +412,13 @@ _OPERATORDEF = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='share_grads', full_name='OperatorDef.share_grads', index=7,
number=8, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
], ],
extensions=[ extensions=[
], ],
...@@ -424,7 +431,7 @@ _OPERATORDEF = _descriptor.Descriptor( ...@@ -424,7 +431,7 @@ _OPERATORDEF = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=558, serialized_start=558,
serialized_end=719, serialized_end=747,
) )
...@@ -467,8 +474,8 @@ _GRADIENTTARGET = _descriptor.Descriptor( ...@@ -467,8 +474,8 @@ _GRADIENTTARGET = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=721, serialized_start=749,
serialized_end=782, serialized_end=810,
) )
...@@ -518,8 +525,8 @@ _UPDATETARGET = _descriptor.Descriptor( ...@@ -518,8 +525,8 @@ _UPDATETARGET = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=784, serialized_start=812,
serialized_end=866, serialized_end=894,
) )
...@@ -605,8 +612,8 @@ _TENSORFILLER = _descriptor.Descriptor( ...@@ -605,8 +612,8 @@ _TENSORFILLER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=869, serialized_start=897,
serialized_end=1138, serialized_end=1166,
) )
...@@ -680,6 +687,13 @@ _GRAPHDEF = _descriptor.Descriptor( ...@@ -680,6 +687,13 @@ _GRAPHDEF = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='share_grads', full_name='GraphDef.share_grads', index=9,
number=11, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
], ],
extensions=[ extensions=[
], ],
...@@ -691,8 +705,8 @@ _GRAPHDEF = _descriptor.Descriptor( ...@@ -691,8 +705,8 @@ _GRAPHDEF = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=1141, serialized_start=1169,
serialized_end=1384, serialized_end=1440,
) )
_TENSORPROTO.fields_by_name['data_type'].enum_type = _TENSORPROTO_DATATYPE _TENSORPROTO.fields_by_name['data_type'].enum_type = _TENSORPROTO_DATATYPE
......
...@@ -11,7 +11,7 @@ from .vision import ConvolutionLayer, DeconvolutionLayer, PoolingLayer, \ ...@@ -11,7 +11,7 @@ from .vision import ConvolutionLayer, DeconvolutionLayer, PoolingLayer, \
from .neuron import ReLULayer, DropoutLayer, TanhLayer, PowerLayer from .neuron import ReLULayer, DropoutLayer, TanhLayer, PowerLayer
from .loss import SoftmaxWithLossLayer, SigmoidCrossEntropyLossLayer, \ from .loss import SoftmaxWithLossLayer, SigmoidCrossEntropyLossLayer, \
L2LossLayer, SmoothL1LossLayer L2LossLayer, SmoothL1LossLayer, SoftmaxWithFocalLossLayer
from .mpi import MPIBroadcastLayer, MPIGatherLayer from .mpi import MPIBroadcastLayer, MPIGatherLayer
......
...@@ -93,8 +93,9 @@ class ConcatLayer(Layer): ...@@ -93,8 +93,9 @@ class ConcatLayer(Layer):
class DenseConcatLayer(Layer): class DenseConcatLayer(Layer):
def __init__(self, LayerParameter): def __init__(self, LayerParameter):
super(DenseConcatLayer, self).__init__(LayerParameter) super(DenseConcatLayer, self).__init__(LayerParameter)
param = LayerParameter.concat_param param = LayerParameter.dense_concat_param
self._param = {'axis': param.axis} self._param = {'axis': param.axis,
'growth_rate': param.growth_rate}
def Setup(self, bottom): def Setup(self, bottom):
super(DenseConcatLayer, self).Setup(bottom) super(DenseConcatLayer, self).Setup(bottom)
...@@ -268,7 +269,7 @@ class BNLayer(Layer): ...@@ -268,7 +269,7 @@ class BNLayer(Layer):
if scale_param.HasField('filler'): if scale_param.HasField('filler'):
self.Fill(scale, scale_param, 'filler') self.Fill(scale, scale_param, 'filler')
else: scale.Constant(value=1.0) else: scale.Uniform(low=0.0, high=1.0)
self.Fill(bias, scale_param, 'bias_filler') self.Fill(bias, scale_param, 'bias_filler')
self.norm_blobs = [{'data': mean, 'diff': None}, self.norm_blobs = [{'data': mean, 'diff': None},
{'data': var, 'diff': None}] {'data': var, 'diff': None}]
......
...@@ -19,14 +19,17 @@ class Layer(object): ...@@ -19,14 +19,17 @@ class Layer(object):
self._name = LayerParameter.name self._name = LayerParameter.name
self._blobs = [] self._blobs = []
self._param = {} self._param = {}
self._mpi_param = {} self._common_param = {}
for include in LayerParameter.include: for include in LayerParameter.include:
mpi_rank = [int(rank) for rank in include.mpi_rank] mpi_rank = [int(rank) for rank in include.mpi_rank]
if len(mpi_rank) > 0: self._mpi_param['mpi_rank'] = mpi_rank if len(mpi_rank) > 0: self._common_param['mpi_rank'] = mpi_rank
if LayerParameter.HasField('mirrow_stage'):
self._common_param['mirrow_stage'] = LayerParameter.mirrow_stage
def Setup(self, bottom): def Setup(self, bottom):
self._param = dict(self._param, **self._mpi_param) self._param = dict(self._param, **self._common_param)
def Fill(self, tensor, param, filler): def Fill(self, tensor, param, filler):
""" wrapper for caffe filler """ """ wrapper for caffe filler """
......
...@@ -24,7 +24,7 @@ class SoftmaxWithLossLayer(Layer): ...@@ -24,7 +24,7 @@ class SoftmaxWithLossLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(SoftmaxWithLossLayer, self).Setup(bottom) super(SoftmaxWithLossLayer, self).Setup(bottom)
return ops.SoftmaxLoss(bottom, **self._param) return ops.SparseSoftmaxCrossEntropy(bottom, **self._param)
class SigmoidCrossEntropyLossLayer(Layer): class SigmoidCrossEntropyLossLayer(Layer):
...@@ -40,7 +40,7 @@ class SigmoidCrossEntropyLossLayer(Layer): ...@@ -40,7 +40,7 @@ class SigmoidCrossEntropyLossLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(SigmoidCrossEntropyLossLayer, self).Setup(bottom) super(SigmoidCrossEntropyLossLayer, self).Setup(bottom)
return ops.SigmoidCrossEntropyLoss(bottom, **self._param) return ops.SigmoidCrossEntropy(bottom, **self._param)
class L2LossLayer(Layer): class L2LossLayer(Layer):
...@@ -64,3 +64,27 @@ class SmoothL1LossLayer(Layer): ...@@ -64,3 +64,27 @@ class SmoothL1LossLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(SmoothL1LossLayer, self).Setup(bottom) super(SmoothL1LossLayer, self).Setup(bottom)
return ops.SmoothL1Loss(bottom, **self._param) return ops.SmoothL1Loss(bottom, **self._param)
class SoftmaxWithFocalLossLayer(Layer):
def __init__(self, LayerParameter):
super(SoftmaxWithFocalLossLayer, self).__init__(LayerParameter)
param = LayerParameter.loss_param
softmax_param = LayerParameter.softmax_param
focal_loss_param = LayerParameter.focal_loss_param
norm_mode = {0: 'FULL', 1: 'VALID', 2: 'BATCH_SIZE', 3: 'NONE'}
normalization = 'VALID'
if param.HasField('normalize'):
if not param.normalize: normalization='BATCH_SIZE'
else: normalization = norm_mode[param.normalization]
self._param = {'axis': softmax_param.axis,
'normalization': normalization,
'ignore_labels': [param.ignore_label] if param.HasField('ignore_label') else [],
'alpha': float(focal_loss_param.alpha),
'gamma': float(focal_loss_param.gamma),
'eps': float(focal_loss_param.eps),
'use_pseudo_metric': focal_loss_param.use_pseudo_metric}
def Setup(self, bottom):
super(SoftmaxWithFocalLossLayer, self).Setup(bottom)
return ops.SparseSoftmaxFocalLoss(bottom, **self._param)
...@@ -20,7 +20,6 @@ def make_if_not_exist(path): ...@@ -20,7 +20,6 @@ def make_if_not_exist(path):
os.makedirs(path) os.makedirs(path)
def UnpackVariable(var, num): def UnpackVariable(var, num):
assert len > 0
if type(var) is list and len(var) == num: if type(var) is list and len(var) == num:
return var return var
else: else:
...@@ -277,7 +276,7 @@ def VGGNetBody(net, from_layer, need_fc=True, fully_conv=False, reduced=False, ...@@ -277,7 +276,7 @@ def VGGNetBody(net, from_layer, need_fc=True, fully_conv=False, reduced=False,
dilation = 1 dilation = 1
kernel_size = 3 kernel_size = 3
pad = int((kernel_size + (dilation - 1) * (kernel_size - 1)) - 1) / 2 pad = int(int((kernel_size + (dilation - 1) * (kernel_size - 1)) - 1) / 2)
net.conv5_1 = L.Convolution(net[name], num_output=512, pad=pad, kernel_size=kernel_size, dilation=dilation, **kwargs) net.conv5_1 = L.Convolution(net[name], num_output=512, pad=pad, kernel_size=kernel_size, dilation=dilation, **kwargs)
net.relu5_1 = L.ReLU(net.conv5_1, in_place=True) net.relu5_1 = L.ReLU(net.conv5_1, in_place=True)
net.conv5_2 = L.Convolution(net.relu5_1, num_output=512, pad=pad, kernel_size=kernel_size, dilation=dilation, **kwargs) net.conv5_2 = L.Convolution(net.relu5_1, num_output=512, pad=pad, kernel_size=kernel_size, dilation=dilation, **kwargs)
...@@ -319,7 +318,7 @@ def VGGNetBody(net, from_layer, need_fc=True, fully_conv=False, reduced=False, ...@@ -319,7 +318,7 @@ def VGGNetBody(net, from_layer, need_fc=True, fully_conv=False, reduced=False,
else: else:
kernel_size = 7 kernel_size = 7
num_output = 4096 num_output = 4096
pad = int((kernel_size + (dilation - 1) * (kernel_size - 1)) - 1) / 2 pad = int(int((kernel_size + (dilation - 1) * (kernel_size - 1)) - 1) / 2)
net.fc6 = L.Convolution(net[name], num_output=num_output, pad=pad, kernel_size=kernel_size, dilation=dilation, **kwargs) net.fc6 = L.Convolution(net[name], num_output=num_output, pad=pad, kernel_size=kernel_size, dilation=dilation, **kwargs)
net.relu6 = L.ReLU(net.fc6, in_place=True) net.relu6 = L.ReLU(net.fc6, in_place=True)
......
...@@ -318,6 +318,9 @@ message LayerParameter { ...@@ -318,6 +318,9 @@ message LayerParameter {
repeated string bottom = 3; // the name of each bottom blob repeated string bottom = 3; // the name of each bottom blob
repeated string top = 4; // the name of each top blob repeated string top = 4; // the name of each top blob
// The mirrow stage optimization
optional bool mirrow_stage = 162 [default = false];
// The train / test phase for computation. // The train / test phase for computation.
optional Phase phase = 10; optional Phase phase = 10;
...@@ -418,6 +421,8 @@ message LayerParameter { ...@@ -418,6 +421,8 @@ message LayerParameter {
optional ExpandDimsParameter expand_dims_param = 159; optional ExpandDimsParameter expand_dims_param = 159;
optional ProposalParameter proposal_param = 160; optional ProposalParameter proposal_param = 160;
optional BatchRenormParameter batch_renorm_param = 161; optional BatchRenormParameter batch_renorm_param = 161;
optional DenseConcatParameter dense_concat_param = 163;
optional FocalLossParameter focal_loss_param = 164;
} }
// Message that stores parameters used to apply transformation // Message that stores parameters used to apply transformation
...@@ -1494,3 +1499,16 @@ message BatchRenormParameter { ...@@ -1494,3 +1499,16 @@ message BatchRenormParameter {
optional float d_max = 5 [default = 5.0]; optional float d_max = 5 [default = 5.0];
optional float t_delta = 6 [default = 1.0]; optional float t_delta = 6 [default = 1.0];
} }
message DenseConcatParameter {
optional int32 axis = 1 [default = 1];
optional int32 growth_rate = 2 [default = 0];
}
message FocalLossParameter {
optional float alpha = 1 [default = 1.0];
optional float gamma = 2 [default = 0.25];
optional float eps = 3 [default = 1e-10];
optional bool use_pseudo_metric = 4 [default = true];
}
This diff could not be displayed because it is too large.
...@@ -119,7 +119,7 @@ def bias_add(value, bias, data_format='NCHW', name=None): ...@@ -119,7 +119,7 @@ def bias_add(value, bias, data_format='NCHW', name=None):
def sigmoid_cross_entropy_with_logits(logits, targets, name=None): def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
return ops.SigmoidCrossEntropyLoss([logits, targets], normalization='UNIT', name=None) return ops.SigmoidCrossEntropy([logits, targets], normalization='UNIT', name=None)
def softmax_cross_entropy_with_logits(_sentinel=None, def softmax_cross_entropy_with_logits(_sentinel=None,
...@@ -131,13 +131,13 @@ def softmax_cross_entropy_with_logits(_sentinel=None, ...@@ -131,13 +131,13 @@ def softmax_cross_entropy_with_logits(_sentinel=None,
'with named arguments (labels=..., logits=..., ...)') 'with named arguments (labels=..., logits=..., ...)')
if dim == -1: dim = 1 if dim == -1: dim = 1
return ops.SoftmaxCrossEntropyLoss([logits, labels], axis=dim, normalization='UNIT', name=name) return ops.SoftmaxCrossEntropy([logits, labels], axis=dim, normalization='UNIT', name=name)
def sparse_softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None): def sparse_softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
if dim == -1: dim = 1 if dim == -1: dim = 1
return ops.SoftmaxLoss([logits, labels], axis=dim, normalization='UNIT', name=name) return ops.SparseSoftmaxCrossEntropy([logits, labels], axis=dim, normalization='UNIT', name=name)
def l2_loss(t, name=None): def l2_loss(t, name=None):
......
...@@ -77,10 +77,11 @@ def GraphDef_Update(graph_def, updater): ...@@ -77,10 +77,11 @@ def GraphDef_Update(graph_def, updater):
u_target.arg.add().CopyFrom(MakeArgument(k, v)) u_target.arg.add().CopyFrom(MakeArgument(k, v))
graph_def.u_target.extend([u_target]) graph_def.u_target.extend([u_target])
def GraphDef_Debug(graph_def): def GraphDef_Opt(graph_def):
""" generate debug mode for CC Graph """ """ generate opt options for CC Graph """
from dragon.config import option from dragon.config import option
graph_def.debug_mode = option['debug_mode'] graph_def.debug_mode = option['debug_mode']
graph_def.share_grads = option['share_grads']
def GraphDef_Device(graph_def): def GraphDef_Device(graph_def):
""" generate deivce info for CC Graph """ """ generate deivce info for CC Graph """
...@@ -155,13 +156,13 @@ def function(inputs=[], outputs=[], swaps=None, updater=None): ...@@ -155,13 +156,13 @@ def function(inputs=[], outputs=[], swaps=None, updater=None):
if len(outputs) > 0: if len(outputs) > 0:
GraphDef_Device(graph_def) GraphDef_Device(graph_def)
GraphDef_Debug(graph_def) GraphDef_Opt(graph_def)
GraphDef_Grad(graph_def, outputs) GraphDef_Grad(graph_def, outputs)
GraphDef_Phase(graph_def, outputs) GraphDef_Phase(graph_def, outputs)
elif updater is not None: elif updater is not None:
GraphDef_Device(graph_def) GraphDef_Device(graph_def)
GraphDef_Debug(graph_def) GraphDef_Opt(graph_def)
GraphDef_Update(graph_def, updater) GraphDef_Update(graph_def, updater)
# call c api to create graph # call c api to create graph
......
...@@ -257,6 +257,7 @@ GraphDef Graph::MakeUpdate(const GraphDef& graph_def) { ...@@ -257,6 +257,7 @@ GraphDef Graph::MakeUpdate(const GraphDef& graph_def) {
bool Graph::Create(const GraphDef& graph_def, Workspace* ws) { bool Graph::Create(const GraphDef& graph_def, Workspace* ws) {
bool has_device_option = graph_def.has_device_option(); bool has_device_option = graph_def.has_device_option();
bool has_debug_mode = graph_def.has_debug_mode(); bool has_debug_mode = graph_def.has_debug_mode();
bool has_share_grads = graph_def.has_share_grads();
for (const OperatorDef& plain_op_def: graph_def.op()) { for (const OperatorDef& plain_op_def: graph_def.op()) {
OperatorDef op_def(plain_op_def); OperatorDef op_def(plain_op_def);
LOG(DEBUG) << "Create Operator " << plain_op_def.name() LOG(DEBUG) << "Create Operator " << plain_op_def.name()
...@@ -270,12 +271,83 @@ bool Graph::Create(const GraphDef& graph_def, Workspace* ws) { ...@@ -270,12 +271,83 @@ bool Graph::Create(const GraphDef& graph_def, Workspace* ws) {
if (!op_def.has_debug_mode() && has_debug_mode) if (!op_def.has_debug_mode() && has_debug_mode)
op_def.set_debug_mode(graph_def.debug_mode()); op_def.set_debug_mode(graph_def.debug_mode());
// inherit share_grads if necessary
if (!op_def.has_share_grads() && has_share_grads)
op_def.set_share_grads(graph_def.share_grads());
OperatorBase* op = CreateOperator(op_def, ws); OperatorBase* op = CreateOperator(op_def, ws);
ops_.push_back(op); ops_.push_back(op);
} }
return true; return true;
} }
void Graph::RecomputingAware(const GraphDef& graph_def, Workspace* ws) {
GraphDef fake_graph(graph_def);
Map<string, vector<OperatorBase*> > fake_recompute_map;
Map<string, string> rename_map;
Map<string, Set<string> > hash_map;
Map<string, int> multi_use_count;
// check mirrow stage
for (int i = 0; i < ops_.size(); i++) {
if (ops_[i]->type().find("Gradient") != string::npos) continue;
bool mirrow_stage = ops_[i]->GetSingleArg<bool>("mirrow_stage", false);
for (auto& u : graph_def.op(i).input()) {
bool inplace_flag = false;
for (auto& v : graph_def.op(i).output()) if (u == v) inplace_flag = true;
mirrow_stage &= (!inplace_flag);
if (!inplace_flag) multi_use_count[u]++;
}
if (mirrow_stage) {
// TODO(PhyscalX): we assume that input(0)-output(0) as a force in-place currently
OperatorDef* op = fake_graph.mutable_op(i);
if (rename_map.count(op->input(0)))
*op->mutable_input(0) = rename_map[op->input(0)];
rename_map[op->output(0)] = op->input(0);
*op->mutable_output(0) = op->input(0);
ops_[i]->input(0).Corrupt(); // mark a flag
}
}
// sub-graph aware
for (int i = 0; i < ops_.size(); i++) {
if (ops_[i]->type().find("Gradient") != string::npos) continue;
OperatorDef fake_op = fake_graph.op(i);
OperatorDef op = graph_def.op(i);
for (int j = 0; j < op.output_size(); j++) {
string v = op.output(j);
string fake_v = fake_op.output(j);
if (!fake_recompute_map.count(fake_v))
fake_recompute_map[fake_v] = vector<OperatorBase*>();
if (v != fake_v) {
if (multi_use_count[fake_v] >= 2)
fake_recompute_map[fake_v] = ws->GetRecompute(fake_v);
}
fake_recompute_map[fake_v].push_back(ops_[i]);
for (int k = 0; k < fake_recompute_map[fake_v].size(); k++) {
if (!hash_map.count(v)) hash_map[v] = Set<string>();
string op_name = fake_recompute_map[fake_v][k]->name();
if (!hash_map[v].count(op_name)) {
ws->AddRecompute(v, fake_recompute_map[fake_v][k]);
hash_map[v].insert(op_name);
}
}
}
}
// prepare resources
Tensor* head = ws->CreateTensor("_t_mirrow_stage_head");
head->Reshape(vector<TIndex>(1, WORKSPACE_MAX_CORRUPTED_SIZE));
Tensor* recompute_flag = ws->CreateTensor("_t_global_recompute_flag");
recompute_flag->Reshape(vector<TIndex>(1, 1));
recompute_flag->mutable_data<bool, CPUContext>()[0] = false;
for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) {
string name = "_t_mirrow_stage_buffer_" + dragon_cast<string, int>(i);
Tensor* buffer = ws->CreateTensor(name);
head->mutable_data<string, CPUContext>()[i] = "";
}
}
Graph::Graph(const GraphDef& graph_def, Workspace* ws) Graph::Graph(const GraphDef& graph_def, Workspace* ws)
: GraphBase(graph_def, ws) { : GraphBase(graph_def, ws) {
GraphDef optimized_graph; GraphDef optimized_graph;
...@@ -297,6 +369,9 @@ Graph::Graph(const GraphDef& graph_def, Workspace* ws) ...@@ -297,6 +369,9 @@ Graph::Graph(const GraphDef& graph_def, Workspace* ws)
// create // create
Create(optimized_graph, ws); Create(optimized_graph, ws);
// recomputing-aware
RecomputingAware(optimized_graph, ws);
} }
bool Graph::Run(const string& include, const string& exclude) { bool Graph::Run(const string& include, const string& exclude) {
......
...@@ -59,6 +59,91 @@ Gradient MakeGradientForOp(const OperatorDef& def, const vector<string>& g_outpu ...@@ -59,6 +59,91 @@ Gradient MakeGradientForOp(const OperatorDef& def, const vector<string>& g_outpu
return grad; return grad;
} }
template <class Context>
void Operator<Context>::ElimateCorruption() {
Set<string> all_heads;
queue<int> safe_heads;
Tensor* head = ws()->GetTensor("_t_mirrow_stage_head");
string* head_data = head->mutable_data<string, CPUContext>();
for (int i = 0; i < head->count(); i++) all_heads.insert(head_data[i]);
// sub-graph run
for (int i = 0; i < InputSize(); i++) {
if (input(i).is_corrupted()) {
if (all_heads.count(input(i).name())) continue;
LOG(DEBUG) << "Tensor(" << input(i).name() << ") is corrupted, recompute... ";
Tensor* recompute_flag = ws()->GetTensor("_t_global_recompute_flag");
vector<OperatorBase*> list = ws()->GetRecompute(input(i).name());
recompute_flag->mutable_data<bool, CPUContext>()[0] = true;
for (int j = 0; j < list.size(); j++) list[j]->Run();
recompute_flag->mutable_data<bool, CPUContext>()[0] = false;
}
}
// check available head
all_heads.clear();
for (int i = 0; i < head->count(); i++) {
bool safe = true;
for (int j = 0; j < InputSize(); j++)
if (head_data[i] == input(j).name()) safe = false;
if (safe) safe_heads.push(i);
all_heads.insert(head_data[i]);
}
// pre-process
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->is_corrupted()) {
bool inplace_flag = false;
for (int j = 0; j < InputSize(); j++)
if (output(i)->name() == input(j).name()) inplace_flag = true;
if (inplace_flag || all_heads.count(output(i)->name())) continue; // skip to use new buffer
CHECK(!safe_heads.empty())
<< "\nat most (" << safe_heads.size() << " [safe] / "
<< all_heads.size() << " [total] can be used for corrupted output in "
<< "(" << name() << ", " << type() << "), "
<< "\nadd WORKSPACE_MAX_CORRUPTED_SIZE for more powerful mirrow stage ?";
int idx = safe_heads.front();
safe_heads.pop();
Tensor* buffer = ws()->GetTensor("_t_mirrow_stage_buffer_" + dragon_cast<string, int>(idx));
output(i)->Move(buffer->memory());
head_data[idx] = output(i)->name();
}
}
}
template <class Context>
void Operator<Context>::ShareGradient() {
// TODO(PhyscalX): we preset input(-1)->output(0) to share
if (output(0)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(0)->Replace(*dX);
}
}
template <class Context>
void Operator<Context>::MakeResource() {
ElimateCorruption();
if (allow_share_grads_) ShareGradient();
}
template <class Context>
void Operator<Context>::CleanResource() {
// post-process for mirrow stage
Map<string, int> head_to_idx;
Tensor* head = ws()->GetTensor("_t_mirrow_stage_head");
string* head_data = head->mutable_data<string, CPUContext>();
for (int i = 0; i < head->count(); i++) head_to_idx[head_data[i]] = i;
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->is_corrupted() && head_to_idx.count(output(i)->name())) {
string used = "_t_mirrow_stage_buffer_" + dragon_cast<string, int>(head_to_idx[output(i)->name()]);
Tensor* buffer = ws()->GetTensor(used);
if (output(i)->memory() != buffer->memory()) buffer->Move(output(i)->memory());
}
}
if (allow_share_grads_) {
// TODO(PhyscalX): we preset input(-1)->output(0) to share
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY, "Grad");
}
}
DEFINE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*); DEFINE_REGISTRY(CPUOperatorRegistry, OperatorBase,const OperatorDef&, Workspace*);
DEFINE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); DEFINE_REGISTRY(CUDAOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
DEFINE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*); DEFINE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Workspace*);
...@@ -94,4 +179,13 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints) ...@@ -94,4 +179,13 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints)
INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings) INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT #undef INSTANTIATE_GET_REPEATED_ARGUMENT
template void Operator<CPUContext>::ElimateCorruption();
template void Operator<CUDAContext>::ElimateCorruption();
template void Operator<CPUContext>::ShareGradient();
template void Operator<CUDAContext>::ShareGradient();
template void Operator<CPUContext>::MakeResource();
template void Operator<CUDAContext>::MakeResource();
template void Operator<CPUContext>::CleanResource();
template void Operator<CUDAContext>::CleanResource();
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -13,4 +13,14 @@ GraphBase* Workspace::CreateGraph(const GraphDef& graph_def) { ...@@ -13,4 +13,14 @@ GraphBase* Workspace::CreateGraph(const GraphDef& graph_def) {
return graph_map_[graph_def.name()].get(); return graph_map_[graph_def.name()].get();
} }
Workspace::~Workspace() {
for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) {
string name = "_t_mirrow_stage_buffer_" + dragon_cast<string, int>(i);
if (HasTensor(name)) {
MixedMemory* mem = GetTensor(name)->memory();
if (mem != nullptr) delete mem;
}
}
}
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -70,8 +70,9 @@ void DropoutGradientOp<Context>::RunOnDevice() { ...@@ -70,8 +70,9 @@ void DropoutGradientOp<Context>::RunOnDevice() {
} }
template <class Context> template <class Context>
void DropoutGradientOp<Context>::ClearAfterRun() { void DropoutGradientOp<Context>::CleanResource() {
ws()->ReleaseBuffer(mask, true); Operator<Context>::CleanResource();
ws()->ReleaseBuffer(mask, "Common", true);
} }
DEPLOY_CPU(DropoutGradient); DEPLOY_CPU(DropoutGradient);
...@@ -81,7 +82,7 @@ DEPLOY_CUDA(DropoutGradient); ...@@ -81,7 +82,7 @@ DEPLOY_CUDA(DropoutGradient);
OPERATOR_SCHEMA(DropoutGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } }); OPERATOR_SCHEMA(DropoutGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
class GetDropoutGradient final : public GradientMakerBase { class GetDropoutGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetDropoutGradient); GRADIENT_MAKER_CTOR(GetDropoutGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
...@@ -92,4 +93,3 @@ public: ...@@ -92,4 +93,3 @@ public:
REGISTER_GRADIENT(Dropout, GetDropoutGradient); REGISTER_GRADIENT(Dropout, GetDropoutGradient);
} // namepsace dragon } // namepsace dragon
\ No newline at end of file
...@@ -48,7 +48,7 @@ DEPLOY_CUDA(ReluGradient); ...@@ -48,7 +48,7 @@ DEPLOY_CUDA(ReluGradient);
OPERATOR_SCHEMA(ReluGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 }}); OPERATOR_SCHEMA(ReluGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 }});
class GetReluGradient final : public GradientMakerBase { class GetReluGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetReluGradient); GRADIENT_MAKER_CTOR(GetReluGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -48,7 +48,7 @@ DEPLOY_CUDA(SigmoidGradient); ...@@ -48,7 +48,7 @@ DEPLOY_CUDA(SigmoidGradient);
OPERATOR_SCHEMA(SigmoidGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } }); OPERATOR_SCHEMA(SigmoidGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
class GetSigmoidGradient final : public GradientMakerBase { class GetSigmoidGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetSigmoidGradient); GRADIENT_MAKER_CTOR(GetSigmoidGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -71,7 +71,7 @@ DEPLOY_CUDA(SoftmaxGradient); ...@@ -71,7 +71,7 @@ DEPLOY_CUDA(SoftmaxGradient);
OPERATOR_SCHEMA(SoftmaxGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } }); OPERATOR_SCHEMA(SoftmaxGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
class GetSoftmaxGradient final : public GradientMakerBase { class GetSoftmaxGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetSoftmaxGradient); GRADIENT_MAKER_CTOR(GetSoftmaxGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -48,7 +48,7 @@ DEPLOY_CUDA(TanhGradient); ...@@ -48,7 +48,7 @@ DEPLOY_CUDA(TanhGradient);
OPERATOR_SCHEMA(TanhGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } }); OPERATOR_SCHEMA(TanhGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
class GetTanhGradient final : public GradientMakerBase { class GetTanhGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetTanhGradient); GRADIENT_MAKER_CTOR(GetTanhGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
...@@ -59,4 +59,3 @@ public: ...@@ -59,4 +59,3 @@ public:
REGISTER_GRADIENT(Tanh, GetTanhGradient); REGISTER_GRADIENT(Tanh, GetTanhGradient);
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -160,22 +160,16 @@ void AddGradientOp<Context>::RunOnDevice() { ...@@ -160,22 +160,16 @@ void AddGradientOp<Context>::RunOnDevice() {
} }
template <class Context> template <class Context>
void AddGradientOp<Context>::ShareBeforeRun() { void AddGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) { for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") { if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer(); Tensor* dX = ws()->GetBuffer("Grad");
if (dX != nullptr) output(i)->Replace(*dX); output(i)->Replace(*dX);
break; break;
} }
} }
} }
template <class Context>
void AddGradientOp<Context>::ClearAfterRun() {
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY);
}
DEPLOY_CPU(AddGradient); DEPLOY_CPU(AddGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(AddGradient); DEPLOY_CUDA(AddGradient);
...@@ -183,7 +177,7 @@ DEPLOY_CUDA(AddGradient); ...@@ -183,7 +177,7 @@ DEPLOY_CUDA(AddGradient);
OPERATOR_SCHEMA(AddGradient).NumInputs(2).NumOutputs(2); OPERATOR_SCHEMA(AddGradient).NumInputs(2).NumOutputs(2);
class GetAddGradient : public GradientMakerBase { class GetAddGradient : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetAddGradient); GRADIENT_MAKER_CTOR(GetAddGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -95,18 +95,6 @@ void BiasAddGradientOp<Context>::RunOnDevice() { ...@@ -95,18 +95,6 @@ void BiasAddGradientOp<Context>::RunOnDevice() {
} }
} }
template <class Context>
void BiasAddGradientOp<Context>::ShareBeforeRun() {
Tensor* dX = ws()->GetBuffer();
if (dX != nullptr) output(0)->Replace(*dX);
}
template <class Context>
void BiasAddGradientOp<Context>::ClearAfterRun() {
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY);
}
DEPLOY_CPU(BiasAddGradient); DEPLOY_CPU(BiasAddGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(BiasAddGradient); DEPLOY_CUDA(BiasAddGradient);
...@@ -114,7 +102,7 @@ DEPLOY_CUDA(BiasAddGradient); ...@@ -114,7 +102,7 @@ DEPLOY_CUDA(BiasAddGradient);
OPERATOR_SCHEMA(BiasAddGradient).NumInputs(3).NumOutputs(2); OPERATOR_SCHEMA(BiasAddGradient).NumInputs(3).NumOutputs(2);
class GetBiasAddGradient final : public GradientMakerBase { class GetBiasAddGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetBiasAddGradient); GRADIENT_MAKER_CTOR(GetBiasAddGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -45,18 +45,6 @@ void ClipGradientOp<Context>::RunOnDevice() { ...@@ -45,18 +45,6 @@ void ClipGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "unsupported input types."; else LOG(FATAL) << "unsupported input types.";
} }
template <class Context>
void ClipGradientOp<Context>::ShareBeforeRun() {
Tensor* dX = ws()->GetBuffer();
if (dX != nullptr) output(0)->Replace(*dX);
}
template <class Context>
void ClipGradientOp<Context>::ClearAfterRun() {
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY);
}
DEPLOY_CPU(ClipGradient); DEPLOY_CPU(ClipGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ClipGradient); DEPLOY_CUDA(ClipGradient);
...@@ -64,7 +52,7 @@ DEPLOY_CUDA(ClipGradient); ...@@ -64,7 +52,7 @@ DEPLOY_CUDA(ClipGradient);
OPERATOR_SCHEMA(ClipGradient).NumInputs(2).NumOutputs(1); OPERATOR_SCHEMA(ClipGradient).NumInputs(2).NumOutputs(1);
class GetClipGradient final : public GradientMakerBase { class GetClipGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetClipGradient); GRADIENT_MAKER_CTOR(GetClipGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -191,22 +191,16 @@ void DivGradientOp<Context>::RunOnDevice() { ...@@ -191,22 +191,16 @@ void DivGradientOp<Context>::RunOnDevice() {
} }
template <class Context> template <class Context>
void DivGradientOp<Context>::ShareBeforeRun() { void DivGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) { for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") { if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer(); Tensor* dX = ws()->GetBuffer("Grad");
if (dX != nullptr) output(i)->Replace(*dX); output(i)->Replace(*dX);
break; break;
} }
} }
} }
template <class Context>
void DivGradientOp<Context>::ClearAfterRun() {
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY);
}
DEPLOY_CPU(DivGradient); DEPLOY_CPU(DivGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(DivGradient); DEPLOY_CUDA(DivGradient);
...@@ -214,7 +208,7 @@ DEPLOY_CUDA(DivGradient); ...@@ -214,7 +208,7 @@ DEPLOY_CUDA(DivGradient);
OPERATOR_SCHEMA(DivGradient).NumInputs(3).NumOutputs(2); OPERATOR_SCHEMA(DivGradient).NumInputs(3).NumOutputs(2);
class GetDivGradient final : public GradientMakerBase { class GetDivGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetDivGradient); GRADIENT_MAKER_CTOR(GetDivGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
#include "operators/arithmetic/dot_op.h" #include "operators/arithmetic/dot_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h" #include "utils/math_functions.h"
namespace dragon { namespace dragon {
...@@ -169,6 +170,17 @@ void DotGradientOp<Context>::RunOnDevice() { ...@@ -169,6 +170,17 @@ void DotGradientOp<Context>::RunOnDevice() {
} }
} }
template <class Context>
void DotGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
break;
}
}
}
DEPLOY_CPU(DotGradient); DEPLOY_CPU(DotGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(DotGradient); DEPLOY_CUDA(DotGradient);
...@@ -176,7 +188,7 @@ DEPLOY_CUDA(DotGradient); ...@@ -176,7 +188,7 @@ DEPLOY_CUDA(DotGradient);
OPERATOR_SCHEMA(DotGradient).NumInputs(3).NumOutputs(2); OPERATOR_SCHEMA(DotGradient).NumInputs(3).NumOutputs(2);
class GetDotGradient final : public GradientMakerBase { class GetDotGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetDotGradient); GRADIENT_MAKER_CTOR(GetDotGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -118,22 +118,16 @@ void EltwiseGradientOp<Context>::RunOnDevice() { ...@@ -118,22 +118,16 @@ void EltwiseGradientOp<Context>::RunOnDevice() {
} }
template <class Context> template <class Context>
void EltwiseGradientOp<Context>::ShareBeforeRun() { void EltwiseGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) { for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") { if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer(); Tensor* dX = ws()->GetBuffer("Grad");
if (dX != nullptr) output(i)->Replace(*dX); output(i)->Replace(*dX);
break; break;
} }
} }
} }
template <class Context>
void EltwiseGradientOp<Context>::ClearAfterRun() {
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY);
}
DEPLOY_CPU(EltwiseGradient); DEPLOY_CPU(EltwiseGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(EltwiseGradient); DEPLOY_CUDA(EltwiseGradient);
...@@ -141,7 +135,7 @@ DEPLOY_CUDA(EltwiseGradient); ...@@ -141,7 +135,7 @@ DEPLOY_CUDA(EltwiseGradient);
OPERATOR_SCHEMA(EltwiseGradient).NumInputs(3, INT_MAX).NumOutputs(2, INT_MAX); OPERATOR_SCHEMA(EltwiseGradient).NumInputs(3, INT_MAX).NumOutputs(2, INT_MAX);
class GetEltwiseGradient final : public GradientMakerBase { class GetEltwiseGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetEltwiseGradient); GRADIENT_MAKER_CTOR(GetEltwiseGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
vector<string> inputs, outputs; vector<string> inputs, outputs;
......
...@@ -41,19 +41,6 @@ void ExpGradientOp<Context>::RunOnDevice() { ...@@ -41,19 +41,6 @@ void ExpGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "unsupported input types."; else LOG(FATAL) << "unsupported input types.";
} }
template <class Context>
void ExpGradientOp<Context>::ShareBeforeRun() {
Tensor* dX = ws()->GetBuffer();
if (dX != nullptr) output(0)->Replace(*dX);
}
template <class Context>
void ExpGradientOp<Context>::ClearAfterRun() {
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY);
}
DEPLOY_CPU(ExpGradient); DEPLOY_CPU(ExpGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ExpGradient); DEPLOY_CUDA(ExpGradient);
...@@ -61,7 +48,7 @@ DEPLOY_CUDA(ExpGradient); ...@@ -61,7 +48,7 @@ DEPLOY_CUDA(ExpGradient);
OPERATOR_SCHEMA(ExpGradient).NumInputs(2).NumOutputs(1); OPERATOR_SCHEMA(ExpGradient).NumInputs(2).NumOutputs(1);
class GetExpGradient final : public GradientMakerBase { class GetExpGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetExpGradient); GRADIENT_MAKER_CTOR(GetExpGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -61,18 +61,6 @@ void GramMatrixGradientOp<Context>::RunOnDevice() { ...@@ -61,18 +61,6 @@ void GramMatrixGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "unsupported input types."; else LOG(FATAL) << "unsupported input types.";
} }
template <class Context>
void GramMatrixGradientOp<Context>::ShareBeforeRun() {
Tensor* dX = ws()->GetBuffer();
if (dX != nullptr) output(0)->Replace(*dX);
}
template <class Context>
void GramMatrixGradientOp<Context>::ClearAfterRun() {
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY);
}
DEPLOY_CPU(GramMatrixGradient); DEPLOY_CPU(GramMatrixGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(GramMatrixGradient); DEPLOY_CUDA(GramMatrixGradient);
...@@ -80,7 +68,7 @@ DEPLOY_CUDA(GramMatrixGradient); ...@@ -80,7 +68,7 @@ DEPLOY_CUDA(GramMatrixGradient);
OPERATOR_SCHEMA(GramMatrixGradient).NumInputs(2).NumOutputs(1); OPERATOR_SCHEMA(GramMatrixGradient).NumInputs(2).NumOutputs(1);
class GetGramMatrixGradient final : public GradientMakerBase { class GetGramMatrixGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetGramMatrixGradient); GRADIENT_MAKER_CTOR(GetGramMatrixGradient);
vector<OperatorDef> MakeDefs() override{ vector<OperatorDef> MakeDefs() override{
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
...@@ -120,20 +120,6 @@ void InnerProductGradientOp<Context>::RunOnDevice() { ...@@ -120,20 +120,6 @@ void InnerProductGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "unsupported input types."; else LOG(FATAL) << "unsupported input types.";
} }
template <class Context>
void InnerProductGradientOp<Context>::ShareBeforeRun() {
if (output(0)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer();
if (dX != nullptr) output(0)->Replace(*dX);
}
}
template <class Context>
void InnerProductGradientOp<Context>::ClearAfterRun() {
Tensor* dY = &input(-1);
ws()->ReleaseBuffer(dY);
}
DEPLOY_CPU(InnerProductGradient); DEPLOY_CPU(InnerProductGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(InnerProductGradient); DEPLOY_CUDA(InnerProductGradient);
...@@ -141,7 +127,7 @@ DEPLOY_CUDA(InnerProductGradient); ...@@ -141,7 +127,7 @@ DEPLOY_CUDA(InnerProductGradient);
OPERATOR_SCHEMA(InnerProductGradient).NumInputs(3).NumOutputs(3); OPERATOR_SCHEMA(InnerProductGradient).NumInputs(3).NumOutputs(3);
class GetInnerProductGradient : public GradientMakerBase { class GetInnerProductGradient : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetInnerProductGradient); GRADIENT_MAKER_CTOR(GetInnerProductGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!