Commit 6f2751b1 by Ting PAN

Refactor Norm Module

1 parent 5bd1f6b5
Showing with 640 additions and 491 deletions
......@@ -97,6 +97,7 @@ link_directories(${UINX_CUDNN_LIBS})
# ---[ Install
set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR} CACHE STRING "set install prefix" FORCE)
set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_RPATH} ${3RDPARTY_LIBS})
# ---[ defines
if (WITH_PYTHON3)
......
......@@ -26,19 +26,8 @@ class OperatorBase {
public:
OperatorBase(const OperatorDef& op_def, Workspace* ws);
inline Tensor& input(int idx) {
CHECK_LT(idx, (int)inputs_.size());
CHECK_GE(idx, -(int)inputs_.size());
if (idx >= 0) return *inputs_[idx];
else return *inputs_[idx + inputs_.size()];
}
inline Tensor* output(int idx) {
CHECK_LT(idx, (int)outputs_.size());
CHECK_GE(idx, -(int)outputs_.size());
if (idx >= 0) return outputs_[idx];
else return outputs_[idx + outputs_.size()];
}
Tensor& input(int idx);
Tensor* output(int idx);
inline size_t InputSize() { return inputs_.size(); }
inline size_t OutputSize() { return outputs_.size(); }
......@@ -46,7 +35,6 @@ class OperatorBase {
inline void SwitchToPhase(const string& phase) { this->phase_ = phase; }
virtual void Run() { NOT_IMPLEMENTED; }
inline const string& name() const { return op_def_.name(); }
inline const string& type() const { return op_def_.type(); }
inline const string& phase() const { return phase_; }
......@@ -171,7 +159,7 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, const OperatorDef&, Worksp
}
#define INIT_MULTIPLIER(ptr_tensor, size) { \
ptr_tensor = ws()->CreateTensor("_t_multiplier"); \
ptr_tensor = ws()->CreateTensor("/share/multiplier"); \
if (size > ptr_tensor->count()) { \
ptr_tensor->Reshape(vector<TIndex>(1, size)); \
math::Set<T, Context>(size, dragon_cast<T, float>(1.0f), \
......
......@@ -20,8 +20,8 @@ typedef size_t TSize;
class Tensor {
public:
Tensor() {}
Tensor(const string& name) : name_(name) {}
Tensor() {}
Tensor(const string& name) : name_(name) {}
void Reshape(const vector<TIndex>& dims) {
dims_ = dims;
......@@ -153,6 +153,7 @@ class Tensor {
void* data_ptr;
mutable_data_ptr<Context>(&data_ptr);
if (meta_ == meta && data_ptr) return data_ptr;
if (meta_ != meta && data_ptr && !own_mem_) delete ex_memory_;
meta_ = meta;
CHECK_GT(size_, 0);
if (own_mem_) memory_.reset(new MixedMemory(meta, size_* meta_.itemsize()));
......@@ -196,14 +197,6 @@ class Tensor {
capacity_ = other.capacity_;
}
inline void Replace(const Tensor& other) {
memory_ = other.memory_;
meta_ = other.meta_;
capacity_ = other.capacity_;
size_ = other.size_;
dims_ = other.dims_;
}
inline void Move(MixedMemory* mem) {
if (mem != nullptr) ex_memory_ = mem;
else ex_memory_ = new MixedMemory(TypeMeta::Make<float>(), 4);
......
......@@ -26,17 +26,18 @@ class Workspace {
typedef Map<string, unique_ptr<GraphBase> > GraphMap;
typedef Map<string, TensorFiller> FillerMap;
typedef Map<string, string> RenameMap;
typedef Map<string, string> AvatarMap;
Workspace(const string& name) : name_(name) { init(); }
Workspace(const string& name) : name_(name) { Init(); }
~Workspace();
void init() {
void Init() {
CreateTensor("ignore");
CreateBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE);
CreateBuffer("Grad", WORKSPACE_GRAD_BUFFER_SIZE);
}
const string& name() { return name_; }
inline const string& name() { return name_; }
/******************** Workspace ********************/
......@@ -55,7 +56,7 @@ class Workspace {
} else { return name; }
}
inline bool HasTensor(const string& name, bool use_remote=true) {
bool HasTensor(const string& name, bool use_remote=true) {
// search local workspace
string query = GetTensorName(name);
bool result = tensor_map_.count(query) > 0;
......@@ -74,7 +75,7 @@ class Workspace {
return tensor_map_[query].get();
}
inline Tensor* GetTensor(const string& name, bool use_remote=true) {
Tensor* GetTensor(const string& name, bool use_remote=true) {
string query = GetTensorName(name);
// search local workspace
if (tensor_map_.count(query) > 0)
......@@ -113,7 +114,7 @@ class Workspace {
tensor_map_[query]->Reset();
}
inline vector<string> GetTensors() {
vector<string> GetTensors() {
vector<string> names;
// search local workspace
for (auto& it : tensor_map_)
......@@ -140,13 +141,28 @@ class Workspace {
else return nullptr;
}
/******************** Avatar ********************/
inline void CreateAvatar(Tensor* orig, Tensor* avatar) {
CHECK(tensor_map_.count(orig->name()) > 0)
<< "\nFailed to create avatar for Tensor(" << orig->name() << ")."
<< "\nAs it has not been registered in the current workspace.";
avatar_map_[orig->name()] = avatar->name();
}
inline Tensor* SearchAvatar(Tensor* orig) {
if (avatar_map_.count(orig->name()) > 0)
return tensor_map_[avatar_map_[orig->name()]].get();
return orig;
}
/******************** Buffer ********************/
inline void CreateBuffer(string category, int num) {
void CreateBuffer(string category, int num) {
CHECK(!buffer_map_.count(category));
buffer_map_[category] = stack<string>();
for (int i = 1; i <= num; i++) {
string name = "_t_" + category + "_buffer_" + dragon_cast<string, int>(i);
string name = "/share/buffer/" + category + "_" + dragon_cast<string, int>(i);
buffer_map_[category].push(name);
CreateTensor(name);
}
......@@ -163,17 +179,18 @@ class Workspace {
return nullptr;
}
inline void ReleaseBuffer(Tensor* tensor,
string category = "Common",
bool enforce = false) {
void ReleaseBuffer(Tensor* tensor,
string category = "Common",
bool enforce = false) {
static Map<string, int> limits = {
{ "Common", WORKSPACE_COMMON_BUFFER_SIZE },
{ "Grad", WORKSPACE_GRAD_BUFFER_SIZE }};
{ "Common", WORKSPACE_COMMON_BUFFER_SIZE },
{ "Grad", WORKSPACE_GRAD_BUFFER_SIZE }
};
if (buffer_map_[category].size() >= limits[category] || enforce) {
// release directly
ReleaseTensor(tensor->name());
} else {
// recover as a available buffer
if (buffer_map_[category].empty())
buffer_map_[category].push(tensor->name());
} else {
buffer_map_[category].push(tensor->name());
}
}
......@@ -182,7 +199,7 @@ class Workspace {
GraphBase* CreateGraph(const GraphDef& meta_graph);
inline bool RunGraph(const string& graph_name,
bool RunGraph(const string& graph_name,
const string& include,
const string& exclude) {
if (!graph_map_.count(graph_name)) {
......@@ -192,7 +209,7 @@ class Workspace {
return graph_map_[graph_name]->Run(include, exclude);
}
inline vector<string> GetGraphs() {
vector<string> GetGraphs() {
vector<string> names;
for (auto& it : graph_map_) names.push_back(it.first);
return names;
......@@ -214,6 +231,7 @@ class Workspace {
GraphMap graph_map_;
FillerMap filler_map_;
RenameMap rename_map_;
AvatarMap avatar_map_;
};
} // namespace dragon
......
......@@ -47,7 +47,6 @@ class DropoutGradientOp final : public Operator<Context> {
}
void RunOnDevice() override;
void CleanResource() override;
template <typename T> void RunWithType();
protected:
......
......@@ -32,7 +32,9 @@ class BiasAddGradientOp final : public Operator<Context> {
public:
BiasAddGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -45,4 +47,4 @@ class BiasAddGradientOp final : public Operator<Context> {
} // namespace dragon
#endif // DRAGON_OPERATORS_ARITHMETIC_BIAS_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_ARITHMETIC_BIAS_ADD_OP_H_
\ No newline at end of file
......@@ -17,13 +17,14 @@ class ScaleOp : public Operator<Context> {
ScaleOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {}
num_axes(OperatorBase::GetSingleArg<int>("num_axes", 1)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
TIndex axis, num_axes, inner_dim;
TIndex axis, start_axis, num_axes;
TIndex inner_dim;
Tensor* bias_multiplier;
};
......@@ -41,7 +42,7 @@ class ScaleGradientOp final : public Operator<Context> {
template <typename T> void RunWithType();
protected:
TIndex axis, num_axes;
TIndex axis, start_axis, num_axes;
TIndex outer_dim, inner_dim, scale_dim, sum_dim, dim;
Tensor* bias_multiplier, *sum_multiplier;
Tensor sum_result;
......
......@@ -16,7 +16,8 @@ class SmoothL1LossOp final : public Operator<Context> {
public:
SmoothL1LossOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)) {
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {
sigma2 *= sigma2;
}
......@@ -26,6 +27,7 @@ class SmoothL1LossOp final : public Operator<Context> {
protected:
float sigma2;
Tensor* diff, *error;
string normalization;
};
template <class Context>
......@@ -33,7 +35,8 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
public:
SmoothL1LossGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)) {
sigma2(OperatorBase::GetSingleArg<float>("sigma", 1.0)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "BATCH_SIZE")) {
sigma2 *= sigma2;
}
......@@ -43,6 +46,7 @@ class SmoothL1LossGradientOp final : public Operator<Context> {
protected:
float sigma2;
Tensor* diff;
string normalization;
};
} // namespace dragon
......
......@@ -20,7 +20,7 @@ class SoftmaxCrossEntropyOp final : public Operator<Context> {
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ input(0).name() }),
vector<string>({ "_t_" + anchor() + "_softmax_prob" }));
vector<string>({ "/mnt/" + anchor() + "/softmax_prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
......
......@@ -26,7 +26,7 @@ class SparseSoftmaxCrossEntropyOp : public Operator<Context> {
}
OperatorDef softmax_def = MakeOperatorDef("Softmax", "",
vector<string>({ input(0).name() }),
vector<string>({ "_t_" + anchor() + "_softmax_prob" }));
vector<string>({ "/mnt/" + anchor() + "/softmax_prob" }));
softmax_def.add_arg()->CopyFrom(this->arg("axis"));
if (op_def.has_device_option())
softmax_def.mutable_device_option()->CopyFrom(op_def.device_option());
......
......@@ -19,7 +19,7 @@ class SparseSoftmaxFocalLossOp final : public SparseSoftmaxCrossEntropyOp<Contex
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
alpha(OperatorBase::GetSingleArg<float>("alpha", 0.5)),
gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)),
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {
pos_alpha = alpha * 2.0;
neg_alpha = (1 - alpha) * 2.0;
......@@ -44,7 +44,7 @@ class SparseSoftmaxFocalLossGradientOp final : public SparseSoftmaxCrossEntropyG
: SparseSoftmaxCrossEntropyGradientOp<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)),
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")),
gamma(OperatorBase::GetSingleArg<float>("gamma", 2.0)),
gamma(OperatorBase::GetSingleArg<float>("gamma", 0.0)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-10))),
neg_id(OperatorBase::GetSingleArg<int>("neg_id", -1)) {}
......
......@@ -27,13 +27,12 @@ class ExpandDimsOp final : public Operator<Context> {
template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> {
public:
ExpandDimsGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
ExpandDimsGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
void RunOnDevice() override;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
#endif // DRAGON_OPERATORS_NDARRAY_EXPAND_DIMS_OP_H_
\ No newline at end of file
......@@ -32,9 +32,8 @@ template <class Context>
class FlattenGradientOp final : public Operator<Context> {
public:
FlattenGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
: Operator<Context>(op_def, ws) {}
void RunOnDevice() override;
};
......
......@@ -31,9 +31,8 @@ template <class Context>
class ReshapeGradientOp final : public Operator<Context> {
public:
ReshapeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
: Operator<Context>(op_def, ws) {}
void RunOnDevice() override;
};
......
......@@ -16,21 +16,31 @@ class BatchNormOp : public Operator<Context> {
public:
BatchNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
inplace(OperatorBase::GetSingleArg<bool>("inplace", false)) {}
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float momentum, eps;
Tensor mean, num_by_chans;
Tensor* num_multiplier, *spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* stddev, *var;
TIndex axis, N, C, S, NC, NS;
string data_format, mode;
int use_stats;
bool use_global_stats, inplace, is_recomputing;
bool use_global_stats, is_recomputing;
};
template <class Context>
......@@ -38,51 +48,72 @@ class BatchNormGradientOp final : public Operator<Context> {
public:
BatchNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
Tensor num_by_chans;
Tensor* num_multiplier, *spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* stddev, *var;
TIndex axis, N, C, S, NC, NS;
string data_format;
int use_stats;
bool use_global_stats;
};
template <class Context>
class BNOp : public Operator<Context> {
class FusedBatchNormOp : public Operator<Context> {
public:
BNOp(const OperatorDef& op_def, Workspace* ws)
FusedBatchNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { }
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup() { NOT_IMPLEMENTED; }
void RunOnDevice() override { NOT_IMPLEMENTED; }
template <typename T> void RunWithType() { NOT_IMPLEMENTED; }
protected:
float momentum, eps;
TIndex axis, N, C, S, NC, NS;
string data_format;
int use_stats;
bool use_global_stats, is_recomputing;
};
template <class Context>
class BNGradientOp : public Operator<Context> {
class FusedBatchNormGradientOp : public Operator<Context> {
public:
BNGradientOp(const OperatorDef& op_def, Workspace* ws)
FusedBatchNormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { }
void Setup() { NOT_IMPLEMENTED; }
void ShareGradient() override;
void RunOnDevice() override { NOT_IMPLEMENTED; }
template <typename T> void RunWithType() { NOT_IMPLEMENTED; }
protected:
float eps;
TIndex axis, N, C, S, NC, NS;
string data_format;
int use_stats;
bool use_global_stats;
};
......@@ -94,49 +125,54 @@ class BNGradientOp : public Operator<Context> {
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNBNOp final : public BNOp<Context> {
class CuDNNBatchNormOp final : public FusedBatchNormOp<Context> {
public:
CuDNNBNOp(const OperatorDef& op_def, Workspace* ws)
: BNOp<Context>(op_def, ws) {
CuDNNBatchNormOp(const OperatorDef& op_def, Workspace* ws)
: FusedBatchNormOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON));
}
void Setup();
void RunOnDevice() override;
template <typename T> void SpatialRunWithType();
template <typename T> void PerActivationRunWithType();
template <typename T> void RunWithType();
protected:
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
TIndex num, channels, spatial_dim;
cudnnBatchNormMode_t bn_mode;
TIndex N, C;
string data_format;
Tensor* mean, *var;
bool use_global_stats, is_recomputing;
};
template <class Context>
class CuDNNBNGradientOp final : public BNGradientOp<Context> {
class CuDNNBatchNormGradientOp final : public FusedBatchNormGradientOp<Context> {
public:
CuDNNBNGradientOp(const OperatorDef& op_def, Workspace* ws)
: BNGradientOp<Context>(op_def, ws) {
CuDNNBatchNormGradientOp(const OperatorDef& op_def, Workspace* ws)
: FusedBatchNormGradientOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_desc));
this->eps = std::max(this->eps, float(CUDNN_BN_MIN_EPSILON));
}
void Setup();
void RunOnDevice() override;
template <typename T> void SpatialRunWithType();
template <typename T> void PerActivationRunWithType();
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
cudnnTensorDescriptor_t input_desc, output_desc, bn_desc;
cudnnBatchNormMode_t bn_mode;
TIndex N, C, S, NC, NS;
string data_format;
Tensor num_by_chans;
Tensor* num_multiplier, *spatial_multiplier;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev;
TIndex num, channels, spatial_dim, nbychans;
bool use_global_stats;
};
#endif
......
......@@ -16,27 +16,36 @@ class BatchRenormOp : public Operator<Context> {
public:
BatchRenormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
r_max(OperatorBase::GetSingleArg<float>("r_max", float(3.0))),
d_max(OperatorBase::GetSingleArg<float>("d_max", float(5.0))),
t_delta(OperatorBase::GetSingleArg<float>("t_delta", float(1.0))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
inplace(OperatorBase::GetSingleArg<bool>("inplace", false)),
t_r_max(float(1.0)), t_d_max(float(0.0)), t_val(float(0.0)) {}
t_r_max(float(1.0)), t_d_max(float(0.0)), t_val(float(0.0)),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float momentum, eps, r_max, d_max, t_delta;
float t_r_max, t_d_max, t_val;
Tensor mean, d, t_h_mean, t_h_var, num_by_chans;
Tensor* num_multiplier, *spatial_multiplier;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* stddev, *r, *var, *x_norm;
TIndex num, channels, spatial_dim, nbychans;
TIndex axis, N, C, S, NC, NS;
string data_format, mode;
int use_stats;
bool use_global_stats, inplace, is_recomputing;
bool use_global_stats, is_recomputing;
};
template <class Context>
......@@ -44,16 +53,27 @@ class BatchRenormGradientOp final : public Operator<Context> {
public:
BatchRenormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
template <typename T> void RunWithType();
protected:
Tensor mean, num_by_chans;
Tensor* num_multiplier, *spatial_multiplier;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* stddev, *r, *var, *x_norm;
TIndex num, channels, spatial_dim, nbychans;
TIndex axis, N, C, S, NC, NS;
string data_format;
int use_stats;
bool use_global_stats;
};
......
......@@ -14,10 +14,16 @@ namespace dragon {
template <class Context>
class InstanceNormOp : public Operator<Context> {
public:
InstanceNormOp(const OperatorDef& op_def, Workspace* ws)
InstanceNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
inplace(OperatorBase::GetSingleArg<bool>("inplace", false)) {}
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -26,22 +32,30 @@ class InstanceNormOp : public Operator<Context> {
float eps;
Tensor mean;
Tensor* spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans;
bool inplace;
TIndex axis, N, C, S, NC, CS;
string data_format;
};
template <class Context>
class InstanceNormGradientOp final : public Operator<Context> {
public:
InstanceNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws) {}
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
Tensor* spatial_multiplier, *stddev, *var;
TIndex num, channels, spatial_dim, nbychans;
TIndex axis, N, C, S, NC, CS;
string data_format;
};
} // namespace dragon
......
......@@ -45,7 +45,6 @@ class ROIAlignGradientOp : public Operator<Context> {
}
void RunOnDevice() override;
void CleanResource() override;
template <typename T> void RunWithType();
protected:
......
......@@ -42,7 +42,6 @@ class ROIPoolingGradientOp final : public Operator<Context> {
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {}
void RunOnDevice() override;
void CleanResource() override;
template <typename T> void RunWithType();
protected:
......
......@@ -109,7 +109,7 @@ class DeviceGuard {
#else
#define CUDA_NOT_COMPILED \
LOG(FATAL) << "CUDA was not compiled.";
LOG(FATAL) << "CUDA was not compiled."
#endif // WITH_CUDA
......
......@@ -127,14 +127,14 @@ void Softmax(const int count,
Context* context);
template <typename T, class Context>
void SoftmaxGrad(const int count,
const int classes,
const int outer_dim,
void SoftmaxGrad(const int count,
const int classes,
const int outer_dim,
const int inner_dim,
const T* sum_multiplier,
const T* dy,
const T* y,
T* scale,
const T* sum_multiplier,
const T* dy,
const T* y,
T* scale,
T* dx);
/******************** activation.tanh ********************/
......
......@@ -13,8 +13,8 @@ _ENGINE_SCOPE = ''
SEPARATOR = '/'
_CURRENT_OP_IDX = 0
_SCOPE_TENSOR_IDX = defaultdict(int)
_CURRENT_OP_UID = 0
_CURRENT_TENSOR_UID = 0
__all__ = [
'GetTensorIdx',
......@@ -35,9 +35,9 @@ def GetOperatorIdx():
The operator index.
"""
global _CURRENT_OP_IDX
_CURRENT_OP_IDX = _CURRENT_OP_IDX + 1
return _CURRENT_OP_IDX - 1
global _CURRENT_OP_UID
_CURRENT_OP_UID += 1
return _CURRENT_OP_UID - 1
def GetTensorIdx():
......@@ -49,9 +49,9 @@ def GetTensorIdx():
The tensor index.
"""
global _SCOPE_TENSOR_IDX
_SCOPE_TENSOR_IDX[_TENSOR_SCOPE] += 1
return _SCOPE_TENSOR_IDX[_TENSOR_SCOPE] - 1
global _CURRENT_TENSOR_UID
_CURRENT_TENSOR_UID += 1
return _CURRENT_TENSOR_UID - 1
def GetOperatorName(name=None):
......@@ -104,7 +104,11 @@ class TensorScope(object):
def __init__(self, prefix):
assert isinstance(prefix, type('str')), \
"TensorScope takes in a string as its argument."
self.prefix = prefix + SEPARATOR
if prefix != '':
self.prefix = prefix + SEPARATOR
else:
# avoid duplicated separators
self.prefix = ''
def __enter__(self):
global _TENSOR_SCOPE
......@@ -114,7 +118,13 @@ class TensorScope(object):
def __exit__(self, type, value, traceback):
global _TENSOR_SCOPE
assert _TENSOR_SCOPE.endswith(self.prefix)
_TENSOR_SCOPE = _TENSOR_SCOPE[:-len(self.prefix)]
if self.prefix != '':
_TENSOR_SCOPE = _TENSOR_SCOPE[:-len(self.prefix)]
def get_tensor_scope():
global _TENSOR_SCOPE
return _TENSOR_SCOPE
def set_tensor_scope(name_scope):
......
......@@ -12,6 +12,7 @@ from dragon.core.utils import MakeOperatorDef
from dragon.core.scope import GetOperatorName, GetTensorName
from six.moves import range as xrange
class Tensor(object):
"""
Tensor is generally used to represent a n-dim array,
......@@ -228,8 +229,11 @@ class Tensor(object):
@name.setter
def name(self, value):
from .scope import _TENSOR_SCOPE
if value is None: self._name = _TENSOR_SCOPE + GetTensorName()
else: self._name = _TENSOR_SCOPE + value
if value is None:
# ignore the scope for the name generated by uid
self._name = GetTensorName()
else:
self._name = _TENSOR_SCOPE + value
@property
def grad_wrts(self):
......@@ -837,27 +841,7 @@ class Tensor(object):
>>> [1, 2, 3, 4]
"""
class TensorShape(object):
class Dimension(object):
def __init__(self, dim):
self.dim = dim
def __str__(self):
return 'Dimension({})'.format(self.dim)
def __init__(self, shape):
self.dims = [self.Dimension(dim) for dim in shape]
self.shape = shape
def __str__(self):
dims = [str(dim) for dim in self.dims]
return 'TensorShape([{}])'.format(', '.join(dims))
def as_list(self):
return self.shape
return TensorShape(self.shape) if self.shape is not None else None
raise NotImplementedError('Implemented in <vm.tensorflow.framework.tensor_shape>')
############################################
# #
......
......@@ -20,7 +20,9 @@
\sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \cdot r + d \\ \,
.. |moving_average_function| mathmacro:: \\ \, \\ x_{moving} = Momentum * x_{moving} + x_{stat}
.. |default_moving_average_function| mathmacro:: \\ \, \\ x_{moving} \leftarrow Momentum * x_{moving} + (1 - Momentum) * x_{stat} \\ \,
.. |caffe_moving_average_function| mathmacro:: \\ \, \\ x_{moving} \leftarrow Momentum * x_{moving} + x_{stat} \\ \,
.. _ops.Scale(*args, **kwargs): arithmetic.html#dragon.operators.arithmetic.Scale
......
......@@ -107,15 +107,15 @@ List Brief
Normalization
-------------
=============== ======================================================================
List Brief
=============== ======================================================================
`BatchNorm`_ Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`BatchRenorm`_ Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_.
`BN`_ Batch Normalization, with scale procedure after normalization.
`InstanceNorm`_ Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_.
`L2Norm`_ L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
=============== ======================================================================
================== ======================================================================
List Brief
================== ======================================================================
`BatchNorm`_ Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`BatchRenorm`_ Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_.
`FusedBatchNorm`_ Batch Normalization, with scale procedure after normalization.
`InstanceNorm`_ Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_.
`L2Norm`_ L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
================== ======================================================================
NDArray
-------
......@@ -244,7 +244,7 @@ List Brief
.. _BatchNorm: operators/norm.html#dragon.operators.norm.BatchNorm
.. _BatchRenorm: operators/norm.html#dragon.operators.norm.BatchRenorm
.. _BN: operators/norm.html#dragon.operators.norm.BN
.. _FusedBatchNorm: operators/norm.html#dragon.operators.norm.FusedBatchNorm
.. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm
.. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm
......
......@@ -170,34 +170,15 @@ Installation - Linux (Distributed, CPU)
**$** Set ``PYTHON_INCLUDE_DIR`` / ``ANACONDA_ROOT_DIR`` and ``NUMPY_ROOT_DIR``
**Step 5:** Set Environment Variables
**$** Create dragon.conf
.. code-block:: shell
sudo vim /etc/ld.so.conf.d/dragon.conf
**$** Append 1 line for ``REPO_ROOT/3rdparty/lib``
.. code-block:: shell
/xyz/Dragon/3rdparty/lib
**$** Rebuild the scanning cache
.. code-block:: shell
sudo ldconfig
**Step 6:** Setup MPI
**Step 5:** Setup MPI
.. code-block:: shell
cd $REPO_ROOT/3rdparty
bash ./setup_mpi.sh
sudo cp openmpi/install/bin/mpirun /usr/bin
**Step 7:** Compile Dragon
**Step 6:** Compile Dragon
**$** Install CMake
......@@ -215,7 +196,7 @@ Installation - Linux (Distributed, CPU)
cmake ..
make install -j16
**Step 8:** Install Dragon
**Step 7:** Install Dragon
.. code-block:: shell
......@@ -275,34 +256,15 @@ Installation - Linux (Distributed, GPU)
**$** OpenMPI can take ``NCCL`` and our ``CUDA-AWARE`` communications at the same time.
**Step 6:** Set Environment Variables
**$** Create dragon.conf
.. code-block:: shell
sudo vim /etc/ld.so.conf.d/dragon.conf
**$** Append 1 line for ``REPO_ROOT/3rdparty/lib``
.. code-block:: shell
/xyz/Dragon/3rdparty/lib
**$** Rebuild the scanning cache
.. code-block:: shell
sudo ldconfig
**Step 7:** Setup MPI
**Step 6:** Setup MPI
.. code-block:: shell
cd $REPO_ROOT/3rdparty
bash ./setup_mpi.sh
sudo cp openmpi/install/bin/mpirun /usr/bin
**Step 8:** Compile Dragon
**Step 7:** Compile Dragon
**$** Install CMake
......@@ -320,7 +282,7 @@ Installation - Linux (Distributed, GPU)
cmake ..
make install -j16
**Step 9:** Install Dragon
**Step 8:** Install Dragon
.. code-block:: shell
......@@ -379,7 +341,7 @@ Add ``REPO_ROOT/3rdparty/bin`` to system environment variables
**$** Open ``DRAGON_ROOT/build/Dragon.sln``
**$** Compile and generate for ``INSTAL`` solution
**$** Compile and generate for ``INSTALL`` solution
**Step 6:** Install Dragon
......
......@@ -165,7 +165,20 @@ def Matmul(inputs, TransA=False, TransB=False, **kwargs):
if inputs[0].shape is not None and \
inputs[1].shape is not None:
pass
if len(inputs[0].shape) < 2 or \
len(inputs[1].shape) < 2:
raise ValueError('The rank of A and B should be at least 2.')
if len(inputs[0].shape) != len(inputs[1].shape):
raise ValueError('Both A and B should have the same number of dimensions.')
M = inputs[0].shape[-1] if TransA else inputs[0].shape[-2]
K1 = inputs[0].shape[-2] if TransA else inputs[0].shape[-1]
K2 = inputs[1].shape[-1] if TransB else inputs[1].shape[-2]
N = inputs[1].shape[-2] if TransB else inputs[1].shape[-1]
if K1 != K2:
raise ValueError('Can not multiply A: ({}, {}} with B: ({}, {})'.format(M, K1, K2, N))
output.shape = inputs[0].shape[:]
output.shape[-2] = M
output.shape[-1] = N
return output
......@@ -412,6 +425,10 @@ def Scale(inputs, axis=1, num_axes=1, **kwargs):
The scale ranges are: |scale_function|
Set ``axis`` to specific the start axis(can be negative).
Set ``num_axes`` to -1 will scale all remained axes.
Parameters
----------
inputs : list of Tensor
......
......@@ -118,7 +118,7 @@ def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs):
return output
def SmoothL1Loss(inputs, sigma=1.0, **kwargs):
def SmoothL1Loss(inputs, sigma=1.0, normalization='BATCH_SIZE', **kwargs):
"""SmoothL1Loss, introduced by `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
Parameters
......@@ -127,6 +127,8 @@ def SmoothL1Loss(inputs, sigma=1.0, **kwargs):
The inputs, represent [input, targets, inside_w, outside_w].
sigma : float
The sigma of L1 bound.
normalization : str
The normalization, ``FULL``, ``BATCH_SIZE``, or ``NONE``.
Returns
-------
......@@ -203,7 +205,7 @@ def L2Loss(inputs, normalization='BATCH_SIZE', **kwargs):
def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=(),
alpha=0.5, gamma=2.0, eps=1e-10, neg_id=-1, **kwargs):
alpha=0.5, gamma=0.0, eps=1e-10, neg_id=-1, **kwargs):
"""SoftmaxFocalLoss with sparse labels, introduced by `[Lin et.al, 2017] <https://arxiv.org/abs/1708.02002>`_.
Parameters
......@@ -219,7 +221,7 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=
alpha : float
The scale factor on the rare class. Default is ``0.5``.
gamma : float
The exponential decay factor on the easy examples. Default is ``2.0``.
The exponential decay factor on the easy examples. Default is ``0.0``.
eps : float
The eps.
neg_id : int
......
......@@ -612,22 +612,24 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
output = Tensor.CreateOperator(nout=1, op_type='Flatten', **arguments)
if inputs.shape is not None:
fake_shape = inputs.shape[:]
fake_shape = [1 if dim is None else dim for dim in fake_shape]
if keep_axes is not None:
if keep_axes > len(inputs.shape):
raise ValueError('The total number of axes is {}, can not keep {}.'
.format(len(inputs.shape), keep_axes))
total_count = np.prod(inputs.shape)
total_count = np.prod(fake_shape)
output.shape = []
for i in xrange(keep_axes - 1):
output.shape.append(inputs.shape[i])
total_count *= inputs.shape[i]
total_count *= fake_shape[i]
if total_count != 1:
output.shape.append(np.long(total_count))
output.shape.append(total_count)
else:
if num_axes == -1: num_axes = len(inputs.shape) - axis
elif num_axes == 0:
raise ValueError('num_axes must > 0 or be -1.')
num_flatten = np.prod(inputs.shape[axis : axis + num_axes])
num_flatten = np.prod(fake_shape[axis : axis + num_axes])
output.shape = inputs.shape[: axis] + [num_flatten] + inputs.shape[axis + num_axes :]
return output
......
......@@ -6,21 +6,28 @@
from . import *
def BatchNorm(inputs, momentum=0.9, eps=1e-3, use_stats=-1, inplace=False, **kwargs):
"""Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_
def BatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3,
use_stats=-1, mode='DEFAULT', **kwargs):
"""Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
It follows the implementation of `Caffe`_, that scale procedure is moved to `ops.Scale(*args, **kwargs)`_.
The number of inputs vary from ``3`` to ``4`` (``DEFAULT`` or ``CAFFE`` mode).
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var, factor].
The inputs, represent [input, mean, var] or [input, mean, var, factor].
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
The eps.
use_stats : int
Whether to use global stats. Default is ``-1`` (Auto).
inplace : boolean
Whether to share input for the output.
mode : str
The moving average mode. ``DEFAULT`` or ``CAFFE``.
Returns
-------
......@@ -29,20 +36,22 @@ def BatchNorm(inputs, momentum=0.9, eps=1e-3, use_stats=-1, inplace=False, **kwa
|batchnorm_function|
The moving average of mean/var, calculated as:
The ``DEFAULT`` moving average of mean/var, calculated as:
|moving_average_function|
|default_moving_average_function|
Notes
-----
This operator follows the implementation of `Caffe`_, without scale after normalization.
The ``CAFFE`` moving average of mean/var, calculated as:
The scale procedure is moved to `ops.Scale(*args, **kwargs)`_.
|caffe_moving_average_function|
"""
CheckInputs(inputs, 4)
CheckInputs(inputs, 3, 4)
arguments = ParseArguments(locals())
if len(inputs) > 3:
if mode != 'CAFFE':
raise ValueError('Only the CAFFE mode will take 4 inputs.')
output = Tensor.CreateOperator(nout=1, op_type='BatchNorm', **arguments)
if inputs[0].shape is not None:
......@@ -51,14 +60,21 @@ def BatchNorm(inputs, momentum=0.9, eps=1e-3, use_stats=-1, inplace=False, **kwa
return output
def BatchRenorm(inputs, momentum=0.9, eps=1e-3, r_max=3.0, d_max=5.0,
t_delta=1.0, use_stats=-1, inplace=False, **kwargs):
"""Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_
def BatchRenorm(inputs, axis=-1, momentum=0.9, eps=1e-3,
r_max=3.0, d_max=5.0, t_delta=0.001,
use_stats=-1, mode='DEFAULT', **kwargs):
"""Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_.
It follows the implementation of `Caffe`_, that scale procedure is moved to `ops.Scale(*args, **kwargs)`_.
The number of inputs vary from ``3`` to ``4`` (``DEFAULT`` or ``CAFFE`` mode).
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var, factor].
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
......@@ -71,8 +87,8 @@ def BatchRenorm(inputs, momentum=0.9, eps=1e-3, r_max=3.0, d_max=5.0,
The magnitude of incrementing after each iteration.
use_stats : int
Whether to use global stats. Default is ``-1`` (Auto).
inplace : boolean
Whether to share input for the output.
mode : str
The moving average mode. ``DEFAULT`` or ``CAFFE``.
Returns
-------
......@@ -81,20 +97,22 @@ def BatchRenorm(inputs, momentum=0.9, eps=1e-3, r_max=3.0, d_max=5.0,
|batchrenorm_function|
The moving average of mean/var, calculated as:
The ``DEFAULT`` moving average of mean/var, calculated as:
|moving_average_function|
|default_moving_average_function|
Notes
-----
This operator follows the implementation of `Caffe`_, without scale after normalization.
The ``CAFFE`` moving average of mean/var, calculated as:
The scale procedure is moved to `ops.Scale(*args, **kwargs)`_.
|caffe_moving_average_function|
"""
CheckInputs(inputs, 4)
CheckInputs(inputs, 3, 4)
arguments = ParseArguments(locals())
if len(inputs) > 3:
if mode != 'CAFFE':
raise ValueError('Only the CAFFE mode will take 4 inputs.')
output = Tensor.CreateOperator(nout=1, op_type='BatchRenorm', **arguments)
if inputs[0].shape is not None:
......@@ -103,13 +121,15 @@ def BatchRenorm(inputs, momentum=0.9, eps=1e-3, r_max=3.0, d_max=5.0,
return output
def BN(inputs, momentum=0.9, eps=1e-3, use_stats=-1, **kwargs):
def FusedBatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwargs):
"""Batch Normalization, with scale procedure after normalization.
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var, scale, bias].
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
......@@ -126,13 +146,13 @@ def BN(inputs, momentum=0.9, eps=1e-3, use_stats=-1, **kwargs):
The moving average of mean/var, calculated as:
|moving_average_function|
|default_moving_average_function|
"""
CheckInputs(inputs, 5)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='BN', **arguments)
output = Tensor.CreateOperator(nout=1, op_type='FusedBatchNorm', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
......@@ -140,17 +160,17 @@ def BN(inputs, momentum=0.9, eps=1e-3, use_stats=-1, **kwargs):
return output
def InstanceNorm(inputs, eps=1e-3, inplace=False, **kwargs):
def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs):
"""Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
Parameters
----------
inputs : Tensor
The input tensor.
axis : int
The channel axis.
eps : float
The eps.
inplace : boolean
Whether to share input for the output.
Returns
-------
......
......@@ -74,7 +74,9 @@ def Conv2d(inputs, num_output, kernel_size,
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
output.shape[1] = num_output
channel_axis = 1 if data_format == 'NCHW' else -1
spatial_axis = 2 if data_format == 'NCHW' else 1
output.shape[channel_axis] = num_output
for i in xrange(2):
k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \
else arguments['kernel_size'][-1]
......@@ -85,7 +87,12 @@ def Conv2d(inputs, num_output, kernel_size,
d = arguments['dilation'][i] if i < len(arguments['dilation']) \
else arguments['dilation'][-1]
dk = d * (k - 1) + 1
output.shape[i + 2] = (output.shape[i + 2] + 2 * p - dk) / s + 1
dp = 2 * p
if padding == 'SAME':
input_size = output.shape[i + spatial_axis]
output_size = (input_size + s - 1) / float(s)
dp = int(max(0, (output_size - 1) * s + k - input_size))
output.shape[i + spatial_axis] = (output.shape[i + spatial_axis] + dp - dk) / s + 1
return output
......@@ -226,7 +233,7 @@ def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
if inputs.shape is not None:
output.shape = inputs.shape[:]
axis = 2 if data_format == 'NCHW' else 1
spatial_axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2):
k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \
else arguments['kernel_size'][-1]
......@@ -234,17 +241,18 @@ def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
else arguments['stride'][-1]
p = arguments['pad'][i] if i < len(arguments['pad']) \
else arguments['pad'][-1]
if padding == 'SAME':
input_size = output.shape[i + axis]
output_size = (input_size + s - 1) / float(s)
padding_needed = max(0, (output_size - 1) * s + k - input_size)
p_l = padding_needed / 2
p_r = padding_needed - p_l
p = min(p_l, p_r)
if not global_pooling:
output.shape[i + axis] = int(math.ceil(float(output.shape[i + axis] + 2 * p - k) / s) + 1)
if padding != 'SAME':
input_size = output.shape[i + spatial_axis]
output_size = int(math.ceil(float(output.shape[i + spatial_axis] + 2 * p - k) / s) + 1)
if ((output_size - 1) * s >= input_size + p):
output_size = output_size - 1
output.shape[i + spatial_axis] = output_size
else:
output.shape[i + spatial_axis] = \
int((output.shape[i + spatial_axis] + s - 1) / float(s))
else:
output.shape[i + axis] = 1
output.shape[i + spatial_axis] = 1
return output
......
......@@ -87,7 +87,7 @@ GramMatrix = math.GramMatrix
# normalization
BatchNorm = norm.BatchNorm
BatchRenorm = norm.BatchRenorm
BN = norm.BN
FusedBatchNorm = norm.FusedBatchNorm
InstanceNorm = norm.InstanceNorm
L2Norm = norm.L2Norm
......
......@@ -329,7 +329,9 @@ class BatchNormLayer(Layer):
self._param = {'use_stats': int(param.use_global_stats)
if param.HasField('use_global_stats') else -1,
'momentum': param.moving_average_fraction,
'eps': param.eps}
'eps': param.eps,
'axis': 1,
'mode': 'CAFFE'}
# mean, var, factor are set to 0 in order to do statistics
mean = Tensor(LayerParameter.name + '@param0').Constant(value=0.0)
var = Tensor(LayerParameter.name + '@param1').Constant(value=0.0)
......@@ -373,7 +375,9 @@ class BatchRenormLayer(Layer):
'eps': param.eps,
'r_max': float(param.r_max),
'd_max': float(param.d_max),
't_delta': float(param.t_delta)}
't_delta': float(param.t_delta),
'axis': 1,
'mode': 'CAFFE'}
mean = Tensor(LayerParameter.name + '@param0').Constant(value=0.0)
var = Tensor(LayerParameter.name + '@param1').Constant(value=0.0)
factor = Tensor(LayerParameter.name + '@param2').Constant(value=0.0)
......@@ -394,6 +398,7 @@ class InstanceNormLayer(Layer):
"""
def __init__(self, LayerParameter):
super(InstanceNormLayer, self).__init__(LayerParameter)
self._param = {'axis': 1}
def Setup(self, bottom):
super(InstanceNormLayer, self).Setup(bottom)
......@@ -464,7 +469,8 @@ class BNLayer(Layer):
self._param = {'use_stats': int(bn_param.use_global_stats)
if bn_param.HasField('use_global_stats') else -1,
'momentum': bn_param.moving_average_fraction,
'eps': bn_param.eps}
'eps': bn_param.eps,
'axis': 1}
mean = Tensor(LayerParameter.name + '@param0').Constant(value=0.0)
var = Tensor(LayerParameter.name + '@param1').Constant(value=0.0)
scale = Tensor(LayerParameter.name + '@param2')
......@@ -485,7 +491,7 @@ class BNLayer(Layer):
def Setup(self, bottom):
super(BNLayer, self).Setup(bottom)
return ops.BN(bottom + [blob['data'] for blob in self._blobs], **self._param)
return ops.FusedBatchNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class NormalizeLayer(Layer):
......
......@@ -20,7 +20,7 @@ class SoftmaxWithLossLayer(Layer):
normalization : NormalizationMode
The normalization. Refer `LossParameter.normalization`_.
normalize : boolean
Wheter to normalize. Refer `LossParameter.normalize`_.
Whether to normalize. Refer `LossParameter.normalize`_.
"""
def __init__(self, LayerParameter):
......@@ -51,16 +51,16 @@ class SigmoidCrossEntropyLossLayer(Layer):
normalization : NormalizationMode
The normalization. Refer `LossParameter.normalization`_.
normalize : boolean
Wheter to normalize. Refer `LossParameter.normalize`_.
Whether to normalize. Refer `LossParameter.normalize`_.
"""
def __init__(self, LayerParameter):
super(SigmoidCrossEntropyLossLayer, self).__init__(LayerParameter)
param = LayerParameter.loss_param
norm_mode = {0: 'FULL', 1: 'FULL', 2: 'BATCH_SIZE', 3: 'NONE'}
normalization = 'FULL'
norm_mode = {0: 'FULL', 1: 'BATCH_SIZE', 2: 'BATCH_SIZE', 3: 'NONE'}
normalization = 'BATCH_SIZE'
if param.HasField('normalize'):
if not param.normalize: normalization = 'BATCH_SIZE'
if param.normalize: normalization = 'FULL'
else: normalization = norm_mode[param.normalization]
self._param = { 'normalization': normalization }
......@@ -78,14 +78,18 @@ class L2LossLayer(Layer):
normalization : NormalizationMode
The normalization. Refer `LossParameter.normalization`_.
normalize : boolean
Wheter to normalize. Refer `LossParameter.normalize`_.
Whether to normalize. Refer `LossParameter.normalize`_.
"""
def __init__(self, LayerParameter):
super(L2LossLayer, self).__init__(LayerParameter)
param = LayerParameter.loss_param
self._param = {'normalize': param.normalize
if param.HasField('normalize') else True}
norm_mode = {0: 'FULL', 1: 'BATCH_SIZE', 2: 'BATCH_SIZE', 3: 'NONE'}
normalization = 'BATCH_SIZE'
if param.HasField('normalize'):
if param.normalize: normalization = 'FULL'
else: normalization = norm_mode[param.normalization]
self._param = {'normalization': normalization}
def Setup(self, bottom):
super(L2LossLayer, self).Setup(bottom)
......@@ -104,13 +108,20 @@ class SmoothL1LossLayer(Layer):
normalization : NormalizationMode
The normalization. Refer `LossParameter.normalization`_.
normalize : boolean
Wheter to normalize. Refer `LossParameter.normalize`_.
Whether to normalize. Refer `LossParameter.normalize`_.
"""
def __init__(self, LayerParameter):
super(SmoothL1LossLayer, self).__init__(LayerParameter)
param = LayerParameter.smooth_l1_loss_param
self._param = {'sigma': float(param.sigma)}
param = LayerParameter.loss_param
smooth_l1_param = LayerParameter.smooth_l1_loss_param
norm_mode = {0: 'FULL', 1: 'BATCH_SIZE', 2: 'BATCH_SIZE', 3: 'NONE'}
normalization = 'BATCH_SIZE'
if param.HasField('normalize'):
if param.normalize: normalization = 'FULL'
else: normalization = norm_mode[param.normalization]
self._param = {'sigma': float(smooth_l1_param.sigma),
'normalization': normalization}
def Setup(self, bottom):
super(SmoothL1LossLayer, self).Setup(bottom)
......@@ -129,11 +140,15 @@ class SoftmaxWithFocalLossLayer(Layer):
alpha : float
The scale on the rare class. Refer `FocalLossParameter.alpha`_.
gamma : float
The exponetial decay. Refer `FocalLossParameter.gamma`_.
The exponential decay. Refer `FocalLossParameter.gamma`_.
eps : float
The eps. Refer `FocalLossParameter.eps`_.
neg_id : int
The negative id. Refer `FocalLossParameter.neg_id`_.
normalization : NormalizationMode
The normalization. Refer `LossParameter.normalization`_.
normalize : boolean
Whether to normalize. Refer `LossParameter.normalize`_.
"""
def __init__(self, LayerParameter):
......@@ -144,7 +159,7 @@ class SoftmaxWithFocalLossLayer(Layer):
norm_mode = {0: 'FULL', 1: 'VALID', 2: 'BATCH_SIZE', 3: 'NONE'}
normalization = 'VALID'
if param.HasField('normalize'):
if not param.normalize: normalization='BATCH_SIZE'
if not param.normalize: normalization = 'BATCH_SIZE'
else: normalization = norm_mode[param.normalization]
self._param = {'axis': softmax_param.axis,
'normalization': normalization,
......
......@@ -1487,7 +1487,7 @@ message BatchRenormParameter {
optional float eps = 3 [default = 1e-3];
optional float r_max = 4 [default = 3.0];
optional float d_max = 5 [default = 5.0];
optional float t_delta = 6 [default = 1.0];
optional float t_delta = 6 [default = 0.001];
}
message DenseConcatParameter {
......@@ -1497,7 +1497,7 @@ message DenseConcatParameter {
message FocalLossParameter {
optional float alpha = 1 [default = 0.5];
optional float gamma = 2 [default = 2.0];
optional float gamma = 2 [default = 0.0];
optional float eps = 3 [default = 1e-10];
optional int32 neg_id = 4 [default = -1];
}
......
......@@ -362,13 +362,13 @@ void Graph::RecomputingAware(const GraphDef& optimized_graph, Workspace* ws) {
// prepare resources
for (auto& ops : ops_) ops->set_recompute_map(recompute_map);
Tensor* head = ws->CreateTensor("_t_mirror_stage_head");
Tensor* head = ws->CreateTensor("/opt/mirror_stage/head");
head->Reshape(vector<TIndex>(1, WORKSPACE_MAX_CORRUPTED_SIZE));
Tensor* recompute_flag = ws->CreateTensor("_t_global_recompute_flag");
Tensor* recompute_flag = ws->CreateTensor("/opt/mirror_stage/recompute_flag");
recompute_flag->Reshape(vector<TIndex>(1, 1));
recompute_flag->mutable_data<bool, CPUContext>()[0] = false;
for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) {
string name = "_t_mirror_stage_buffer_" + dragon_cast<string, int>(i);
string name = "/opt/mirror_stage/buffer_" + dragon_cast<string, int>(i);
Tensor* buffer = ws->CreateTensor(name);
head->mutable_data<string, CPUContext>()[i] = "";
}
......
......@@ -88,7 +88,7 @@ void MixedMemory::async_cuda_data(const cudaStream_t& stream) {
MixedMemory::~MixedMemory() {
bool use_cudahost_mem = false;
#ifdef WITH_CUDA_HOST_MEN
#ifdef WITH_CUDA_HOST_MEM
use_cudahost_mem = true;
#endif
if (cpu_ptr_ && !use_cudahost_mem) {
......@@ -112,4 +112,4 @@ void MixedMemory::SwitchToDevice() {
}
}
} // namespace dragon
} // namespace dragon
\ No newline at end of file
......@@ -20,6 +20,19 @@ OperatorBase::OperatorBase(const OperatorDef& op_def, Workspace* ws)
outputs_.push_back(tensor);
}
}
inline Tensor& OperatorBase::input(int idx) {
CHECK_LT(idx, (int)inputs_.size());
CHECK_GE(idx, -(int)inputs_.size());
if (idx >= 0) return *(ws()->SearchAvatar(inputs_[idx]));
else return *(ws()->SearchAvatar(inputs_[idx + inputs_.size()]));
}
inline Tensor* OperatorBase::output(int idx) {
CHECK_LT(idx, (int)outputs_.size());
CHECK_GE(idx, -(int)outputs_.size());
if (idx >= 0) return ws()->SearchAvatar(outputs_[idx]);
else return ws()->SearchAvatar(outputs_[idx + outputs_.size()]);
}
OperatorBase* TryCreateOperator(const string& key, const OperatorDef& op_def, Workspace* ws) {
switch (op_def.device_option().device_type()) {
......@@ -49,11 +62,11 @@ Gradient MakeGradientForOp(const OperatorDef& def, const vector<string>& g_outpu
if (maker.get() == nullptr)
LOG(FATAL) << "Gradient maker for operator " << def.type() << "not implemented.";
Gradient grad = maker->Make();
// copy device option, engine, and arguments if needed.
if (maker->CopyDeviceOption() && def.has_device_option())
// copy device option, engine, and arguments if needed
if (maker->CopyDeviceOption() && def.has_device_option())
for (auto& grad_def : grad.ops)
grad_def.mutable_device_option()->CopyFrom(def.device_option());
// copy arguments if needed.
// copy arguments if needed
if (maker->CopyArguments() && def.arg_size())
for (auto& grad_def : grad.ops) grad_def.mutable_arg()->MergeFrom(def.arg());
return grad;
......@@ -63,7 +76,7 @@ template <class Context>
void Operator<Context>::ElimateCorruption() {
Set<string> all_heads;
queue<int> safe_heads;
Tensor* head = ws()->GetTensor("_t_mirror_stage_head");
Tensor* head = ws()->GetTensor("/opt/mirror_stage/head");
string* head_data = head->mutable_data<string, CPUContext>();
for (int i = 0; i < head->count(); i++) all_heads.insert(head_data[i]);
// sub-graph run
......@@ -71,7 +84,7 @@ void Operator<Context>::ElimateCorruption() {
if (input(i).is_corrupted()) {
if (all_heads.count(input(i).name())) continue;
LOG(DEBUG) << "Tensor(" << input(i).name() << ") is corrupted, recompute... ";
Tensor* recompute_flag = ws()->GetTensor("_t_global_recompute_flag");
Tensor* recompute_flag = ws()->GetTensor("/opt/mirror_stage/recompute_flag");
vector<OperatorBase*>& list = recompute_map()[input(i).name()];
recompute_flag->mutable_data<bool, CPUContext>()[0] = true;
for (int j = 0; j < list.size(); j++) list[j]->Run();
......@@ -101,7 +114,7 @@ void Operator<Context>::ElimateCorruption() {
<< "\nadd WORKSPACE_MAX_CORRUPTED_SIZE for more powerful mirror stage ?";
int idx = safe_heads.front();
safe_heads.pop();
Tensor* buffer = ws()->GetTensor("_t_mirror_stage_buffer_" + dragon_cast<string, int>(idx));
Tensor* buffer = ws()->GetTensor("/opt/mirror_stage/buffer_" + dragon_cast<string, int>(idx));
output(i)->Move(buffer->memory());
head_data[idx] = output(i)->name();
}
......@@ -113,7 +126,7 @@ void Operator<Context>::ShareGradient() {
// TODO(PhyscalX): we preset input(-1)->output(0) to share
if (output(0)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(0)->Replace(*dX);
ws()->CreateAvatar(output(0), dX);
}
}
......@@ -127,12 +140,12 @@ template <class Context>
void Operator<Context>::CleanResource() {
// post-process for mirror stage
Map<string, int> head_to_idx;
Tensor* head = ws()->GetTensor("_t_mirror_stage_head");
Tensor* head = ws()->GetTensor("/opt/mirror_stage/head");
string* head_data = head->mutable_data<string, CPUContext>();
for (int i = 0; i < head->count(); i++) head_to_idx[head_data[i]] = i;
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->is_corrupted() && head_to_idx.count(output(i)->name())) {
string used = "_t_mirror_stage_buffer_" + dragon_cast<string, int>(head_to_idx[output(i)->name()]);
string used = "/opt/mirror_stage/buffer_" + dragon_cast<string, int>(head_to_idx[output(i)->name()]);
Tensor* buffer = ws()->GetTensor(used);
if (output(i)->memory() != buffer->memory()) buffer->Move(output(i)->memory());
}
......
......@@ -16,7 +16,7 @@ GraphBase* Workspace::CreateGraph(const GraphDef& meta_graph) {
Workspace::~Workspace() {
for (int i = 0; i < WORKSPACE_MAX_CORRUPTED_SIZE; i++) {
string name = "_t_mirror_stage_buffer_" + dragon_cast<string, int>(i);
string name = "/opt/mirror_stage/buffer_" + dragon_cast<string, int>(i);
if (tensor_map_.count(name) > 0) {
MixedMemory* mem = tensor_map_[name]->memory();
if (mem != nullptr) delete mem;
......
......@@ -27,7 +27,7 @@ void DropoutOp<Context>::RunWithType() {
template <class Context>
void DropoutOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
mask = ws()->CreateTensor("_t_" + anchor() + "_dropout_mask");
mask = ws()->CreateTensor("/mnt/" + anchor() + "/dropout_mask");
mask->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
......@@ -42,23 +42,21 @@ OPERATOR_SCHEMA(Dropout).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } });
template <class Context> template <typename T>
void DropoutGradientOp<Context>::RunWithType() {
mask = ws()->GetTensor("_t_" + anchor() + "_dropout_mask");
mask = ws()->GetTensor("/mnt/" + anchor() + "/dropout_mask");
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template data<uint32_t, Context>();
if (this->phase() == "TRAIN") {
kernel::DropoutGrad<T, Context>(output(0)->count(),
prob,
kernel::DropoutGrad<T, Context>(output(0)->count(),
prob,
scale,
dYdata,
dYdata,
Mdata,
dXdata);
} else if (this->phase() == "TEST") {
NOT_IMPLEMENTED;
}
} else if (this->phase() == "TEST") { NOT_IMPLEMENTED; }
mask->Reset();
}
template <class Context>
......@@ -69,12 +67,6 @@ void DropoutGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types.";
}
template <class Context>
void DropoutGradientOp<Context>::CleanResource() {
Operator<Context>::CleanResource();
ws()->ReleaseBuffer(mask, "Common", true);
}
DEPLOY_CPU(DropoutGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(DropoutGradient);
......
......@@ -26,7 +26,7 @@ void SoftmaxOp<Context>::RunWithType() {
template <class Context>
void SoftmaxOp<Context>::RunOnDevice() {
if (axis == -1) axis = (int)input(0).ndim() - 1;
scale = ws()->CreateTensor("_t_softmax_scale");
scale = ws()->CreateTensor("/share/softmax_scale");
scale->ReshapeLike(input(0));
outer_dim = input(0).count(0, axis);
inner_dim = input(0).count(axis + 1);
......@@ -64,7 +64,7 @@ void SoftmaxGradientOp<Context>::RunWithType() {
template <class Context>
void SoftmaxGradientOp<Context>::RunOnDevice() {
if (axis == -1) axis = (int)input(0).ndim() - 1;
scale = ws()->CreateTensor("_t_softmax_scale");
scale = ws()->CreateTensor("/share/softmax_scale");
scale->ReshapeLike(input(0));
outer_dim = input(0).count(0, axis);
inner_dim = input(0).count(axis + 1);
......
......@@ -176,7 +176,7 @@ void AddGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -30,8 +30,7 @@ void BiasAddOp<Context>::RunOnDevice() {
dim = input(0).dim(-1);
inner_dim = input(0).count(1) / dim;
} else LOG(FATAL) << "Unknown data format: " << data_format;
output(0)->ReshapeLike(input(0));
output(0)->Share(input(0));
ws()->CreateAvatar(output(0), &input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -70,8 +69,7 @@ void BiasAddGradientOp<Context>::RunWithType() {
}
if (output(0)->name() != "ignore") {
output(0)->ReshapeLike(input(-1));
output(0)->Share(input(-1));
ws()->CreateAvatar(output(0), &input(-1));
}
}
......
......@@ -16,8 +16,7 @@ void ClipOp<Context>::RunWithType() {
template <class Context>
void ClipOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
output(0)->Share(input(0));
mask = ws()->CreateTensor("_t_" + anchor() + "_clip_mask");
mask = ws()->CreateTensor("/mnt/" + anchor() + "/clip_mask");
mask->ReshapeLike(input(0));
if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -27,7 +26,7 @@ DEPLOY_CPU(Clip);
#ifdef WITH_CUDA
DEPLOY_CUDA(Clip);
#endif
OPERATOR_SCHEMA(Clip).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Clip).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } });
template <class Context> template <typename T>
void ClipGradientOp<Context>::RunWithType() {
......@@ -39,8 +38,7 @@ void ClipGradientOp<Context>::RunWithType() {
template <class Context>
void ClipGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
output(0)->Share(input(-1));
mask = ws()->GetTensor("_t_" + anchor() + "_clip_mask");
mask = ws()->GetTensor("/mnt/" + anchor() + "/clip_mask");
if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
......@@ -49,14 +47,14 @@ DEPLOY_CPU(ClipGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(ClipGradient);
#endif
OPERATOR_SCHEMA(ClipGradient).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(ClipGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 } });
class GetClipGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetClipGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), GO(0)},
vector<string> {O(0), GO(0)},
vector<string> {GI(0)});
}
};
......
......@@ -122,7 +122,7 @@ void DivGradientOp<Context>::BroadcastRunWithType(int type) {
}
if (output(1)->name() != "ignore") {
Tensor* buffer = ws()->CreateTensor("_t_buffer_0");
Tensor* buffer = ws()->GetBuffer();
buffer->ReshapeLike(input(1));
auto* X1data = input(0).template data<T, Context>();
auto* X2data = input(1).template data<T, Context>();
......@@ -147,6 +147,7 @@ void DivGradientOp<Context>::BroadcastRunWithType(int type) {
dX1data, BMul_data, 0.0, Bdata);
}
math::Mul<T, Context>(input(1).count(), Bdata, dX2data, dX2data);
ws()->ReleaseBuffer(buffer);
}
if (output(0)->name() != "ignore") {
......@@ -207,7 +208,7 @@ void DivGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -183,7 +183,7 @@ void DotGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -130,7 +130,7 @@ void EltwiseGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -118,7 +118,7 @@ void MatmulGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -195,7 +195,7 @@ void MulGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -176,7 +176,7 @@ void RAddGradientOp<Context>::ShareGradient() {
for (int i = (int)OutputSize() - 1; i >= 0; i--) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -201,7 +201,7 @@ void RDivGradientOp<Context>::ShareGradient() {
for (int i = (int)OutputSize() - 1; i >= 0; i--) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -193,7 +193,7 @@ void RMulGradientOp<Context>::ShareGradient() {
for (int i = (int)OutputSize() - 1; i >= 0; i--) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -177,7 +177,7 @@ void RSubGradientOp<Context>::ShareGradient() {
for (int i = (int)OutputSize() - 1; i >= 0; i--) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -7,28 +7,33 @@ namespace dragon {
template <class Context> template <typename T>
void ScaleOp<Context>::RunWithType() {
CHECK_LT(axis, (int)input(0).ndim());
const vector<TIndex>::const_iterator& dim_start =
input(0).dims().begin() + axis;
if (num_axes == -1) num_axes = (int)input(0).ndim() - axis;
CHECK_LE(axis + num_axes, (int)input(0).ndim());
start_axis = axis;
if (start_axis < 0) start_axis += (int)input(0).ndim();
if (num_axes == -1) num_axes = (int)input(0).ndim() - start_axis;
else if (num_axes == 0) num_axes = 1;
CHECK_LT(start_axis, (int)input(0).ndim());
CHECK_LE(start_axis + num_axes, (int)input(0).ndim());
const vector<TIndex>::const_iterator& dim_start = input(0).dims().begin() + start_axis;
const vector<TIndex>::const_iterator& dim_end = dim_start + num_axes;
vector<TIndex> param_dims(dim_start, dim_end);
TENSOR_FILL(input(1), param_dims);
if (InputSize() > 2) {
TENSOR_FILL(input(2), param_dims);
inner_dim = input(0).count(axis + num_axes);
inner_dim = input(0).count(start_axis + num_axes);
INIT_MULTIPLIER(bias_multiplier, inner_dim);
}
if (InputSize() > 2) {
kernel::Scale<T, Context>(axis, &input(0), &input(1),
&input(2), bias_multiplier,
output(0));
kernel::Scale<T, Context>(start_axis, &input(0), &input(1),
&input(2), bias_multiplier,
output(0));
} else {
kernel::Scale<T, Context>(axis, &input(0), &input(1),
nullptr, nullptr,
output(0));
kernel::Scale<T, Context>(start_axis, &input(0), &input(1),
nullptr, nullptr,
output(0));
}
}
......@@ -95,9 +100,9 @@ void ScaleGradientOp<Context>::ScaleRunWithType() {
SRes_data = (outer_dim == 1) ? // handle scale only
dScale : sum_result.template mutable_data<T, Context>();
math::Gemv<T, Context>(CblasNoTrans, sum_result.count(), inner_dim,
1.0,
tmp_data, SMul_data,
SRes_data == dScale ? 1.0 : 0.0,
1.0,
tmp_data, SMul_data,
SRes_data == dScale ? 1.0 : 0.0,
SRes_data);
}
if (outer_dim != 1) {
......@@ -106,9 +111,9 @@ void ScaleGradientOp<Context>::ScaleRunWithType() {
*dScale += result;
} else {
math::Gemv<T, Context>(CblasTrans, outer_dim, scale_dim,
1.0,
SRes_data, SMul_data,
1.0,
1.0,
SRes_data, SMul_data,
1.0,
dScale);
}
}
......@@ -118,14 +123,21 @@ void ScaleGradientOp<Context>::ScaleRunWithType() {
template <class Context> template <typename T>
void ScaleGradientOp<Context>::RunWithType() {
output(0)->ReshapeLike(input(0));
kernel::ScaleGrad<float, Context>(axis, &input(-1), &input(1), output(0));
kernel::ScaleGrad<float, Context>(start_axis, &input(-1), &input(1), output(0));
}
template <class Context>
void ScaleGradientOp<Context>::RunOnDevice() {
if (num_axes == -1) num_axes = (int)input(0).ndim() - axis;
outer_dim = input(0).count(0, axis);
inner_dim = input(0).count(axis + num_axes);
start_axis = axis;
if (start_axis < 0) start_axis += (int)input(0).ndim();
if (num_axes == -1) num_axes = (int)input(0).ndim() - start_axis;
else if (num_axes == 0) num_axes = 1;
CHECK_LT(start_axis, (int)input(0).ndim());
CHECK_LE(start_axis + num_axes, (int)input(0).ndim());
outer_dim = input(0).count(0, start_axis);
inner_dim = input(0).count(start_axis + num_axes);
scale_dim = input(1).count();
sum_dim = std::max(outer_dim, inner_dim);
dim = scale_dim * inner_dim;
......
......@@ -176,7 +176,7 @@ void SubGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -142,7 +142,7 @@ void ScanOp<Context>::UnrollTemplate() {
new_def.add_target(output(i)->name());
}
// upload
Tensor* string_tensor = ws()->CreateTensor("_t_" + anchor() + "_raw_ops");
Tensor* string_tensor = ws()->CreateTensor("/mnt/" + anchor() + "/raw_ops");
string_tensor->Reshape(vector<TIndex>(1, 1));
string* data = string_tensor->mutable_data <string, CPUContext>();
data[0] = new_def.SerializeAsString();
......@@ -171,7 +171,7 @@ void ScanGradientOp<Context>::MakeGradientOps() {
else if (step_type == "Default") nsteps = input(0).dim(axis);
if (graphs.count(nsteps)) return;
Tensor* ops = ws()->GetTensor("_t_" + anchor() + "_raw_ops");
Tensor* ops = ws()->GetTensor("/mnt/" + anchor() + "/raw_ops");
forward_def.ParseFromString(ops->data<string, CPUContext>()[0]);
vector<string> targets;
for (auto& target : forward_def.target()) targets.push_back(target);
......
......@@ -31,7 +31,7 @@ template <class Context>
void L1LossOp<Context>::RunOnDevice() {
CHECK_EQ(input(0).count(), input(1).count());
output(0)->Reshape(vector<TIndex>(1, 1));
diff = ws()->CreateTensor("_t_" + anchor() + "_l1_loss_diff");
diff = ws()->CreateTensor("/mnt/" + anchor() + "/l1_loss_diff");
diff->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
......@@ -67,7 +67,7 @@ void L1LossGradientOp<Context>::RunWithType() {
template <class Context>
void L1LossGradientOp<Context>::RunOnDevice() {
diff = ws()->GetTensor("_t_" + anchor() + "_l1_loss_diff");
diff = ws()->GetTensor("/mnt/" + anchor() + "/l1_loss_diff");
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -78,7 +78,7 @@ void L1LossGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -29,7 +29,7 @@ template <class Context>
void L2LossOp<Context>::RunOnDevice() {
CHECK_EQ(input(0).count(), input(1).count());
output(0)->Reshape(vector<TIndex>(1, 1));
diff = ws()->CreateTensor("_t_" + anchor() + "_l2_loss_diff");
diff = ws()->CreateTensor("/mnt/" + anchor() + "/l2_loss_diff");
diff->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
......@@ -64,7 +64,7 @@ void L2LossGradientOp<Context>::RunWithType() {
template <class Context>
void L2LossGradientOp<Context>::RunOnDevice() {
diff = ws()->GetTensor("_t_" + anchor() + "_l2_loss_diff");
diff = ws()->GetTensor("/mnt/" + anchor() + "/l2_loss_diff");
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -75,7 +75,7 @@ void L2LossGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -35,7 +35,7 @@ template <class Context>
void SigmoidCrossEntropyOp<Context>::RunOnDevice() {
CHECK_EQ(input(0).count(), input(1).count())
<< "\nNumber of predictions must match the number of labels.";
prob = ws()->CreateTensor("_t_" + anchor() + "_sigmoid_prob");
prob = ws()->CreateTensor("/mnt/" + anchor() + "/sigmoid_prob");
prob->ReshapeLike(input(0));
losses.ReshapeLike(input(0));
......@@ -73,7 +73,7 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
template <class Context>
void SigmoidCrossEntropyGradientOp<Context>::RunOnDevice() {
prob = ws()->GetTensor("_t_" + anchor() + "_sigmoid_prob");
prob = ws()->GetTensor("/mnt/" + anchor() + "/sigmoid_prob");
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
......
......@@ -23,9 +23,13 @@ void SmoothL1LossOp<Context>::RunWithType() {
auto* outside_w_data = input(3).template data<T, Context>();
math::Mul<T, Context>(diff->count(), outside_w_data, error_data, error_data);
}
Ydata[0] = math::ASum<T, Context>(error->count(), error_data);
T loss = math::ASum<T, Context>(error->count(), error_data);
Ydata[0] = loss / input(0).dim(0);
T normalizer;
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = input(0).count();
else if (normalization == "NONE") normalizer = 1;
Ydata[0] = Ydata[0] / normalizer;
}
template <class Context>
......@@ -35,8 +39,8 @@ void SmoothL1LossOp<Context>::RunOnDevice() {
if (InputSize() > 3) CHECK(input(0).dims() == input(3).dims());
output(0)->Reshape(vector<TIndex>(1, 1));
diff = ws()->CreateTensor("_t_" + anchor() + "_smoothl1_loss_diff");
error = ws()->CreateTensor("_t_smoothl1_loss_error");
diff = ws()->CreateTensor("/mnt/" + anchor() + "/smoothl1_loss_diff");
error = ws()->CreateTensor("/share/smoothl1_loss_error");
diff->ReshapeLike(input(0));
error->ReshapeLike(input(0));
......@@ -54,16 +58,21 @@ template <class Context> template <typename T>
void SmoothL1LossGradientOp<Context>::RunWithType() {
auto* diff_data = diff->template mutable_data<T, Context>();
auto* dYdata = input(-1).template data<T, CPUContext>();
kernel::SmoothL1Grad<T, Context>(diff->count(), sigma2, diff_data, diff_data);
T alpha = dYdata[0], normalizer;
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = input(0).count();
else if (normalization == "NONE") normalizer = 1;
alpha = alpha / normalizer;
for (int i = 0; i < 2; i++) {
if (output(i)->name() == "ignore") continue;
output(i)->ReshapeLike(input(i));
auto* dXdata = output(i)->template mutable_data<T, Context>();
const T sign = (i == 0) ? 1 : -1;
const T coeff = sign / input(i).dim(0) * dYdata[0];
math::Axpby<T, Context>(output(i)->count(), coeff, diff_data, 0, dXdata);
alpha *= sign;
math::Axpby<T, Context>(output(i)->count(), alpha, diff_data, 0, dXdata);
if (InputSize() > 3) {
auto* inside_w_data = input(2).template data<T, Context>();
math::Mul<T, Context>(output(i)->count(), inside_w_data, dXdata, dXdata);
......@@ -77,7 +86,7 @@ void SmoothL1LossGradientOp<Context>::RunWithType() {
template <class Context>
void SmoothL1LossGradientOp<Context>::RunOnDevice() {
diff = ws()->GetTensor("_t_" + anchor() + "_smoothl1_loss_diff");
diff = ws()->GetTensor("/mnt/" + anchor() + "/smoothl1_loss_diff");
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......
......@@ -43,7 +43,7 @@ void SoftmaxCrossEntropyOp<Context>::RunOnDevice() {
<< "\nNumber of predictions must match the number of labels.";
losses.ReshapeLike(input(0));
softmax_op->Run();
prob = ws()->GetTensor("_t_" + anchor() + "_softmax_prob");
prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax_prob");
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -85,7 +85,7 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
template <class Context>
void SoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
prob = ws()->GetTensor("_t_" + anchor() + "_softmax_prob");
prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax_prob");
outer_dim = prob->count(0, axis);
inner_dim = prob->count(axis + 1);
output(0)->ReshapeLike(input(0));
......
......@@ -51,7 +51,7 @@ void SparseSoftmaxCrossEntropyOp<Context>::RunOnDevice() {
valid.Reshape(vector<TIndex>(1, outer_dim * inner_dim));
losses.Reshape(vector<TIndex>(1, outer_dim * inner_dim));
softmax_op->Run();
prob = ws()->GetTensor("_t_" + anchor() + "_softmax_prob");
prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax_prob");
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -100,7 +100,7 @@ void SparseSoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
template <class Context>
void SparseSoftmaxCrossEntropyGradientOp<Context>::RunOnDevice() {
prob = ws()->GetTensor("_t_" + anchor() + "_softmax_prob");
prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax_prob");
outer_dim = prob->count(0, axis);
inner_dim = prob->count(axis + 1);
output(0)->ReshapeLike(input(0));
......
......@@ -57,8 +57,8 @@ void SparseSoftmaxFocalLossOp<Context>::RunOnDevice() {
this->valid.Reshape(vector<TIndex>(1, outer_dim * inner_dim));
this->losses.Reshape(vector<TIndex>(1, outer_dim * inner_dim));
this->softmax_op->Run();
this->prob = ws()->GetTensor("_t_" + anchor() + "_softmax_prob");
scale = ws()->CreateTensor("_t_" + anchor() + "_focal_scale");
this->prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax_prob");
scale = ws()->CreateTensor("/mnt/" + anchor() + "/focal_scale");
scale->ReshapeLike(*this->prob);
if (input(0).template IsType<float>()) RunWithType<float>();
......@@ -116,8 +116,8 @@ void SparseSoftmaxFocalLossGradientOp<Context>::RunWithType() {
template <class Context>
void SparseSoftmaxFocalLossGradientOp<Context>::RunOnDevice() {
this->prob = ws()->GetTensor("_t_" + anchor() + "_softmax_prob");
scale = ws()->GetTensor("_t_" + anchor() + "_focal_scale");
this->prob = ws()->GetTensor("/mnt/" + anchor() + "/softmax_prob");
scale = ws()->GetTensor("/mnt/" + anchor() + "/focal_scale");
outer_dim = this->prob->count(0, axis);
inner_dim = this->prob->count(axis + 1);
output(0)->ReshapeLike(input(0));
......
......@@ -10,8 +10,8 @@ void GradientGenerateOp<Context>::RunWithType() {
if (output(i)->name() == "ignore") continue;
output(i)->ReshapeLike(input(i));
auto* dXdata = output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(output(0)->count(),
dragon_cast<T, float>(defaults[i]),
math::Set<T, Context>(output(0)->count(),
dragon_cast<T, float>(defaults[i]),
dXdata);
}
}
......@@ -37,9 +37,7 @@ void GradientGatherOp<Context>::RunWithType() {
TIndex count = output(0)->count();
for (int i = 1; i < indices.size(); i++) {
CHECK(output(0)->dims() == input(indices[i]).dims());
math::Add<T, Context>(count, dXdata,
input(indices[i]).template data<T, Context>(), dXdata);
// trick: force to release memory
math::Add<T, Context>(count, dXdata, input(indices[i]).template data<T, Context>(), dXdata);
input(indices[i]).Reset();
}
}
......@@ -47,8 +45,7 @@ void GradientGatherOp<Context>::RunWithType() {
template <class Context>
void GradientGatherOp<Context>::RunOnDevice() {
if (indices.size() == 0) return;
output(0)->ReshapeLike(input(indices[0]));
output(0)->Share(input(indices[0]));
ws()->CreateAvatar(output(0), &input(indices[0]));
if (input(indices[0]).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -63,8 +60,7 @@ NO_GRADIENT(GradientGather);
template <class Context>
void StopGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
output(0)->Share(input(0));
ws()->CreateAvatar(output(0), &input(0));
}
DEPLOY_CPU(StopGradient);
......@@ -74,4 +70,4 @@ DEPLOY_CUDA(StopGradient);
OPERATOR_SCHEMA(StopGradient).NumInputs(1).NumOutputs(1);
NO_GRADIENT(StopGradient);
} // namespace dragon
\ No newline at end of file
} // namespace dragon
......@@ -109,7 +109,7 @@ void ConcatGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -38,19 +38,9 @@ void CropOp<Context>::Setup() {
}
// make ends
if (ends.size() > 0) {
// static crop
CHECK_EQ(ends.size(), input(0).ndim())
<< "\nThe cropping is performed on " << ends.size() << " dimensions, "
<< "but the num of dimensions of input is " << input(0).ndim() << "."; \
// fix end if necessary
for (int i = 0; i < ends.size(); i++)
if (ends[i] == 0) ends[i] = input(0).dim(i);
} else {
if (shape.size() + shape_like.size() != 0) {
CHECK(shape.size() * shape_like.size() == 0)
<< "\nCan not set shape and shape_like both.";
CHECK(shape.size() + shape_like.size() != 0)
<< "\nMust set shape and shape_like either.";
ends.resize(input(0).ndim(), 0);
for (int i = 0; i < ends.size(); i++) {
// dynamic crop 1: keep unchanged
......@@ -73,6 +63,14 @@ void CropOp<Context>::Setup() {
ends[i] = starts[i] + like->dim(i);
}
}
} else {
// static crop
CHECK_EQ(ends.size(), input(0).ndim())
<< "\nThe cropping is performed on " << ends.size() << " dimensions, "
<< "but the num of dimensions of input is " << input(0).ndim() << ".";
// fix end if necessary
for (int i = 0; i < ends.size(); i++)
if (ends[i] == 0) ends[i] = input(0).dim(i);
}
// check starts and ends
......@@ -157,19 +155,9 @@ void CropGradientOp<Context>::Setup() {
}
// make ends
if (ends.size() > 0) {
// static crop
CHECK_EQ(ends.size(), input(0).ndim())
<< "\nThe cropping is performed on " << ends.size() << " dimensions, "
<< "but the num of dimensions of input is " << input(0).ndim() << "."; \
// fix end if necessary
for (int i = 0; i < ends.size(); i++)
if (ends[i] == 0) ends[i] = input(0).dim(i);
} else {
if (shape.size() + shape_like.size() != 0) {
CHECK(shape.size() * shape_like.size() == 0)
<< "\nCan not set shape and shape_like both.";
CHECK(shape.size() + shape_like.size() != 0)
<< "\nMust set shape and shape_like either.";
ends.resize(input(0).ndim(), 0);
for (int i = 0; i < ends.size(); i++) {
// dynamic crop 1: keep unchanged
......@@ -192,6 +180,14 @@ void CropGradientOp<Context>::Setup() {
ends[i] = starts[i] + like->dim(i);
}
}
} else {
// static crop
CHECK_EQ(ends.size(), input(0).ndim())
<< "\nThe cropping is performed on " << ends.size() << " dimensions, "
<< "but the num of dimensions of input is " << input(0).ndim() << "."; \
// fix end if necessary
for (int i = 0; i < ends.size(); i++)
if (ends[i] == 0) ends[i] = input(0).dim(i);
}
// check starts and ends
......
......@@ -33,7 +33,7 @@ void RandomPickOp<Context>::RunOnDevice() {
inner_dim = input(0).count(axis + 1);
output(0)->Reshape(output_dims);
pick_indices = ws()->CreateTensor("_t_" + anchor() + "_pick_indices");
pick_indices = ws()->CreateTensor("/mnt/" + anchor() + "/pick_indices");
pick_indices->Reshape(vector<TIndex>(1, max_samples));
if (input(0).template IsType<float>()) RunWithType<float>();
......@@ -68,7 +68,7 @@ void RandomPickGradientOp<Context>::RunWithType() {
template <class Context>
void RandomPickGradientOp<Context>::RunOnDevice() {
pick_indices = ws()->GetTensor("_t_" + anchor() + "_pick_indices");
pick_indices = ws()->GetTensor("/mnt/" + anchor() + "/pick_indices");
x_slice_dim = input(0).dim(axis);
y_slice_dim = pick_indices->count();
......
......@@ -108,7 +108,7 @@ void StackGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
output(i)->Replace(*dX);
ws()->CreateAvatar(output(i), dX);
break;
}
}
......
......@@ -27,9 +27,9 @@ void TransposeOp<Context>::RunOnDevice() {
<< "\nbut Tensor(" << input(0).name() << ")'s dims are "
<< input(0).dim_string();
vector<TIndex> output_dims;
order = ws()->CreateTensor("_t_" + anchor() + "_order");
old_steps = ws()->CreateTensor("_t_" + anchor() + "_old_steps");
new_steps = ws()->CreateTensor("_t_" + anchor() + "_new_steps");
order = ws()->CreateTensor("/mnt/" + anchor() + "/transpose_order");
old_steps = ws()->CreateTensor("/mnt/" + anchor() + "/transpose_old_steps");
new_steps = ws()->CreateTensor("/mnt/" + anchor() + "/transpose_new_steps");
order->Reshape(vector<TIndex>(1, perms.size()));
old_steps->Reshape(vector<TIndex>(1, perms.size()));
new_steps->Reshape(vector<TIndex>(1, perms.size()));
......@@ -76,9 +76,9 @@ void TransposeGradientOp<Context>::RunWithType() {
template <class Context>
void TransposeGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
order = ws()->GetTensor("_t_" + anchor() + "_order");
old_steps = ws()->GetTensor("_t_" + anchor() + "_old_steps");
new_steps = ws()->GetTensor("_t_" + anchor() + "_new_steps");
order = ws()->GetTensor("/mnt/" + anchor() + "/transpose_order");
old_steps = ws()->GetTensor("/mnt/" + anchor() + "/transpose_old_steps");
new_steps = ws()->GetTensor("/mnt/" + anchor() + "/transpose_new_steps");
if (input(0).template IsType<float>()) RunWithType<float>();
#ifdef WITH_CUDA_FP16
......
#include "operators/norm/batch_norm_op.h"
#include "core/workspace.h"
namespace dragon {
DEPLOY_CPU(FusedBatchNorm);
#ifdef WITH_CUDA
DEPLOY_CUDA(FusedBatchNorm);
#endif
OPERATOR_SCHEMA(FusedBatchNorm).NumInputs(5).NumOutputs(1);
template <class Context>
void FusedBatchNormGradientOp<Context>::ShareGradient() {
if (use_global_stats) {
if (output(0)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
ws()->CreateAvatar(output(0), dX);
}
} else {
if (output(0)->name() != "ignore" ||
output(1)->name() != "ignore" ||
output(2)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
ws()->CreateAvatar(output(0), dX);
}
}
}
DEPLOY_CPU(FusedBatchNormGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(FusedBatchNormGradient);
#endif
OPERATOR_SCHEMA(FusedBatchNormGradient).NumInputs(5).NumOutputs(3);
class GetFusedBatchNormGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetFusedBatchNormGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), I(2), I(3), GO(0)},
vector<string> {GI(0), GI(3), GI(4)});
}
};
REGISTER_GRADIENT(FusedBatchNorm, GetFusedBatchNormGradient);
} // namespace dragon
\ No newline at end of file
......@@ -15,7 +15,7 @@ void L2NormOp<Context>::RunWithType() {
buffer->Reshape(dims);
// normalize by inner_dim independently if not across it
norm = ws()->CreateTensor("_t_" + anchor() + "_l2norm_normalizer");
norm = ws()->CreateTensor("/mnt/" + anchor() + "/l2norm_normalizer");
dims = input(0).dims();
for (int i = axis; i < end_axis; i++) dims[i] = 1;
norm->Reshape(dims);
......@@ -95,7 +95,7 @@ void L2NormGradientOp<Context>::RunWithType() {
INIT_MULTIPLIER(multiplier, dim);
// normalize by inner_dim independently if not across it
norm = ws()->GetTensor("_t_" + anchor() + "_l2norm_normalizer");
norm = ws()->GetTensor("/mnt/" + anchor() + "/l2norm_normalizer");
buffer = ws()->GetBuffer();
vector<TIndex> dims = input(0).dims();
for (int i = 0; i < axis; i++) dims[i] = 1;
......@@ -121,32 +121,32 @@ void L2NormGradientOp<Context>::RunWithType() {
} else {
// compute \sum_{i} x_{i, j}dy_{i, j}
math::Mul<T, Context>(buffer->count(), Xdata, dYdata, Bdata);
math::Gemv<T, Context>(CblasTrans, dim, inner_dim,
1.0,
Bdata, DMuldata,
0.0,
math::Gemv<T, Context>(CblasTrans, dim, inner_dim,
1.0,
Bdata, DMuldata,
0.0,
BInnerdata);
// compute T1 = x[(\sum_{i} x_{i, j}dy_{i, j})]_{dim}
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, dim, inner_dim, 1,
1.0,
DMuldata, BInnerdata,
0.0,
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, dim, inner_dim, 1,
1.0,
DMuldata, BInnerdata,
0.0,
Bdata);
math::Mul<T, Context>(buffer->count(), Xdata, Bdata, dXdata);
// compute T2 = T1 / Normalizer^{2}
math::Pow<T, Context>(inner_dim, 2.0, Ndata, BInnerdata);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, dim, inner_dim, 1,
1.0,
DMuldata, BInnerdata,
0.0,
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, dim, inner_dim, 1,
1.0,
DMuldata, BInnerdata,
0.0,
Bdata);
math::Div<T, Context>(buffer->count(), dXdata, Bdata, dXdata);
// compute T3 = (dy - T2) / Normalizer
math::Sub<T, Context>(buffer->count(), dYdata, dXdata, dXdata);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, dim, inner_dim, 1,
1.0,
DMuldata, Ndata,
0.0,
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, dim, inner_dim, 1,
1.0,
DMuldata, Ndata,
0.0,
Bdata);
math::Div<T, Context>(buffer->count(), dXdata, Bdata, dXdata);
Ndata += inner_dim;
......
......@@ -69,7 +69,7 @@ void LSTMUnitGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
output(1)->ReshapeLike(input(1));
if (InputSize() != 5) {
zeros = ws()->CreateTensor("_t_zeros");
zeros = ws()->CreateTensor("/share/zeros");
if (zeros->count() < input(0).count())
zeros->ReshapeLike(input(0));
}
......
......@@ -9,7 +9,7 @@ void RMSPropUpdateOp<Context>::ComputeRunWithFloat() {
if (!history.get()) {
string slot = OperatorBase::GetSingleArg<string>("slot", "");
if (slot.empty()) history.reset(new Tensor());
else history.reset(ws()->CreateTensor("_t_" + output(0)->name() + "_" + slot));
else history.reset(ws()->CreateTensor("/mnt/" + name() + "/history"));
history->ReshapeLike(input(0));
}
lr = param("base_lr") * this->lr_mult;
......
......@@ -38,7 +38,7 @@ template <class Context>
void DenseConcatGradientOp<Context>::ElimateCorruption() {
Set<string> all_heads;
queue<int> safe_heads;
Tensor* head = ws()->GetTensor("_t_mirror_stage_head");
Tensor* head = ws()->GetTensor("/opt/mirror_stage/head");
string* head_data = head->mutable_data<string, CPUContext>();
for (int i = 0; i < head->count(); i++) all_heads.insert(head_data[i]);
......@@ -54,7 +54,7 @@ void DenseConcatGradientOp<Context>::ElimateCorruption() {
}
int idx = safe_heads.front();
safe_heads.pop();
Tensor* buffer = ws()->GetTensor("_t_mirror_stage_buffer_" + dragon_cast<string, int>(idx));
Tensor* buffer = ws()->GetTensor("/opt/mirror_stage/buffer_" + dragon_cast<string, int>(idx));
input(0).Move(buffer->memory());
head_data[idx] = input(0).name();
if (input(-2).template IsType<float>()) RestoreX1<float>();
......@@ -91,7 +91,7 @@ void DenseConcatGradientOp<Context>::ElimateCorruption() {
<< "\nadd WORKSPACE_MAX_CORRUPTED_SIZE for more powerful mirror stage ?";
int idx = safe_heads.front();
safe_heads.pop();
Tensor* buffer = ws()->GetTensor("_t_mirror_stage_buffer_" + dragon_cast<string, int>(idx));
Tensor* buffer = ws()->GetTensor("/opt/mirror_stage/buffer_" + dragon_cast<string, int>(idx));
output(i)->Move(buffer->memory());
head_data[idx] = output(i)->name();
}
......
......@@ -15,18 +15,18 @@ void LRNOp<Context>::AcrossRunWithType() {
template <class Context> template <typename T>
void LRNOp<Context>::SplitRunWithType() {
sqr_in = ws()->CreateTensor("_t_" + anchor() + "_sqr_in");
sqr_in = ws()->CreateTensor("/mnt/" + anchor() + "/sqr_in");
sqr_in->ReshapeLike(input(0));
sqr_in->Share(input(0));
prod_in = ws()->CreateTensor("_t_" + anchor() + "_prod_in");
prod_in = ws()->CreateTensor("/mnt/" + anchor() + "/prod_in");
prod_in->ReshapeLike(input(0));
prod_in->Share(input(0));
}
template <class Context> template <typename T>
void LRNOp<Context>::SquareRunWithType() {
sqr_out = ws()->CreateTensor("_t_" + anchor() + "_sqr_out");
sqr_out = ws()->CreateTensor("/mnt/" + anchor() + "/sqr_out");
if (!sqr_op) {
Argument power;
power.set_name("power"); power.set_f(2.0);
......@@ -43,7 +43,7 @@ void LRNOp<Context>::SquareRunWithType() {
template <class Context> template <typename T>
void LRNOp<Context>::PoolRunWithType() {
pool_out = ws()->CreateTensor("_t_" + anchor() + "_pool_out");
pool_out = ws()->CreateTensor("/mnt/" + anchor() + "/pool_out");
if (!pool_op) {
Argument ks, s, p, mode;
ks.set_name("kernel_size"); ks.add_ints(local_size);
......@@ -63,7 +63,7 @@ void LRNOp<Context>::PoolRunWithType() {
template <class Context> template <typename T>
void LRNOp<Context>::PowRunWithType() {
pow_out = ws()->CreateTensor("_t_" + anchor() + "_pow_out");
pow_out = ws()->CreateTensor("/mnt/" + anchor() + "/pow_out");
if (!pow_op) {
Argument scale, shift, power;
scale.set_name("scale"); scale.set_f(alpha);
......@@ -129,8 +129,8 @@ void LRNGradientOp<Context>::AcrossRunWithType() {
template <class Context> template <typename T>
void LRNGradientOp<Context>::ProdRunWithType() {
prod_in = ws()->GetTensor("_t_" + anchor() + "_prod_in");
pow_out = ws()->GetTensor("_t_" + anchor() + "_pow_out");
prod_in = ws()->GetTensor("/mnt/" + anchor() + "/prod_in");
pow_out = ws()->GetTensor("/mnt/" + anchor() + "/pow_out");
if (!prod_op) {
Argument operation;
operation.set_name("operation"); operation.set_s("PROD");
......@@ -150,7 +150,7 @@ void LRNGradientOp<Context>::ProdRunWithType() {
template <class Context> template <typename T>
void LRNGradientOp<Context>::PowRunWithType() {
pool_out = ws()->GetTensor("_t_" + anchor() + "_pool_out");
pool_out = ws()->GetTensor("/mnt/" + anchor() + "/pool_out");
if (!pow_op) {
Argument scale, shift, power;
scale.set_name("scale"); scale.set_f(alpha);
......@@ -171,7 +171,7 @@ void LRNGradientOp<Context>::PowRunWithType() {
template <class Context> template <typename T>
void LRNGradientOp<Context>::PoolRunWithType() {
sqr_out = ws()->GetTensor("_t_" + anchor() + "_sqr_out");
sqr_out = ws()->GetTensor("/mnt/" + anchor() + "/sqr_out");
if (!pool_op) {
Argument ks, s, p, mode;
ks.set_name("kernel_size"); ks.add_ints(local_size);
......@@ -179,7 +179,7 @@ void LRNGradientOp<Context>::PoolRunWithType() {
p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("PoolingGradient", "",
vector<string>({ sqr_out->name(),
vector<string>({ sqr_out->name(),
pool_out->name(),
pool_out->name() + "_grad" }),
vector<string>({ sqr_out->name() + "_grad" }),
......@@ -193,7 +193,7 @@ void LRNGradientOp<Context>::PoolRunWithType() {
template <class Context> template <typename T>
void LRNGradientOp<Context>::SquareRunWithType() {
sqr_in = ws()->GetTensor("_t_" + anchor() + "_sqr_in");
sqr_in = ws()->GetTensor("/mnt/" + anchor() + "/sqr_in");
if (!sqr_op) {
Argument power;
power.set_name("power"); power.set_f(2.0);
......
......@@ -7,7 +7,7 @@ namespace dragon {
template <class Context> template <typename T>
void Pooling2dOp<Context>::MAXRunWithType() {
mask = ws()->CreateTensor("_t_" + anchor() + "_pool_mask");
mask = ws()->CreateTensor("/mnt/" + anchor() + "/max_pool_mask");
mask->ReshapeLike(*output(0));
auto* Xdata = input(0).template data<T, Context>();
......@@ -122,7 +122,7 @@ OPERATOR_SCHEMA(Pooling2d).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T>
void Pooling2dGradientOp<Context>::MAXRunWithType() {
mask = ws()->GetTensor("_t_" + anchor() + "_pool_mask");
mask = ws()->GetTensor("/mnt/" + anchor() + "/max_pool_mask");
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
......
......@@ -17,10 +17,10 @@ void ROIAlignOp<Context>::RunWithType() {
template <class Context>
void ROIAlignOp<Context>::RunOnDevice() {
mask = ws()->CreateTensor("/mnt/" + anchor() + "/roi_align_mask");
vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w});
output(0)->Reshape(dims);
mask = ws()->CreateTensor("_t_" + anchor() + "_roi_align_mask");
output(0)->Reshape(dims);
mask->Reshape(dims);
if (input(0).template IsType<float>()) return RunWithType<float>();
......@@ -35,7 +35,7 @@ OPERATOR_SCHEMA(ROIAlign).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void ROIAlignGradientOp<Context>::RunWithType() {
kernel::ROIAlignGrad<T, Context>(spatial_scale,
kernel::ROIAlignGrad<T, Context>(spatial_scale,
pool_h, pool_w,
&input(-1),
&input(1),
......@@ -45,20 +45,14 @@ void ROIAlignGradientOp<Context>::RunWithType() {
template <class Context>
void ROIAlignGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
mask = ws()->GetTensor("/mnt/" + anchor() + "/roi_align_mask");
mask = ws()->GetTensor("_t_" + anchor() + "_roi_align_mask");
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
template <class Context>
void ROIAlignGradientOp<Context>::CleanResource() {
Operator<Context>::CleanResource();
ws()->ReleaseBuffer(mask, "Common", true);
}
DEPLOY_CPU(ROIAlignGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(ROIAlignGradient);
......
......@@ -17,10 +17,10 @@ void ROIPoolingOp<Context>::RunWithType() {
template <class Context>
void ROIPoolingOp<Context>::RunOnDevice() {
mask = ws()->CreateTensor("/mnt/" + anchor() + "/roi_pool_mask");
vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w});
output(0)->Reshape(dims);
mask = ws()->CreateTensor("_t_" + anchor() + "_roi_pool_mask");
mask->Reshape(dims);
if (input(0).template IsType<float>()) return RunWithType<float>();
......@@ -33,10 +33,9 @@ DEPLOY_CUDA(ROIPooling);
#endif
OPERATOR_SCHEMA(ROIPooling).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void ROIPoolingGradientOp<Context>::RunWithType() {
kernel::ROIPoolingGrad<T, Context>(spatial_scale,
kernel::ROIPoolingGrad<T, Context>(spatial_scale,
pool_h, pool_w,
&input(-1),
&input(1),
......@@ -46,20 +45,14 @@ void ROIPoolingGradientOp<Context>::RunWithType() {
template <class Context>
void ROIPoolingGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
mask = ws()->GetTensor("/mnt/" + anchor() + "/roi_pool_mask");
mask = ws()->GetTensor("_t_" + anchor() + "_roi_pool_mask");
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
template <class Context>
void ROIPoolingGradientOp<Context>::CleanResource() {
Operator<Context>::CleanResource();
ws()->ReleaseBuffer(mask, "Common", true);
}
DEPLOY_CPU(ROIPoolingGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(ROIPoolingGradient);
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!