Commit a7a7e4fc by Ting PAN

Fix the skipped algorithm finding in cached CUDNN convolution

Summary:
This commit enforces the algorithm finding even if the backward of filter or data
will not be executed. Otherwise, the empty algorithm will be encountered between
two cached operation with the same arguments and input shape.
1 parent 218796ed
Showing with 478 additions and 498 deletions
...@@ -18,6 +18,9 @@ dragon/core ...@@ -18,6 +18,9 @@ dragon/core
`class Operator <core/Operator.html>`_ `class Operator <core/Operator.html>`_
: The base operator class with context. : The base operator class with context.
`class OpSchema <core/OpSchema.html>`_
: Class to record the schema of operator.
`class Tensor <core/Tensor.html>`_ `class Tensor <core/Tensor.html>`_
: The base tensor class, manage memory or not. : The base tensor class, manage memory or not.
...@@ -37,6 +40,7 @@ dragon/core ...@@ -37,6 +40,7 @@ dragon/core
core/CUDAContext core/CUDAContext
core/Graph core/Graph
core/Operator core/Operator
core/OpSchema
core/Tensor core/Tensor
core/TypeMeta core/TypeMeta
core/UnifiedMemory core/UnifiedMemory
......
OpSchema
========
.. doxygenclass:: dragon::OpSchema
Constructors
------------
.. doxygenfunction:: dragon::OpSchema::OpSchema()
.. doxygenfunction:: dragon::OpSchema::OpSchema(const string &op_type, const string &file, const int line)
Public Functions
----------------
AllowInplace
############
.. doxygenfunction:: dragon::OpSchema::AllowInplace(set<pair<int, int>> inplace)
AllowInplace
############
.. doxygenfunction:: dragon::OpSchema::AllowInplace(std::function<bool(int, int)> inplace)
NumInputs
#########
.. doxygenfunction:: dragon::OpSchema::NumInputs(int n)
NumInputs
#########
.. doxygenfunction:: dragon::OpSchema::NumInputs(int min_num, int max_num)
NumOutputs
##########
.. doxygenfunction:: dragon::OpSchema::NumOutputs(int n)
NumOutputs
##########
.. doxygenfunction:: dragon::OpSchema::NumOutputs(int min_num, int max_num)
Verify
######
.. doxygenfunction:: dragon::OpSchema::Verify
.. raw:: html
<style>
h1:before {
content: "dragon::";
color: #103d3e;
}
</style>
...@@ -470,7 +470,6 @@ zero\_ ...@@ -470,7 +470,6 @@ zero\_
.. _torch.div(...): div.html .. _torch.div(...): div.html
.. _torch.eq(...): eq.html .. _torch.eq(...): eq.html
.. _torch.exp(...): exp.html .. _torch.exp(...): exp.html
.. _torch.expand(...): expand.html
.. _torch.floor(...): floor.html .. _torch.floor(...): floor.html
.. _torch.ge(...): ge.html .. _torch.ge(...): ge.html
.. _torch.gt(...): gt.html .. _torch.gt(...): gt.html
......
...@@ -79,12 +79,12 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { ...@@ -79,12 +79,12 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
Map<string, vec32_t> subgraph_indices; Map<string, vec32_t> subgraph_indices;
int opt = 3; // default: O3 int opt = 3; // default: O3
if (args().count("optimization")) opt = arg("optimization").i(); if (args().count("optimization")) opt = arg("optimization").i();
if (opt >= 1) def_v2 = graph_optimizer.PruneNodes(def); if (opt >= 1) def_v2 = graph_optimizer.EliminateUnused(def);
if (opt >= 2) graph_optimizer.AddInplace(def_v2, output_aliases_); if (opt >= 2) graph_optimizer.PlanInplace(def_v2, output_aliases_);
if (opt >= 3) { if (opt >= 3) {
if (phase() == "TRAIN") { if (phase() == "TRAIN") {
def_v2 = graph_optimizer.MirrorStage(def_v2, subgraph_indices); def_v2 = graph_optimizer.PlanCheckpoint(def_v2, subgraph_indices);
def_v2 = gradient_maker.Share(def_v2); def_v2 = gradient_maker.Optimize(def_v2);
} else { } else {
def_v2 = graph_optimizer.SimulateGC(def_v2); def_v2 = graph_optimizer.SimulateGC(def_v2);
} }
...@@ -98,8 +98,8 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) { ...@@ -98,8 +98,8 @@ Graph::Graph(const GraphDef& def, Workspace* ws) : GraphBase(def, ws) {
Map<string, vector<OperatorBase*>> subgraph; Map<string, vector<OperatorBase*>> subgraph;
for (const auto& it : subgraph_indices) { for (const auto& it : subgraph_indices) {
subgraph[it.first] = vector<OperatorBase*>(); subgraph[it.first] = vector<OperatorBase*>();
for (const auto& idx : subgraph_indices[it.first]) for (auto op_idx : subgraph_indices[it.first])
subgraph[it.first].push_back(cached_ops_[idx]); subgraph[it.first].push_back(cached_ops_[op_idx]);
} }
for (auto* op : cached_ops_) { for (auto* op : cached_ops_) {
op->set_subgraph(subgraph); op->set_subgraph(subgraph);
......
...@@ -4,40 +4,40 @@ ...@@ -4,40 +4,40 @@
namespace dragon { namespace dragon {
bool GraphGradientMaker::CheckGrad( bool GraphGradientMaker::CheckGrad(
const OperatorDef& op_def, const OperatorDef& op,
const Set<string>& targets, const Set<string>& targets,
vector<pair<string, int>>& gen_grads) { vector<pair<string, int>>& gen_grads) {
if (NoGradientRegistry()->Has(op_def.type())) { if (NoGradientRegistry()->Has(op.type())) {
return true; return true;
} }
bool maybe_skip = false; bool maybe_skip = false;
for (int i = 0; i < op_def.output_size(); ++i) { for (int i = 0; i < op.output_size(); ++i) {
const auto& output = op_def.output(i); const auto& out = op.output(i);
if (!inputs_to_grads_.count(output)) { if (!inputs_to_grads_.count(out)) {
maybe_skip = true; maybe_skip = true;
if (targets.count(output)) { if (targets.count(out)) {
gen_grads.push_back({output, i}); gen_grads.push_back({out, i});
inputs_to_grads_[output] = output + "_grad"; inputs_to_grads_[out] = out + "_grad";
} }
} }
} }
return maybe_skip && gen_grads.empty() && op_def.output_size() == 1; return maybe_skip && gen_grads.empty() && op.output_size() == 1;
} }
void GraphGradientMaker::Make( void GraphGradientMaker::Make(
const vector<OperatorDef*>& op_defs, const vector<OperatorDef*>& ops,
const vector<string>& targets, const vector<string>& targets,
const vector<string>& input_grads, const vector<string>& input_grads,
GraphDef& graph_def) { GraphDef& graph) {
Set<string> split_grads, targets_v2; Set<string> split_grads, targets_v2;
Map<string, int> inputs_count, grads_count; Map<string, int> inputs_count, grads_count;
// PLAY for the forward // PLAY for the forward
for (auto* op_def : op_defs) { for (auto* op : ops) {
if (NoGradientRegistry()->Has(op_def->type())) continue; if (NoGradientRegistry()->Has(op->type())) continue;
for (const auto& input : op_def->input()) { for (const auto& input : op->input()) {
bool input_in_outputs = false; bool input_in_outputs = false;
for (auto& output : op_def->output()) for (auto& output : op->output())
if (output == input) { if (output == input) {
input_in_outputs = true; input_in_outputs = true;
break; break;
...@@ -56,21 +56,21 @@ void GraphGradientMaker::Make( ...@@ -56,21 +56,21 @@ void GraphGradientMaker::Make(
} }
// PLAY for the backward // PLAY for the backward
for (int op_idx = (int)op_defs.size() - 1; op_idx >= 0; --op_idx) { for (int op_idx = (int)ops.size() - 1; op_idx >= 0; --op_idx) {
const OperatorDef& op_def = *op_defs[op_idx]; const auto& op = *ops[op_idx];
// Generate def by registered gradient maker // Generate def by registered gradient maker
vector<pair<string, int>> gen_grads; vector<pair<string, int>> gen_grads;
vector<string> grad_outputs; vector<string> grad_outputs;
bool is_skip = CheckGrad(op_def, targets_v2, gen_grads); bool is_skip = CheckGrad(op, targets_v2, gen_grads);
for (const auto& output : op_def.output()) { for (const auto& out : op.output()) {
string grad_output = ""; string grad_out = "";
const auto& it = inputs_to_grads_.find(output); const auto& it = inputs_to_grads_.find(out);
if (it != inputs_to_grads_.end()) grad_output = it->second; if (it != inputs_to_grads_.end()) grad_out = it->second;
grad_outputs.push_back(grad_output); grad_outputs.push_back(grad_out);
} }
auto pack = MakeGradientForOp(op_def, grad_outputs); auto pack = MakeGradientForOp(op, grad_outputs);
// Split and gather gradient for multi-used inputs // Split and gather gradient for multi-used inputs
vector<OperatorDef> gather_defs; vector<OperatorDef> gather_ops;
for (auto& grad_def : pack.grad_defs) { for (auto& grad_def : pack.grad_defs) {
if (!grad_def.has_name()) { if (!grad_def.has_name()) {
grad_def.set_name(GetOperatorName()); grad_def.set_name(GetOperatorName());
...@@ -93,38 +93,38 @@ void GraphGradientMaker::Make( ...@@ -93,38 +93,38 @@ void GraphGradientMaker::Make(
} }
if (output_in_inputs) continue; if (output_in_inputs) continue;
// Detect a split branch // Detect a split branch
const auto& original_name = op_def.input(original_index); const auto& original_name = op.input(original_index);
if (inputs_count[original_name] > 1) { if (inputs_count[original_name] > 1) {
auto grad_name_v2 = auto grad_name_v2 =
grad_name + "_autosplit_" + str::to(grads_count[grad_name]++); grad_name + "_autosplit_" + str::to(grads_count[grad_name]++);
if (!is_skip) split_grads.insert(grad_name_v2); if (!is_skip) split_grads.insert(grad_name_v2);
if (grads_count[grad_name] == inputs_count[original_name]) { if (grads_count[grad_name] == inputs_count[original_name]) {
auto gather_def = MakeOperatorDef( auto gather_op = MakeOperatorDef(
"GradientGather", "GradientGather",
GetOperatorName(), GetOperatorName(),
vector<string>({}), vector<string>({}),
vector<string>({grad_name})); vector<string>({grad_name}));
if (grad_def.has_device_option()) { if (grad_def.has_device_option()) {
gather_def.mutable_device_option()->CopyFrom( gather_op.mutable_device_option()->CopyFrom(
grad_def.device_option()); grad_def.device_option());
} }
for (int j = 0; j < grads_count[grad_name]; j++) { for (int j = 0; j < grads_count[grad_name]; j++) {
auto name = grad_name + "_autosplit_" + str::to(j); auto name = grad_name + "_autosplit_" + str::to(j);
if (split_grads.count(name)) gather_def.add_input(name); if (split_grads.count(name)) gather_op.add_input(name);
} }
gather_defs.push_back(gather_def); gather_ops.push_back(gather_op);
} }
*grad_def.mutable_output(i) = grad_name_v2; *grad_def.mutable_output(i) = grad_name_v2;
} }
} }
} }
// Add defs // Add gradient ops
if (!is_skip) { if (!is_skip) {
for (int i = 0; i < op_def.input_size(); ++i) { for (int i = 0; i < op.input_size(); ++i) {
inputs_to_grads_[op_def.input(i)] = pack.grad_inputs[i]; inputs_to_grads_[op.input(i)] = pack.grad_inputs[i];
} }
// Add def for ``GradientGenerateOp`` // Add ``GradientGenerateOp``
if (gen_grads.size() > 0) { if (gen_grads.size() > 0) {
vector<string> inputs, outputs; vector<string> inputs, outputs;
Argument arg_defaults; Argument arg_defaults;
...@@ -134,37 +134,36 @@ void GraphGradientMaker::Make( ...@@ -134,37 +134,36 @@ void GraphGradientMaker::Make(
outputs.emplace_back(gen_grad.first + "_grad"); outputs.emplace_back(gen_grad.first + "_grad");
arg_defaults.add_floats(pack.defaults[gen_grad.second]); arg_defaults.add_floats(pack.defaults[gen_grad.second]);
} }
auto generate_def = MakeOperatorDef( auto gen_op = MakeOperatorDef(
"GradientGenerate", "GradientGenerate",
GetOperatorName(), GetOperatorName(),
inputs, inputs,
outputs, outputs,
vector<Argument>({arg_defaults})); vector<Argument>({arg_defaults}));
if (op_def.has_device_option()) { if (op.has_device_option()) {
generate_def.mutable_device_option()->CopyFrom( gen_op.mutable_device_option()->CopyFrom(op.device_option());
op_def.device_option());
} }
graph_def.add_op()->CopyFrom(generate_def); graph.add_op()->CopyFrom(gen_op);
} }
// Add def for ``GenerateOp`` // Add ``GradientOp``
for (const auto& grad_def : pack.grad_defs) { for (const auto& grad_def : pack.grad_defs) {
graph_def.add_op()->CopyFrom(grad_def); graph.add_op()->CopyFrom(grad_def);
} }
} }
// Add def for ``GradientGatherOp`` // Add ``GradientGatherOp``
for (const auto& gather_def : gather_defs) { for (const auto& gather_op : gather_ops) {
graph_def.add_op()->CopyFrom(gather_def); graph.add_op()->CopyFrom(gather_op);
} }
} }
} }
GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { GraphDef GraphGradientMaker::Optimize(const GraphDef& graph) {
Set<int> invalid_ops; Set<int> invalid_ops;
Map<string, int> ref_count; Map<string, int> ref_count;
Map<string, pair<int, string>> gather_map; Map<string, pair<int, string>> gather_map;
for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = input_def.op(op_idx); const auto& op = graph.op(op_idx);
if (!str::find(op.type(), "Gradient")) continue; if (!str::find(op.type(), "Gradient")) continue;
// Flag the gathering gradients // Flag the gathering gradients
if (op.type() == "GradientGather") { if (op.type() == "GradientGather") {
...@@ -195,17 +194,17 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -195,17 +194,17 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
// Decompose the <GradientGather> into <GradientAdd> // Decompose the <GradientGather> into <GradientAdd>
// This trick accumulates the split to target right after computing, // This trick accumulates the split to target right after computing,
// which helps to reduce the total number of buffers. // which helps to reduce the total number of buffers.
GraphDef output_def(input_def); auto graph_v2(graph);
output_def.clear_op(); graph_v2.clear_op();
for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
if (invalid_ops.count(op_idx)) continue; if (invalid_ops.count(op_idx)) continue;
const auto& op = input_def.op(op_idx); const auto& op = graph.op(op_idx);
output_def.add_op()->CopyFrom(op); graph_v2.add_op()->CopyFrom(op);
if (!str::find(op.type(), "Gradient")) continue; if (!str::find(op.type(), "Gradient")) continue;
for (const auto& output : op.output()) { for (const auto& output : op.output()) {
const auto& find_iter = gather_map.find(output); const auto& find_iter = gather_map.find(output);
if (find_iter != gather_map.end()) { if (find_iter != gather_map.end()) {
const auto& gather_op = input_def.op(find_iter->second.first); const auto& gather_op = graph.op(find_iter->second.first);
auto add_op(gather_op); auto add_op(gather_op);
add_op.clear_input(); add_op.clear_input();
if (output != find_iter->second.second) { if (output != find_iter->second.second) {
...@@ -216,7 +215,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -216,7 +215,7 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
if (ref_iter != ref_count.end()) ref_iter->second++; if (ref_iter != ref_count.end()) ref_iter->second++;
} }
add_op.add_input(output); add_op.add_input(output);
output_def.add_op()->CopyFrom(add_op); graph_v2.add_op()->CopyFrom(add_op);
} }
} }
} }
...@@ -242,8 +241,8 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -242,8 +241,8 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
} }
}; };
for (int op_idx = 0; op_idx < output_def.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < graph_v2.op_size(); ++op_idx) {
auto* op = output_def.mutable_op(op_idx); auto* op = graph_v2.mutable_op(op_idx);
// Ignore the non-gradient ops // Ignore the non-gradient ops
if (!str::find(op->type(), "Gradient")) continue; if (!str::find(op->type(), "Gradient")) continue;
// Check if output is an alias of input // Check if output is an alias of input
...@@ -262,45 +261,44 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -262,45 +261,44 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
vector<string> dead_buffers; vector<string> dead_buffers;
// Rewrite input gradients // Rewrite input gradients
for (int i = 0; i < op->input_size(); ++i) { for (int i = 0; i < op->input_size(); ++i) {
const string& input = op->input(i); const string& in = op->input(i);
if (ref_count.count(input) > 0) { if (ref_count.count(in) > 0) {
ref_count[input] -= 1; // Decref ref_count[in] -= 1; // Decref
if (grad_to_buffer.count(input) == 0) continue; if (grad_to_buffer.count(in) == 0) continue;
string new_input = grad_to_buffer[input]; string in_v2 = grad_to_buffer[in];
if (ref_count[input] == 0) { if (ref_count[in] == 0) {
dead_buffers.emplace_back(new_input); dead_buffers.emplace_back(in_v2);
} }
*op->mutable_input(i) = new_input; *op->mutable_input(i) = in_v2;
} }
} }
// Rewrite output gradients // Rewrite output gradients
for (int i = 0; i < op->output_size(); ++i) { for (int i = 0; i < op->output_size(); ++i) {
if (str::startswith(op->type(), "Python")) continue; if (str::startswith(op->type(), "Python")) continue;
const string& output = op->output(i); const string& out = op->output(i);
if (output.empty() || str::startswith(output, "/share/buffer")) continue; if (out.empty() || str::startswith(out, "/share/buffer")) continue;
if (empty_grads_.count(output) > 0) { if (empty_grads_.count(out) > 0) {
*op->mutable_output(i) = ""; *op->mutable_output(i) = "";
continue; continue;
} }
// Protection for leafs // Protection for leafs
if (ref_count.count(output) == 0) continue; if (ref_count.count(out) == 0) continue;
// Protection for sources and leafs // Protection for sources and leafs
if (retained_grads_.count(output) > 0) continue; if (retained_grads_.count(out) > 0) continue;
string new_output = output; string out_v2 = out;
if (inplace_flags[i] >= 0) { if (inplace_flags[i] >= 0) {
new_output = op->input(inplace_flags[i]); out_v2 = op->input(inplace_flags[i]);
} else { } else {
grad_to_buffer[output] = new_output = get_buffer(); grad_to_buffer[out] = out_v2 = get_buffer();
} }
*op->mutable_output(i) = new_output; *op->mutable_output(i) = out_v2;
} }
// Update the pool // Update the pool
for (auto& buffer : dead_buffers) { for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer); pool.emplace_back(buffer);
} }
} }
return graph_v2;
return output_def;
} }
} // namespace dragon } // namespace dragon
...@@ -19,15 +19,15 @@ namespace dragon { ...@@ -19,15 +19,15 @@ namespace dragon {
class DRAGON_API GraphGradientMaker { class DRAGON_API GraphGradientMaker {
public: public:
/*! \brief Generate graph def from the op defs */ /*! \brief Generate graph from the executed ops */
void Make( void Make(
const vector<OperatorDef*>& op_defs, const vector<OperatorDef*>& ops,
const vector<string>& targets, const vector<string>& targets,
const vector<string>& input_grads, const vector<string>& input_grads,
GraphDef& graph_def); GraphDef& graph);
/*! \brief Rewrite graph def to share the intermediate grads */ /*! \brief Eliminate the unused and make sharing of outputs */
GraphDef Share(const GraphDef& input_def); GraphDef Optimize(const GraphDef& graph);
/*! \brief Add an empty gradient */ /*! \brief Add an empty gradient */
void add_empty_grad(const string& name) { void add_empty_grad(const string& name) {
...@@ -47,14 +47,14 @@ class DRAGON_API GraphGradientMaker { ...@@ -47,14 +47,14 @@ class DRAGON_API GraphGradientMaker {
private: private:
/*! \brief Check the missing grads */ /*! \brief Check the missing grads */
bool CheckGrad( bool CheckGrad(
const OperatorDef& op_def, const OperatorDef& op,
const Set<string>& targets, const Set<string>& targets,
vector<pair<string, int>>& gen_grads); vector<pair<string, int>>& gen_grads);
/*! \brief Return a dummy operator name */ /*! \brief Return a dummy operator name */
string GetOperatorName() { string GetOperatorName() {
if (op_prefix_.empty()) return "GradientOp"; if (op_prefix_.empty()) return "GradientOp";
return op_prefix_ + str::to(op_index_++); return op_prefix_ + str::to(op_idx_++);
} }
/*! \brief The mapping from input to grad */ /*! \brief The mapping from input to grad */
...@@ -70,7 +70,7 @@ class DRAGON_API GraphGradientMaker { ...@@ -70,7 +70,7 @@ class DRAGON_API GraphGradientMaker {
string op_prefix_; string op_prefix_;
/*! \brief The counter of op name */ /*! \brief The counter of op name */
int64_t op_index_ = 0; int64_t op_idx_ = 0;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -7,140 +7,141 @@ ...@@ -7,140 +7,141 @@
namespace dragon { namespace dragon {
void GraphOptimizer::BuildDAG(const GraphDef& input_def) { void GraphOptimizer::BuildDAG(const GraphDef& graph) {
dag_.clear(); nodes_.clear();
colored_.clear();
reference_count_.clear(); reference_count_.clear();
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < graph.op_size(); ++i) {
const auto& op = input_def.op(i); const auto& op = graph.op(i);
for (const auto& u : op.input()) { for (const auto& in : op.input()) {
reference_count_[u] += 1; reference_count_[in] += 1;
}
for (const auto& out : op.output()) {
if (op.input().empty()) {
nodes_[""].childs.push_back(out);
nodes_[out].parents.push_back("");
} else {
for (const auto& in : op.input()) {
nodes_[in].childs.push_back(out);
nodes_[out].parents.push_back(in);
} }
for (const auto& v : op.output()) {
vector<string> u_set(op.input().begin(), op.input().end());
if (u_set.empty()) u_set.resize(op.output_size());
for (const auto& u : u_set) {
dag_[v].parents.push_back(u);
dag_[u].childs.push_back(v);
dag_[v].op_idx = i;
} }
dag_[v].op_def = op; nodes_[out].op_idx = i;
nodes_[out].op_def = op;
} }
} }
} }
GraphDef GraphOptimizer::PruneNodes(const GraphDef& input_def) { GraphDef GraphOptimizer::EliminateUnused(const GraphDef& graph) {
// Initialization // Initialization
BuildDAG(input_def); BuildDAG(graph);
used_.clear();
// Backward pass from targets // Eliminate the unused nodes
for (const auto& target : input_def.output()) { for (const auto& out : graph.output()) {
if (colored_[target]) continue; EliminateUnusedNode(out);
BackwardPrunePass(target);
} }
for (const auto& grad_info : graph.grad_info()) {
for (const auto& grad_info : input_def.grad_info()) { const auto grad_y = grad_info.y() + "_grad";
const auto u = grad_info.y() + "_grad";
for (const auto& x : grad_info.xs()) { for (const auto& x : grad_info.xs()) {
visited_.clear(); visited_.clear();
ForwardPrunePass(u, x + "_grad", std::deque<string>({u})); EliminateUnusedNode(grad_y, x + "_grad");
} }
} }
// Select all colored operators // Select the used operators
set<int> selected_op_indices; set<int> selected_op_indices;
for (auto it : colored_) { for (auto it : used_) {
if (dag_[it.first].op_idx == -1) continue; if (nodes_[it.first].op_idx == -1) continue;
selected_op_indices.insert(dag_[it.first].op_idx); selected_op_indices.insert(nodes_[it.first].op_idx);
} }
// Remove the tensors that can not be produced // Prepare the registered placeholders
Set<string> outputs; Set<string> outputs;
for (const auto& name : ws_->tensors()) { for (const auto& name : ws_->tensors()) {
outputs.insert(name); outputs.insert(name);
} }
// Generate the final op sequence // Rewrite graph
map<int, OperatorDef> final_sequence; GraphDef graph_v2(graph);
graph_v2.clear_op();
for (auto op_idx : selected_op_indices) { for (auto op_idx : selected_op_indices) {
const auto& op = input_def.op(op_idx); const auto& op = graph.op(op_idx);
auto new_op(input_def.op(op_idx)); auto* op_v2 = graph_v2.add_op();
op_v2->CopyFrom(op);
// Rewrite inputs // Rewrite inputs
for (int i = 0; i < op.input_size(); ++i) { for (int i = 0; i < op.input_size(); ++i) {
const auto& input = op.input(i); const auto& in = op.input(i);
if (!colored_[input] || outputs.count(input) == 0) if (!used_[in] || outputs.count(in) == 0) {
*new_op.mutable_input(i) = ""; *op_v2->mutable_input(i) = "";
}
} }
// Rewrite outputs // Rewrite outputs
for (int i = 0; i < op.output_size(); ++i) { for (int i = 0; i < op.output_size(); ++i) {
const auto& output = op.output(i); const auto& out = op.output(i);
if (!colored_[output]) { if (!used_[out]) {
*new_op.mutable_output(i) = ""; *op_v2->mutable_output(i) = "";
} else { } else {
outputs.insert(output); outputs.insert(out);
} }
} }
// Rewrite hand-craft cases // Rewrite hand-craft cases
if (op.type() == "AffineGradient") { if (op.type() == "AffineGradient") {
if (new_op.output(1).empty()) *new_op.mutable_input(0) = ""; if (op_v2->output(1).empty()) *op_v2->mutable_input(0) = "";
} else if (op.type() == "MulGradient") { } else if (op.type() == "MulGradient") {
if (new_op.output(0).empty()) *new_op.mutable_input(1) = ""; if (op_v2->output(0).empty()) *op_v2->mutable_input(1) = "";
if (new_op.output(1).empty()) *new_op.mutable_input(0) = ""; if (op_v2->output(1).empty()) *op_v2->mutable_input(0) = "";
} else if (op.type() == "DivGradient") { } else if (op.type() == "DivGradient") {
if (new_op.output(1).empty()) { if (op_v2->output(1).empty()) {
*new_op.mutable_input(0) = ""; *op_v2->mutable_input(0) = "";
if (new_op.output(0).empty()) *new_op.mutable_input(1) = ""; if (op_v2->output(0).empty()) *op_v2->mutable_input(1) = "";
} }
} }
// Push into the final sequence
final_sequence[op_idx].CopyFrom(new_op);
} }
return graph_v2;
// Done!
GraphDef output_def(input_def);
output_def.clear_op();
for (auto it : final_sequence)
output_def.add_op()->CopyFrom(it.second);
return output_def;
} }
void GraphOptimizer::AddInplace( void GraphOptimizer::PlanInplace(
const GraphDef& input_def, const GraphDef& graph,
Map<string, Set<string>>& output_aliases) { Map<string, Set<string>>& output_aliases) {
// Initialization // Initialization
BuildDAG(input_def); BuildDAG(graph);
// Generate runtime aliases map // Generate aliases map to apply in-place
for (auto& u_iter : reference_count_) { for (const auto& iter : reference_count_) {
if (u_iter.second == 1 && !u_iter.first.empty() && const auto& in = iter.first;
dag_[u_iter.first].childs.size() > 0) { if (iter.second == 1 && !in.empty() && nodes_[in].childs.size() > 0) {
const auto& u = u_iter.first; const auto& op = nodes_[nodes_[in].childs[0]].op_def;
const auto& v0 = dag_[u].childs[0]; const auto* schema = OpSchemaRegistry::Schema(op.type());
const auto& op_def = dag_[v0].op_def; for (int i = 0; i < op.input_size(); ++i) {
const auto* op_schema = OpSchemaRegistry::Schema(op_def.type()); if (op.input(i) == in) {
for (int i = 0; i < op_def.input_size(); ++i) for (int j = 0; j < op.output_size(); ++j) {
for (int j = 0; j < op_def.output_size(); ++j) if (schema->CheckInplace(i, j)) {
if (op_schema->CheckInplace != nullptr && op_def.input(i) == u && output_aliases[op.output(j)].insert(in);
op_schema->CheckInplace(i, j)) }
output_aliases[op_def.output(j)].insert(u); }
}
}
} }
} }
} }
GraphDef GraphOptimizer::MirrorStage( GraphDef GraphOptimizer::PlanCheckpoint(
const GraphDef& input_def, const GraphDef& graph,
Map<string, vec32_t>& op_indices) { Map<string, vec32_t>& subgraph_indices) {
GraphDef output_def(input_def); GraphDef graph_v2(graph);
Map<string, set<int>> fake_op_indices; Map<string, set<int>> op_indices;
Map<string, string> rename_map; Map<string, string> rename_map;
Map<string, int> versions; Map<string, int> versions;
// Check mirror stage // Check the mirror stage setting
for (const auto& op : input_def.op()) { for (const auto& op : graph.op()) {
if (str::find(op.type(), "Gradient")) continue; if (str::find(op.type(), "Gradient")) continue;
bool mirror_stage = false; bool mirror_stage = false;
for (auto& arg : op.arg()) for (auto& arg : op.arg()) {
if (arg.name() == "mirror_stage") mirror_stage |= (bool)arg.i(); if (arg.name() == "mirror_stage") {
mirror_stage |= (bool)arg.i();
}
}
if (mirror_stage) { if (mirror_stage) {
// We only assume X(0) can be recomputed // We only assume X(0) can be recomputed
rename_map[op.input(0)] = "placeholder"; rename_map[op.input(0)] = "placeholder";
...@@ -149,24 +150,25 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -149,24 +150,25 @@ GraphDef GraphOptimizer::MirrorStage(
// Allocate the temporal buffers // Allocate the temporal buffers
string v2_name, version_name; string v2_name, version_name;
for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = input_def.op(op_idx); const auto& op = graph.op(op_idx);
auto* new_op = output_def.mutable_op(op_idx); auto* op_v2 = graph_v2.mutable_op(op_idx);
vector<string> used_buffers; vector<string> used_buffers;
for (int i = 0; i < op.input_size(); ++i) { for (int i = 0; i < op.input_size(); ++i) {
const auto& it = rename_map.find(op.input(i)); const auto& it = rename_map.find(op.input(i));
if (it != rename_map.end() && it->second != "placeholder") { if (it != rename_map.end() && it->second != "placeholder") {
*new_op->mutable_input(i) = it->second; *op_v2->mutable_input(i) = it->second;
used_buffers.emplace_back(it->second); used_buffers.emplace_back(it->second);
} }
} }
for (int i = 0; i < op.output_size(); ++i) { for (int i = 0; i < op.output_size(); ++i) {
bool inplace_flag = false; bool inplace_flag = false;
for (const auto& u : op.input()) for (const auto& in : op.input()) {
if (u == op.output(i)) inplace_flag = true; if (in == op.output(i)) inplace_flag = true;
}
if (rename_map.count(op.output(i))) { if (rename_map.count(op.output(i))) {
if (inplace_flag && rename_map[op.output(i)] != "placeholder") { if (inplace_flag && rename_map[op.output(i)] != "placeholder") {
*new_op->mutable_output(i) = rename_map[op.output(i)]; *op_v2->mutable_output(i) = rename_map[op.output(i)];
continue; continue;
} }
for (int j = 0; j < GRAPH_TEMPORAL_OUTPUT_MAX_SIZE; ++j) { for (int j = 0; j < GRAPH_TEMPORAL_OUTPUT_MAX_SIZE; ++j) {
...@@ -183,45 +185,42 @@ GraphDef GraphOptimizer::MirrorStage( ...@@ -183,45 +185,42 @@ GraphDef GraphOptimizer::MirrorStage(
CHECK(!v2_name.empty()) << "\nNo enough buffers for outputs."; CHECK(!v2_name.empty()) << "\nNo enough buffers for outputs.";
ws_->CreateTensor(v2_name)->set_version(0); ws_->CreateTensor(v2_name)->set_version(0);
version_name = "/ver:" + str::to(versions[v2_name]++); version_name = "/ver:" + str::to(versions[v2_name]++);
*new_op->mutable_output(i) = rename_map[op.output(i)] = *op_v2->mutable_output(i) = rename_map[op.output(i)] =
v2_name + version_name; v2_name + version_name;
} }
} }
} }
// Plan the minimum recomputing ops for temporal buffers // Determine the recomputing ops for temporal buffers
for (int i = 0; i < input_def.op_size(); ++i) { for (int i = 0; i < graph.op_size(); ++i) {
const auto& input_op = input_def.op(i); const auto &op = graph.op(i), &op_v2 = graph_v2.op(i);
const auto& output_op = output_def.op(i); set<int> recomputing_ops = {i};
for (int j = 0; j < op.input_size(); ++j) {
/* if (op.input(j) != op_v2.input(j)) {
* DP(v) = {DP(u) if input(u) != output(u) else {}} + {i} for (auto op_idx : op_indices[op.input(j)]) {
*/ recomputing_ops.insert(op_idx);
}
set<int> minimum_ops = {i};
for (int j = 0; j < input_op.input_size(); ++j) {
if (input_op.input(j) != output_op.input(j)) {
for (auto idx : fake_op_indices[input_op.input(j)])
minimum_ops.insert(idx);
} }
} }
for (const auto& output : input_op.output()) { for (const auto& out : op.output()) {
for (auto idx : minimum_ops) for (auto op_idx : recomputing_ops) {
fake_op_indices[output].insert(idx); op_indices[out].insert(op_idx);
}
} }
} }
// Bind to the renamed tensors // Bind to the renamed tensors
for (const auto& it : rename_map) { for (const auto& it : rename_map) {
for (auto op_idx : fake_op_indices[it.first]) for (auto op_idx : op_indices[it.first]) {
op_indices[it.second].push_back(op_idx); subgraph_indices[it.second].push_back(op_idx);
}
} }
// Done! // Done
return output_def; return graph_v2;
} }
GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) { GraphDef GraphOptimizer::SimulateGC(const GraphDef& graph) {
Set<string> blacklist = {""}; Set<string> blacklist = {""};
Map<string, int> ref_count; Map<string, int> ref_count;
Map<string, string> rename_map; Map<string, string> rename_map;
...@@ -241,42 +240,39 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) { ...@@ -241,42 +240,39 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) {
}; };
// Count the references // Count the references
for (const auto& op : input_def.op()) { for (const auto& op : graph.op()) {
for (const auto& input : op.input()) for (const auto& in : op.input()) {
ref_count[input] += 1; ref_count[in] += 1;
} }
// We should preserve the targets
for (auto& e : input_def.output()) {
blacklist.insert(e);
} }
// Rewritten the inputs and outputs // Preserve the graph outputs
auto output_def(input_def); for (auto& out : graph.output()) {
for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) { blacklist.insert(out);
const auto& op = input_def.op(op_idx); }
auto* new_op = output_def.mutable_op(op_idx);
// Rewrite the inputs and outputs
auto graph_v2(graph);
for (int op_idx = 0; op_idx < graph.op_size(); ++op_idx) {
const auto& op = graph.op(op_idx);
auto* op_v2 = graph_v2.mutable_op(op_idx);
// Ignore the init ops // Ignore the init ops
if (op.input_size() == 0) continue; if (op.input_size() == 0) continue;
// We need to collect the dead buffers.
// We need to collect the dead buffers // Reuse them when current operator is done.
// Reuse them when current operator is done
vector<string> dead_buffers; vector<string> dead_buffers;
// Rewrite inputs // Rewrite inputs
for (int i = 0; i < op.input_size(); ++i) { for (int i = 0; i < op.input_size(); ++i) {
const auto& name = op.input(i); const auto& name = op.input(i);
if (rename_map.count(name)) { if (rename_map.count(name)) {
*new_op->mutable_input(i) = rename_map[name]; *op_v2->mutable_input(i) = rename_map[name];
} }
ref_count[name]--; ref_count[name]--;
if (ref_count[name] == 0 && if (ref_count[name] == 0 &&
str::startswith(new_op->input(i), "/share/buffer/output:")) { str::startswith(op_v2->input(i), "/share/buffer/output:")) {
dead_buffers.push_back(new_op->input(i)); dead_buffers.push_back(op_v2->input(i));
} }
} }
// Rewrite outputs // Rewrite outputs
if (!star_ops.count(op.type())) { if (!star_ops.count(op.type())) {
for (int i = 0; i < op.output_size(); ++i) { for (int i = 0; i < op.output_size(); ++i) {
...@@ -286,55 +282,49 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) { ...@@ -286,55 +282,49 @@ GraphDef GraphOptimizer::SimulateGC(const GraphDef& input_def) {
for (const auto& input : op.input()) for (const auto& input : op.input())
if (name == input) inplace_flag = true; if (name == input) inplace_flag = true;
if (inplace_flag) { if (inplace_flag) {
*new_op->mutable_output(i) = new_op->input(i); *op_v2->mutable_output(i) = op_v2->input(i);
} else { } else {
rename_map[name] = *new_op->mutable_output(i) = get_buffer(); rename_map[name] = *op_v2->mutable_output(i) = get_buffer();
} }
} }
} }
// Update the pool // Update the pool
for (auto& buffer : dead_buffers) { for (auto& buffer : dead_buffers) {
pool.emplace_back(buffer); pool.emplace_back(buffer);
} }
} }
return graph_v2;
return output_def;
} }
void GraphOptimizer::ForwardPrunePass( void GraphOptimizer::EliminateUnusedNode(
const string& u, const string& source,
const string& leaf, const string& sink) {
const std::deque<string>& path) { if (visited_.count(source)) return;
if (visited_.count(u)) { visited_[source] = false;
if (visited_[u]) { for (const auto& next : nodes_[source].childs) {
for (const auto& node : path) { if (next == sink) {
visited_[node] = colored_[node] = true; visited_[next] = used_[next] = true;
} visited_[source] = used_[source] = true;
}
return; return;
} }
visited_[u] = false; EliminateUnusedNode(next, sink);
for (int i = 0; i < dag_[u].childs.size(); ++i) { if (visited_[next]) {
auto v = dag_[u].childs[i]; visited_[source] = used_[source] = true;
auto new_path(path);
new_path.push_back(v);
if (v == leaf) {
for (const auto& node : new_path) {
visited_[node] = colored_[node] = true;
} }
return;
}
ForwardPrunePass(v, leaf, new_path);
} }
} }
void GraphOptimizer::BackwardPrunePass(const string& v) { void GraphOptimizer::EliminateUnusedNode(const string& sink) {
colored_[v] = true; std::queue<const string*> q;
for (int i = 0; i < dag_[v].parents.size(); ++i) { q.push(&sink);
auto u = dag_[v].parents[i]; while (!q.empty()) {
if (colored_.count(u)) continue; const auto& source = *q.front();
BackwardPrunePass(u); q.pop();
used_[source] = true;
for (const auto& last : nodes_[source].parents) {
if (used_.count(last)) continue;
q.push(&last);
}
} }
} }
......
...@@ -32,45 +32,42 @@ class GraphOptimizer { ...@@ -32,45 +32,42 @@ class GraphOptimizer {
/*! \brief Default constructor */ /*! \brief Default constructor */
GraphOptimizer(Workspace* ws) : ws_(ws) {} GraphOptimizer(Workspace* ws) : ws_(ws) {}
/*! \brief Build the DAG resources for given def */ /*! \brief Build the DAG */
void BuildDAG(const GraphDef& input_def); void BuildDAG(const GraphDef& graph);
/*! \brief Prune the redundant nodes (-O1) */ /*! \brief Eliminate the unused outputs and operators */
GraphDef PruneNodes(const GraphDef& input_def); GraphDef EliminateUnused(const GraphDef& graph);
/*! \brief Add the inplace for outputs (-O2) */ /*! \brief Plan the inplace for inputs */
void AddInplace( void PlanInplace(
const GraphDef& input_def, const GraphDef& graph,
Map<string, Set<string>>& output_aliases); Map<string, Set<string>>& output_aliases);
/*! \brief Plan the recomputing for inputs (-O3) */ /*! \brief Plan the checkpoint for inputs */
GraphDef MirrorStage( GraphDef PlanCheckpoint(
const GraphDef& input_def, const GraphDef& graph,
Map<string, vec32_t>& op_indices); Map<string, vec32_t>& subgraph_indices);
/*! \brief Allocate the buffer for outputs (-O3) */ /*! \brief Allocate the shared buffer for outputs */
GraphDef SimulateGC(const GraphDef& input_def); GraphDef SimulateGC(const GraphDef& graph);
protected: protected:
/*! \brief Pass from gradients to remove unused nodes */ /*! \brief Remote the unused nodes from a sink to all sources */
void ForwardPrunePass( void EliminateUnusedNode(const string& sink);
const string& u,
const string& leaf,
const std::deque<string>& path);
/*! \brief Pass from targets to remove unused nodes */ /*! \brief Remote the unused nodes from a source to a sink */
void BackwardPrunePass(const string& v); void EliminateUnusedNode(const string& source, const string& sink);
/* \brief Store the workspace of parent graph */ /* \brief The graph workspace */
Workspace* ws_; Workspace* ws_;
/* \brief Store the DAG */ /* \brief The graph nodes */
Map<string, Node> dag_; Map<string, Node> nodes_;
/* \brief Store the traversal flags */ /* \brief The traversal flags */
Map<string, bool> visited_, colored_; Map<string, bool> visited_, used_;
/* \brief Store the count of references */ /* \brief The reference count */
Map<string, int> reference_count_; Map<string, int> reference_count_;
private: private:
......
...@@ -173,10 +173,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) { ...@@ -173,10 +173,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) {
OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) { OperatorBase* NewOperator(const OperatorDef& def, Workspace* ws) {
auto* schema = OpSchemaRegistry::Schema(def.type()); auto* schema = OpSchemaRegistry::Schema(def.type());
if (schema != nullptr) { if (schema != nullptr) CHECK(schema->Verify(def));
CHECK(schema->Verify(def))
<< "\nOperator failed to pass the schema checking.";
}
OperatorDef mutable_def(def); OperatorDef mutable_def(def);
// Heuristically make each random seed slightly different // Heuristically make each random seed slightly different
static unsigned int seed_offset = 0; static unsigned int seed_offset = 0;
......
...@@ -14,7 +14,6 @@ bool OpSchema::Verify(const OperatorDef& def) const { ...@@ -14,7 +14,6 @@ bool OpSchema::Verify(const OperatorDef& def) const {
<< " is not in range [min=" << min_output_ << " is not in range [min=" << min_output_
<< ", max=" << max_output_ << "]"; << ", max=" << max_output_ << "]";
} }
if (CheckInplace != nullptr) {
for (int i = 0; i < def.input_size(); ++i) { for (int i = 0; i < def.input_size(); ++i) {
if (def.input(i).empty()) continue; if (def.input(i).empty()) continue;
for (int j = 0; j < def.output_size(); ++j) { for (int j = 0; j < def.output_size(); ++j) {
...@@ -25,7 +24,6 @@ bool OpSchema::Verify(const OperatorDef& def) const { ...@@ -25,7 +24,6 @@ bool OpSchema::Verify(const OperatorDef& def) const {
} }
} }
} }
}
return true; return true;
} }
...@@ -49,7 +47,12 @@ OpSchema& OpSchema::NumOutputs(int min_num, int max_num) { ...@@ -49,7 +47,12 @@ OpSchema& OpSchema::NumOutputs(int min_num, int max_num) {
return *this; return *this;
} }
OpSchema& OpSchema::Inplace(set<pair<int, int>> inplace) { OpSchema& OpSchema::AllowInplace(std::function<bool(int, int)> inplace) {
CheckInplace = inplace;
return *this;
}
OpSchema& OpSchema::AllowInplace(set<pair<int, int>> inplace) {
CheckInplace = [inplace](int in, int out) -> bool { CheckInplace = [inplace](int in, int out) -> bool {
return (inplace.count(std::make_pair(in, out)) > 0); return (inplace.count(std::make_pair(in, out)) > 0);
}; };
......
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
namespace dragon { namespace dragon {
/*!
* \brief Class to record the schema of operator.
*/
class DRAGON_API OpSchema { class DRAGON_API OpSchema {
public: public:
/*! \brief Default constructor */ /*! \brief Default constructor */
...@@ -27,15 +30,12 @@ class DRAGON_API OpSchema { ...@@ -27,15 +30,12 @@ class DRAGON_API OpSchema {
Init(); Init();
} }
/*! \brief Constructor with defined spec */ /*! \brief Constructor with the defined spec */
OpSchema(const string& op_type, const string& file, const int line) OpSchema(const string& op_type, const string& file, const int line)
: op_type_(op_type), file_(file), line_(line) { : op_type_(op_type), file_(file), line_(line) {
Init(); Init();
} }
/*! \brief Check if the in-place setting is matched */
std::function<bool(int, int)> CheckInplace = nullptr;
/*! \brief Set a fixed number of inputs */ /*! \brief Set a fixed number of inputs */
OpSchema& NumInputs(int n); OpSchema& NumInputs(int n);
...@@ -48,12 +48,18 @@ class DRAGON_API OpSchema { ...@@ -48,12 +48,18 @@ class DRAGON_API OpSchema {
/*! \brief Set the min and max number of outputs */ /*! \brief Set the min and max number of outputs */
OpSchema& NumOutputs(int min_num, int max_num); OpSchema& NumOutputs(int min_num, int max_num);
/*! \brief Set the in-place setting */ /*! \brief Set the rule to allow inplace with a group of indices */
OpSchema& Inplace(set<pair<int, int>> inplace); OpSchema& AllowInplace(set<pair<int, int>> inplace);
/*! \brief Verify if the def matches the schema */ /*! \brief Set the rule to allow inplace with a function */
OpSchema& AllowInplace(std::function<bool(int, int)> inplace);
/*! \brief Check if the given def matches this schema */
bool Verify(const OperatorDef& def) const; bool Verify(const OperatorDef& def) const;
/*! \brief Check if the inplace is allowed */
std::function<bool(int, int)> CheckInplace = [](int, int) { return false; };
private: private:
/*! \brief Initialize the default settings */ /*! \brief Initialize the default settings */
void Init() { void Init() {
......
...@@ -242,7 +242,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -242,7 +242,7 @@ PYBIND11_MODULE(libdragon_python, m) {
maker.Make(op_defs, targets, input_grads, graph_def); maker.Make(op_defs, targets, input_grads, graph_def);
py::gil_scoped_release g; py::gil_scoped_release g;
if (!retain_grads) { if (!retain_grads) {
graph_def = maker.Share(graph_def); graph_def = maker.Optimize(graph_def);
} }
for (const auto& op_def : graph_def.op()) { for (const auto& op_def : graph_def.op()) {
if (verbose) { if (verbose) {
......
...@@ -129,7 +129,7 @@ OPERATOR_SCHEMA(DropBlock2d) ...@@ -129,7 +129,7 @@ OPERATOR_SCHEMA(DropBlock2d)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
DEPLOY_CPU(DropBlock2dGradient); DEPLOY_CPU(DropBlock2dGradient);
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -142,7 +142,7 @@ OPERATOR_SCHEMA(DropBlock2dGradient) ...@@ -142,7 +142,7 @@ OPERATOR_SCHEMA(DropBlock2dGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(DropBlock2d, SimpleGradientMaker); REGISTER_GRADIENT(DropBlock2d, SimpleGradientMaker);
......
...@@ -95,7 +95,7 @@ OPERATOR_SCHEMA(DropPath) ...@@ -95,7 +95,7 @@ OPERATOR_SCHEMA(DropPath)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DropPathGradient) OPERATOR_SCHEMA(DropPathGradient)
/* dY */ /* dY */
...@@ -103,7 +103,7 @@ OPERATOR_SCHEMA(DropPathGradient) ...@@ -103,7 +103,7 @@ OPERATOR_SCHEMA(DropPathGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(DropPath, SimpleGradientMaker); REGISTER_GRADIENT(DropPath, SimpleGradientMaker);
......
...@@ -84,7 +84,7 @@ OPERATOR_SCHEMA(Dropout) ...@@ -84,7 +84,7 @@ OPERATOR_SCHEMA(Dropout)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(DropoutGradient) OPERATOR_SCHEMA(DropoutGradient)
/* dY */ /* dY */
...@@ -92,7 +92,7 @@ OPERATOR_SCHEMA(DropoutGradient) ...@@ -92,7 +92,7 @@ OPERATOR_SCHEMA(DropoutGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Dropout, SimpleGradientMaker); REGISTER_GRADIENT(Dropout, SimpleGradientMaker);
......
...@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Elu) ...@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Elu)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(EluGradient) OPERATOR_SCHEMA(EluGradient)
/* Y, dY */ /* Y, dY */
...@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(EluGradient) ...@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(EluGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Elu, InplaceGradientMaker); REGISTER_GRADIENT(Elu, InplaceGradientMaker);
......
...@@ -73,7 +73,7 @@ OPERATOR_SCHEMA(Relu) ...@@ -73,7 +73,7 @@ OPERATOR_SCHEMA(Relu)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReluGradient) OPERATOR_SCHEMA(ReluGradient)
/* Y, dY */ /* Y, dY */
...@@ -81,7 +81,7 @@ OPERATOR_SCHEMA(ReluGradient) ...@@ -81,7 +81,7 @@ OPERATOR_SCHEMA(ReluGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Relu, InplaceGradientMaker); REGISTER_GRADIENT(Relu, InplaceGradientMaker);
......
...@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Selu) ...@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Selu)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SeluGradient) OPERATOR_SCHEMA(SeluGradient)
/* Y, dY */ /* Y, dY */
...@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(SeluGradient) ...@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(SeluGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Selu, InplaceGradientMaker); REGISTER_GRADIENT(Selu, InplaceGradientMaker);
......
...@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Sigmoid) ...@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Sigmoid)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SigmoidGradient) OPERATOR_SCHEMA(SigmoidGradient)
/* Y, dY */ /* Y, dY */
...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(SigmoidGradient) ...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(SigmoidGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Sigmoid, InplaceGradientMaker); REGISTER_GRADIENT(Sigmoid, InplaceGradientMaker);
......
...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(Softmax) ...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(Softmax)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SoftmaxGradient) OPERATOR_SCHEMA(SoftmaxGradient)
/* Y, dY */ /* Y, dY */
...@@ -68,7 +68,7 @@ OPERATOR_SCHEMA(SoftmaxGradient) ...@@ -68,7 +68,7 @@ OPERATOR_SCHEMA(SoftmaxGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Softmax, InplaceGradientMaker); REGISTER_GRADIENT(Softmax, InplaceGradientMaker);
......
...@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Tanh) ...@@ -52,7 +52,7 @@ OPERATOR_SCHEMA(Tanh)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(TanhGradient) OPERATOR_SCHEMA(TanhGradient)
/* Y, dY */ /* Y, dY */
...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(TanhGradient) ...@@ -60,7 +60,7 @@ OPERATOR_SCHEMA(TanhGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Tanh, InplaceGradientMaker); REGISTER_GRADIENT(Tanh, InplaceGradientMaker);
......
...@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(ExpandDims) ...@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(ExpandDims)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ExpandDimsGradient) OPERATOR_SCHEMA(ExpandDimsGradient)
/* dY */ /* dY */
...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(ExpandDimsGradient) ...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(ExpandDimsGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(ExpandDims, SimpleGradientMaker); REGISTER_GRADIENT(ExpandDims, SimpleGradientMaker);
......
...@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Flatten) ...@@ -56,7 +56,7 @@ OPERATOR_SCHEMA(Flatten)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(FlattenGradient) OPERATOR_SCHEMA(FlattenGradient)
/* dY */ /* dY */
...@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(FlattenGradient) ...@@ -64,7 +64,7 @@ OPERATOR_SCHEMA(FlattenGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Flatten, SimpleGradientMaker); REGISTER_GRADIENT(Flatten, SimpleGradientMaker);
......
...@@ -69,7 +69,7 @@ OPERATOR_SCHEMA(Reshape) ...@@ -69,7 +69,7 @@ OPERATOR_SCHEMA(Reshape)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReshapeGradient) OPERATOR_SCHEMA(ReshapeGradient)
/* dY */ /* dY */
...@@ -77,7 +77,7 @@ OPERATOR_SCHEMA(ReshapeGradient) ...@@ -77,7 +77,7 @@ OPERATOR_SCHEMA(ReshapeGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Reshape, SimpleGradientMaker); REGISTER_GRADIENT(Reshape, SimpleGradientMaker);
......
...@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(Squeeze) ...@@ -45,7 +45,7 @@ OPERATOR_SCHEMA(Squeeze)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(SqueezeGradient) OPERATOR_SCHEMA(SqueezeGradient)
/* dY */ /* dY */
...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(SqueezeGradient) ...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(SqueezeGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Squeeze, SimpleGradientMaker); REGISTER_GRADIENT(Squeeze, SimpleGradientMaker);
......
...@@ -193,7 +193,7 @@ DEPLOY_CPU(Collective); ...@@ -193,7 +193,7 @@ DEPLOY_CPU(Collective);
DEPLOY_CUDA(Collective); DEPLOY_CUDA(Collective);
#endif #endif
OPERATOR_SCHEMA(Collective); OPERATOR_SCHEMA(Collective).AllowInplace([](int, int) -> bool { return true; });
} // namespace dragon } // namespace dragon
......
...@@ -122,7 +122,7 @@ OPERATOR_SCHEMA(GradientAdd) ...@@ -122,7 +122,7 @@ OPERATOR_SCHEMA(GradientAdd)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X1 => Y */ /* X1 => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(StopGradient) OPERATOR_SCHEMA(StopGradient)
/* X */ /* X */
...@@ -130,7 +130,7 @@ OPERATOR_SCHEMA(StopGradient) ...@@ -130,7 +130,7 @@ OPERATOR_SCHEMA(StopGradient)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
NO_GRADIENT(StopGradient); NO_GRADIENT(StopGradient);
......
...@@ -107,16 +107,16 @@ OPERATOR_SCHEMA(Add) ...@@ -107,16 +107,16 @@ OPERATOR_SCHEMA(Add)
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* A => Y */ /* A => Y, B => Y */
.Inplace({{0, 0}, {1, 0}}); .AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(AddGradient) OPERATOR_SCHEMA(AddGradient)
/* dY */ /* dY */
.NumInputs(1) .NumInputs(1)
/* dA, dB */ /* dA, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dA */ /* dY => dA, dY => dB */
.Inplace({{0, 0}, {0, 1}}); .AllowInplace({{0, 0}, {0, 1}});
REGISTER_GRADIENT(Add, SimpleGradientMaker); REGISTER_GRADIENT(Add, SimpleGradientMaker);
......
...@@ -151,7 +151,7 @@ OPERATOR_SCHEMA(Affine) ...@@ -151,7 +151,7 @@ OPERATOR_SCHEMA(Affine)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(AffineGradient) OPERATOR_SCHEMA(AffineGradient)
/* X, W, dY */ /* X, W, dY */
...@@ -159,7 +159,7 @@ OPERATOR_SCHEMA(AffineGradient) ...@@ -159,7 +159,7 @@ OPERATOR_SCHEMA(AffineGradient)
/* dX, dW, dB */ /* dX, dW, dB */
.NumOutputs(3) .NumOutputs(3)
/* dY => dX */ /* dY => dX */
.Inplace({{2, 0}}); .AllowInplace({{2, 0}});
namespace { namespace {
......
...@@ -6,49 +6,28 @@ namespace dragon { ...@@ -6,49 +6,28 @@ namespace dragon {
template <class Context> template <class Context>
template <typename T> template <typename T>
void AxpbyOp<Context>::DoRunWithType(Tensor* X, Tensor* Y) { void AxpbyOp<Context>::DoRunWithType() {
CHECK_EQ(X->count(), Y->count()); auto &X = Input(0), *Y = Output(0);
auto* x = X->template data<T, Context>(); auto* x = X.template data<T, Context>();
auto* y = Y->template mutable_data<T, Context>(); auto* y = Y->ReshapeLike(X)->template mutable_data<T, Context>();
if (beta_ == 1.f) { if (beta_ == 1.f) {
if (alpha_ == 1.f) { if (alpha_ == 1.f) {
math::Add(X->count(), x, y, y, ctx()); math::Add(X.count(), x, y, y, ctx());
} else { } else {
math::Axpy(X->count(), alpha_, x, y, ctx()); math::Axpy(X.count(), alpha_, x, y, ctx());
} }
} else { } else {
if (alpha_ == 0.f) { if (alpha_ == 0.f) {
math::Scale(X->count(), beta_, y, y, ctx()); math::Scale(X.count(), beta_, y, y, ctx());
} else { } else {
math::Axpby(X->count(), alpha_, x, beta_, y, ctx()); math::Axpby(X.count(), alpha_, x, beta_, y, ctx());
} }
} }
} }
template <class Context> template <class Context>
void AxpbyOp<Context>::RunOnDevice() { void AxpbyOp<Context>::RunOnDevice() {
for (int i = 0; i < InputSize(); i++) { DispatchHelper<MathTensorTypes>::Call(this, Input(0));
auto &X = Input(i), *Y = Output(i);
Y->ReshapeLike(X);
if (XIsType(X, int8_t)) {
DoRunWithType<int8_t>(&X, Y);
} else if (XIsType(X, uint8_t)) {
DoRunWithType<uint8_t>(&X, Y);
} else if (XIsType(X, int)) {
DoRunWithType<int>(&X, Y);
} else if (XIsType(X, int64_t)) {
DoRunWithType<int64_t>(&X, Y);
} else if (XIsType(X, float16)) {
DoRunWithType<float16>(&X, Y);
} else if (XIsType(X, float)) {
DoRunWithType<float>(&X, Y);
} else if (XIsType(X, double)) {
DoRunWithType<double>(&X, Y);
} else
LOG(FATAL) << MessageForUnsupported(
types::to_string(X.meta()),
{"int8", "uint8", "int32", "int64", "float16", "float32", "float64"});
}
} }
DEPLOY_CPU(Axpby); DEPLOY_CPU(Axpby);
...@@ -57,10 +36,12 @@ DEPLOY_CUDA(Axpby); ...@@ -57,10 +36,12 @@ DEPLOY_CUDA(Axpby);
#endif #endif
OPERATOR_SCHEMA(Axpby) OPERATOR_SCHEMA(Axpby)
/* X1, ... */ /* X */
.NumInputs(1, INT_MAX) .NumInputs(1)
/* Y1, ... */ /* Y */
.NumOutputs(1, INT_MAX); .NumOutputs(1)
/* X => Y */
.AllowInplace({{0, 0}});
NO_GRADIENT(Axpby); NO_GRADIENT(Axpby);
......
...@@ -207,16 +207,16 @@ OPERATOR_SCHEMA(Div) ...@@ -207,16 +207,16 @@ OPERATOR_SCHEMA(Div)
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* A => Y */ /* A => Y, B => Y */
.Inplace({{0, 0}, {1, 0}}); .AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(DivGradient) OPERATOR_SCHEMA(DivGradient)
/* A, B, dY */ /* A, B, dY */
.NumInputs(3) .NumInputs(3)
/* dA, dB */ /* dA, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dA */ /* dY => dA, dY => dB */
.Inplace({{2, 0}, {2, 1}}); .AllowInplace({{2, 0}, {2, 1}});
REGISTER_GRADIENT(Div, GenericGradientMaker); REGISTER_GRADIENT(Div, GenericGradientMaker);
......
...@@ -172,15 +172,15 @@ DEPLOY_CUDA(Greater); ...@@ -172,15 +172,15 @@ DEPLOY_CUDA(Greater);
DEPLOY_CUDA(GreaterEqual); DEPLOY_CUDA(GreaterEqual);
#endif #endif
OPERATOR_SCHEMA(Ceil).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Ceil).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Floor).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Floor).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Round).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Round).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sign).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Sign).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sqrt).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Sqrt).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Rsqrt).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Rsqrt).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Exp).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Exp).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Log).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Log).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Invert).NumInputs(1).NumOutputs(1).Inplace({{0, 0}}); OPERATOR_SCHEMA(Invert).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
OPERATOR_SCHEMA(Sin).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Sin).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Cos).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Cos).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Square).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Square).NumInputs(1).NumOutputs(1);
......
...@@ -43,7 +43,7 @@ class AxpbyOp final : public Operator<Context> { ...@@ -43,7 +43,7 @@ class AxpbyOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> template <typename T>
void DoRunWithType(Tensor* X, Tensor* Y); void DoRunWithType();
protected: protected:
float alpha_, beta_; float alpha_, beta_;
......
...@@ -31,7 +31,7 @@ OPERATOR_SCHEMA(ExpGradient) ...@@ -31,7 +31,7 @@ OPERATOR_SCHEMA(ExpGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Exp, InplaceGradientMaker); REGISTER_GRADIENT(Exp, InplaceGradientMaker);
......
...@@ -189,16 +189,16 @@ OPERATOR_SCHEMA(Mul) ...@@ -189,16 +189,16 @@ OPERATOR_SCHEMA(Mul)
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* A => Y */ /* A => Y, B => Y */
.Inplace({{0, 0}, {1, 0}}); .AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(MulGradient) OPERATOR_SCHEMA(MulGradient)
/* A, B, dY */ /* A, B, dY */
.NumInputs(3) .NumInputs(3)
/* dA, dB */ /* dA, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dA */ /* dY => dA, dY => dB */
.Inplace({{2, 0}, {2, 1}}); .AllowInplace({{2, 0}, {2, 1}});
REGISTER_GRADIENT(Mul, GenericGradientMaker); REGISTER_GRADIENT(Mul, GenericGradientMaker);
......
...@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Neg) ...@@ -54,7 +54,7 @@ OPERATOR_SCHEMA(Neg)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(NegGradient) OPERATOR_SCHEMA(NegGradient)
/* dY */ /* dY */
...@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(NegGradient) ...@@ -62,7 +62,7 @@ OPERATOR_SCHEMA(NegGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Neg, SimpleGradientMaker); REGISTER_GRADIENT(Neg, SimpleGradientMaker);
......
...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(Reciprocal) ...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(Reciprocal)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(ReciprocalGradient) OPERATOR_SCHEMA(ReciprocalGradient)
/* Y, dY */ /* Y, dY */
...@@ -61,7 +61,7 @@ OPERATOR_SCHEMA(ReciprocalGradient) ...@@ -61,7 +61,7 @@ OPERATOR_SCHEMA(ReciprocalGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Reciprocal, InplaceGradientMaker); REGISTER_GRADIENT(Reciprocal, InplaceGradientMaker);
......
...@@ -32,7 +32,7 @@ OPERATOR_SCHEMA(RsqrtGradient) ...@@ -32,7 +32,7 @@ OPERATOR_SCHEMA(RsqrtGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Rsqrt, InplaceGradientMaker); REGISTER_GRADIENT(Rsqrt, InplaceGradientMaker);
......
...@@ -30,7 +30,7 @@ OPERATOR_SCHEMA(SignGradient) ...@@ -30,7 +30,7 @@ OPERATOR_SCHEMA(SignGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(Sign, SimpleGradientMaker); REGISTER_GRADIENT(Sign, SimpleGradientMaker);
......
...@@ -37,7 +37,7 @@ OPERATOR_SCHEMA(SqrtGradient) ...@@ -37,7 +37,7 @@ OPERATOR_SCHEMA(SqrtGradient)
/* dX */ /* dX */
.NumOutputs(1) .NumOutputs(1)
/* dY => dX */ /* dY => dX */
.Inplace({{1, 0}}); .AllowInplace({{1, 0}});
REGISTER_GRADIENT(Sqrt, InplaceGradientMaker); REGISTER_GRADIENT(Sqrt, InplaceGradientMaker);
......
...@@ -112,16 +112,16 @@ OPERATOR_SCHEMA(Sub) ...@@ -112,16 +112,16 @@ OPERATOR_SCHEMA(Sub)
.NumInputs(2) .NumInputs(2)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* A => Y */ /* A => Y, B => Y */
.Inplace({{0, 0}, {1, 0}}); .AllowInplace({{0, 0}, {1, 0}});
OPERATOR_SCHEMA(SubGradient) OPERATOR_SCHEMA(SubGradient)
/* dY */ /* dY */
.NumInputs(1) .NumInputs(1)
/* dA, dB */ /* dA, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dA */ /* dY => dA, dY => dB */
.Inplace({{0, 0}, {0, 1}}); .AllowInplace({{0, 0}, {0, 1}});
REGISTER_GRADIENT(Sub, SimpleGradientMaker); REGISTER_GRADIENT(Sub, SimpleGradientMaker);
......
...@@ -89,7 +89,7 @@ OPERATOR_SCHEMA(BiasAdd) ...@@ -89,7 +89,7 @@ OPERATOR_SCHEMA(BiasAdd)
/* Y */ /* Y */
.NumOutputs(1) .NumOutputs(1)
/* X => Y */ /* X => Y */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
OPERATOR_SCHEMA(BiasAddGradient) OPERATOR_SCHEMA(BiasAddGradient)
/* dY */ /* dY */
...@@ -97,7 +97,7 @@ OPERATOR_SCHEMA(BiasAddGradient) ...@@ -97,7 +97,7 @@ OPERATOR_SCHEMA(BiasAddGradient)
/* dX, dB */ /* dX, dB */
.NumOutputs(2) .NumOutputs(2)
/* dY => dX */ /* dY => dX */
.Inplace({{0, 0}}); .AllowInplace({{0, 0}});
REGISTER_GRADIENT(BiasAdd, SimpleGradientMaker); REGISTER_GRADIENT(BiasAdd, SimpleGradientMaker);
......
...@@ -289,7 +289,7 @@ template <class Context> ...@@ -289,7 +289,7 @@ template <class Context>
template <typename T> template <typename T>
void CuDNNConv2dGradientOp<Context>::ResetDesc() { void CuDNNConv2dGradientOp<Context>::ResetDesc() {
auto &X = Input(0), &W = Input(1), &dY = Input(-1); auto &X = Input(0), &W = Input(1), &dY = Input(-1);
auto *dX = Output(0), *dW = Output(1); // auto *dX = Output(0), *dW = Output(1);
bool input_changed = (X.dims() != input_dims_); bool input_changed = (X.dims() != input_dims_);
bool filter_changed = (W.dims() != filter_dims_); bool filter_changed = (W.dims() != filter_dims_);
if (input_changed || filter_changed) { if (input_changed || filter_changed) {
...@@ -328,8 +328,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -328,8 +328,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
exhaustive_search_data_ = true; exhaustive_search_data_ = true;
exhaustive_search_filter_ = true; exhaustive_search_filter_ = true;
} else { } else {
if (dW->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
{
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos]; cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos];
...@@ -353,20 +353,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -353,20 +353,8 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found) CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> " << "\nNo algorithms available for <cudnnConvolutionBackwardFilter> "
<< "under the current desc and workspace limit."; << "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
if (dX->has_name()) { {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS;
cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos]; cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos];
...@@ -390,7 +378,17 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -390,7 +378,17 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found) CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardData> " << "\nNo algorithms available for <cudnnConvolutionBackwardData> "
<< "under the current desc and workspace limit."; << "under the current desc and workspace limit.";
}
#else #else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
filter_desc_, filter_desc_,
...@@ -402,7 +400,6 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() { ...@@ -402,7 +400,6 @@ void CuDNNConv2dGradientOp<Context>::ResetDesc() {
&bwd_data_algo_)); &bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0) #endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
}
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size
} }
} }
......
...@@ -287,7 +287,6 @@ template <class Context> ...@@ -287,7 +287,6 @@ template <class Context>
template <typename T> template <typename T>
void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
auto &X = Input(0), &W = Input(1), &dY = Input(-1); auto &X = Input(0), &W = Input(1), &dY = Input(-1);
auto *dX = Output(0), *dW = Output(1);
bool input_changed = (X.dims() != input_dims_); bool input_changed = (X.dims() != input_dims_);
bool filter_changed = (W.dims() != filter_dims_); bool filter_changed = (W.dims() != filter_dims_);
if (input_changed || filter_changed) { if (input_changed || filter_changed) {
...@@ -324,8 +323,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -324,8 +323,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
exhaustive_search_data_ = true; exhaustive_search_data_ = true;
exhaustive_search_filter_ = true; exhaustive_search_filter_ = true;
} else { } else {
if (dW->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
{
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos]; cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos];
...@@ -349,20 +348,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -349,20 +348,8 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found) CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> " << "\nNo algorithms available for <cudnnConvolutionBackwardFilter> "
<< "under the current desc and workspace limit."; << "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
if (dX->has_name()) { {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS;
cudnnConvolutionFwdAlgoPerf_t stats[num_algos]; cudnnConvolutionFwdAlgoPerf_t stats[num_algos];
...@@ -386,7 +373,17 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -386,7 +373,17 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
CHECK(algo_is_found) CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionForward> " << "\nNo algorithms available for <cudnnConvolutionForward> "
<< "under the current desc and workspace limit."; << "under the current desc and workspace limit.";
}
#else #else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
input_desc_, input_desc_,
...@@ -398,7 +395,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() { ...@@ -398,7 +395,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
&bwd_data_algo_)); &bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0) #endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
}
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size
} }
} }
......
...@@ -50,11 +50,11 @@ def set_optimization(level=1): ...@@ -50,11 +50,11 @@ def set_optimization(level=1):
* level = ``0``: Do nothing. * level = ``0``: Do nothing.
* level = ``1``: Prune the redundant nodes. * level = ``1``: Eliminate the unused outputs and operators.
* level = ``2``: Add the inplace to outputs. * level = ``2``: Apply inplace to the inputs if available.
* level = ``3``: Allocate the buffer for outputs. * level = ``3``: Allocate shared buffer for the outputs.
Parameters Parameters
---------- ----------
......
...@@ -78,7 +78,7 @@ class GradientMaker(object): ...@@ -78,7 +78,7 @@ class GradientMaker(object):
if not is_skip: if not is_skip:
for input, grad_input in zip(op_def.input, grad_inputs): for input, grad_input in zip(op_def.input, grad_inputs):
inputs_to_grads[input] = grad_input inputs_to_grads[input] = grad_input
# Add def for ``GradientGenerateOp`` # Add ``GradientGenerateOp``
if len(gen_grads) > 0: if len(gen_grads) > 0:
inputs, outputs, values = [], [], [] inputs, outputs, values = [], [], []
for name, i in gen_grads: for name, i in gen_grads:
...@@ -94,7 +94,7 @@ class GradientMaker(object): ...@@ -94,7 +94,7 @@ class GradientMaker(object):
device_option=op_def.device_option device_option=op_def.device_option
if op_def.HasField('device_option') else None) if op_def.HasField('device_option') else None)
backward_defs.append(gen_op) backward_defs.append(gen_op)
# Add def for ``GradientOp`` # Add ``GradientOp``
for grad_def in grad_defs: for grad_def in grad_defs:
grad_def.name = OpDef.get_name() grad_def.name = OpDef.get_name()
backward_defs.append(grad_def) backward_defs.append(grad_def)
......
...@@ -130,7 +130,7 @@ def affine(inputs, axis=1, num_axes=1, **kwargs): ...@@ -130,7 +130,7 @@ def affine(inputs, axis=1, num_axes=1, **kwargs):
return op_lib.blend(**args) return op_lib.blend(**args)
@OpSchema.num_inputs(1, 2147483647) @OpSchema.num_inputs(1)
def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs): def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
r"""Compute the element-wise addition from input to output. r"""Compute the element-wise addition from input to output.
...@@ -140,10 +140,10 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs): ...@@ -140,10 +140,10 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
Parameters Parameters
---------- ----------
inputs : Union[dragon.Tensor, Sequence[dragon.Tensor]] inputs : dragon.Tensor
The input tensor(s). The input tensor.
outputs : Union[dragon.Tensor, Sequence[dragon.Tensor]], optional outputs : dragon.Tensor, optional
The output tensor(s). The output tensor.
alpha : number, optional, default=1. alpha : number, optional, default=1.
The value to :math:`\alpha`. The value to :math:`\alpha`.
beta : number, optional, default=1. beta : number, optional, default=1.
...@@ -151,23 +151,17 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs): ...@@ -151,23 +151,17 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
Returns Returns
------- -------
Union[dragon.Tensor, Sequence[dragon.Tensor]] dragon.Tensor
The output tensor(s). The output tensor.
""" """
args = parse_args(locals()) args = parse_args(locals())
args['alpha'], args['beta'] = float(alpha), float(beta) args['alpha'], args['beta'] = float(alpha), float(beta)
if types.is_tensor(inputs):
inputs = [inputs]
if outputs is not None and types.is_tensor(outputs):
args['outputs'] = [outputs]
op_lib = math_ops_lib.Axpby op_lib = math_ops_lib.Axpby
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(alpha=args['alpha'], beta=args['beta']) \
alpha=args['alpha'], .apply([inputs], [outputs])
beta=args['beta'],
).apply(inputs, args['outputs'])
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -65,8 +65,7 @@ class _BatchNorm(Module): ...@@ -65,8 +65,7 @@ class _BatchNorm(Module):
.format(**self.__dict__) .format(**self.__dict__)
def forward(self, input): def forward(self, input):
training = self.training or \ training = self.training or not self.track_running_stats
not self.track_running_stats
return F.batch_norm( return F.batch_norm(
input, *self.inputs, input, *self.inputs,
training=training, training=training,
......
...@@ -146,10 +146,8 @@ def cat(seq, dim=0, out=None): ...@@ -146,10 +146,8 @@ def cat(seq, dim=0, out=None):
""" """
return _functions.Concat \ return _functions.Concat \
.instantiate( .instantiate(seq[0].device, axis=dim) \
seq[0].device, .apply(seq, out)
axis=dim,
).apply(seq, out)
def channel_normalize( def channel_normalize(
...@@ -618,10 +616,7 @@ def nonzero(input, out=None): ...@@ -618,10 +616,7 @@ def nonzero(input, out=None):
The output tensor. The output tensor.
""" """
return _functions.NonZero \ return _functions.NonZero.instantiate(input.device).apply(input, out)
.instantiate(
input.device,
).apply(input, out)
def one_hot(input, depth): def one_hot(input, depth):
...@@ -647,8 +642,7 @@ def one_hot(input, depth): ...@@ -647,8 +642,7 @@ def one_hot(input, depth):
The output tensor. The output tensor.
""" """
return _functions.OneHot \ return _functions.OneHot.instantiate(input.device, depth=depth).apply(input)
.instantiate(input.device, depth=depth).apply(input)
def permute(input, dims): def permute(input, dims):
...@@ -715,18 +709,14 @@ def reshape(input, shape, out=None): ...@@ -715,18 +709,14 @@ def reshape(input, shape, out=None):
""" """
shape = nest.flatten(shape) shape = nest.flatten(shape)
return _functions.Reshape \ return _functions.Reshape \
.instantiate( .instantiate(input.device, ndim=len(shape)) \
input.device, .apply(input, shape, out)
ndim=len(shape),
).apply(input, shape, out)
def slice(input, starts, sizes): def slice(input, starts, sizes):
return _functions.Slice \ return _functions.Slice \
.instantiate( .instantiate(input.device, ndim=len(starts)) \
input.device, .apply(input, starts, sizes)
ndim=len(starts),
).apply(input, starts, sizes)
def split(tensor, split_size_or_sections, dim=0): def split(tensor, split_size_or_sections, dim=0):
...@@ -1015,9 +1005,8 @@ def where(condition, x, y): ...@@ -1015,9 +1005,8 @@ def where(condition, x, y):
""" """
return _functions.Where \ return _functions.Where \
.instantiate( .instantiate(utils.unify_devices([condition, x, y])) \
utils.unify_devices([condition, x, y]), .apply(condition, x, y)
).apply(condition, x, y)
def _arg_reduce(input, op_type, dim=None, keepdim=False, out=None): def _arg_reduce(input, op_type, dim=None, keepdim=False, out=None):
......
...@@ -567,10 +567,6 @@ def expand(self, *sizes): ...@@ -567,10 +567,6 @@ def expand(self, *sizes):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.expand(...)`_
""" """
return array_funcs.expand(self, sizes) return array_funcs.expand(self, sizes)
......
...@@ -51,6 +51,5 @@ class GradAccumulate(function.Function): ...@@ -51,6 +51,5 @@ class GradAccumulate(function.Function):
'arguments': {'alpha': 1., 'beta': float(self.momentum)}, 'arguments': {'alpha': 1., 'beta': float(self.momentum)},
} }
def forward(self, grads): def forward(self, grad):
outputs = [grad.id + '[accum]' for grad in grads] return self.dispatch([grad], [grad.id + '[accum]'], no_grad=True)
return self.dispatch(grads, outputs, no_grad=True)
...@@ -14,18 +14,13 @@ from __future__ import absolute_import ...@@ -14,18 +14,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.util import nest
from dragon.vm.torch.core.ops.training import _functions from dragon.vm.torch.core.ops.training import _functions
def accumulate_grad(grads, momentum=1): def accumulate_grad(grad, momentum=1):
"""Accumulate the gradients.""" """Accumulate the gradient."""
grads = nest.flatten(grads)
if len(grads) == 0:
return
return _functions.GradAccumulate \ return _functions.GradAccumulate \
.instantiate(grads[0].device, momentum=momentum) \ .instantiate(grad.device, momentum=momentum).apply(grad)
.apply(grads)
def update_param( def update_param(
......
...@@ -97,15 +97,13 @@ class Optimizer(object): ...@@ -97,15 +97,13 @@ class Optimizer(object):
The momentum to the accumulated value. The momentum to the accumulated value.
""" """
grads = []
current_ws = workspace.get_workspace() current_ws = workspace.get_workspace()
for group in self.param_groups: for group in self.param_groups:
group['_internal/grad_accum'] = True group['_internal/grad_accum'] = True
for param in group['params']: for param in group['params']:
grad = self._steal_grad(current_ws, param) grad = self._steal_grad(current_ws, param)
if grad is not None: if grad is not None:
grads.append(grad) training_funcs.accumulate_grad(grad)
training_funcs.accumulate_grad(grads, momentum)
def add_param_group(self, param_group): def add_param_group(self, param_group):
"""Add a new param group into the optimizer. """Add a new param group into the optimizer.
......
...@@ -776,10 +776,6 @@ class Tensor(object): ...@@ -776,10 +776,6 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.expand(...)`_
""" """
def expand_as(self, other): def expand_as(self, other):
...@@ -795,10 +791,6 @@ class Tensor(object): ...@@ -795,10 +791,6 @@ class Tensor(object):
dragon.vm.torch.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
See Also
--------
`torch.expand(...)`_
""" """
return self.expand(*other.size()) return self.expand(*other.size())
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!