Commit df172cc8 by Ting PAN

Simplify the operation executor

Summary:
This commit removes the redundant workspace reference
when executing a tensor operation.
1 parent b37d4e5e
...@@ -37,6 +37,7 @@ void GraphGradientMaker::Make( ...@@ -37,6 +37,7 @@ void GraphGradientMaker::Make(
Map<string, int> inputs_count, grads_count; Map<string, int> inputs_count, grads_count;
Set<string> all_split_grads, targets_set; Set<string> all_split_grads, targets_set;
Map<string, string> targets_to_grads; Map<string, string> targets_to_grads;
// PLAY for the forward // PLAY for the forward
for (auto* op : forward_ops) { for (auto* op : forward_ops) {
if (NoGradientRegistry()->Has(op->type())) continue; if (NoGradientRegistry()->Has(op->type())) continue;
...@@ -51,6 +52,7 @@ void GraphGradientMaker::Make( ...@@ -51,6 +52,7 @@ void GraphGradientMaker::Make(
if (!input_in_outputs) inputs_count[input]++; if (!input_in_outputs) inputs_count[input]++;
} }
} }
// PLAY for the backward // PLAY for the backward
for (int i = 0; i < targets.size(); ++i) { for (int i = 0; i < targets.size(); ++i) {
// Set the gradient of targets // Set the gradient of targets
...@@ -59,17 +61,19 @@ void GraphGradientMaker::Make( ...@@ -59,17 +61,19 @@ void GraphGradientMaker::Make(
} }
targets_set.insert(targets[i]); targets_set.insert(targets[i]);
} }
for (int i = (int)forward_ops.size() - 1; i >= 0; --i) {
for (int op_idx = (int)forward_ops.size() - 1; op_idx >= 0; --op_idx) {
// Collect inputs and outputs, generate raw gradient ops // Collect inputs and outputs, generate raw gradient ops
const OperatorDef& op = *forward_ops[i]; const OperatorDef& op = *forward_ops[op_idx];
vector<pair<string, int>> gen_grads; vector<pair<string, int>> gen_grads;
bool is_skip = CheckGrad(op, targets_set, gen_grads); bool is_skip = CheckGrad(op, targets_set, gen_grads);
vector<string> g_outputs; vector<string> g_outputs;
for (auto& output : op.output()) { for (const auto& output : op.output()) {
string g_output = ""; string g_output = "";
if (inputs_to_grads_.count(output) > 0) if (inputs_to_grads_.count(output) > 0) {
g_output = inputs_to_grads_[output]; g_output = inputs_to_grads_[output];
g_outputs.emplace_back(g_output); }
g_outputs.push_back(g_output);
} }
auto grad = MakeGradientForOp(op, g_outputs); auto grad = MakeGradientForOp(op, g_outputs);
...@@ -82,15 +86,21 @@ void GraphGradientMaker::Make( ...@@ -82,15 +86,21 @@ void GraphGradientMaker::Make(
for (int i = 0; i < grad_op.output_size(); ++i) { for (int i = 0; i < grad_op.output_size(); ++i) {
auto* output = grad_op.mutable_output(i); auto* output = grad_op.mutable_output(i);
int original_idx = -1; int original_idx = -1;
for (int j = 0; j < grad.g_inputs.size(); ++j) for (int j = 0; j < grad.g_inputs.size(); ++j) {
if (grad_op.output(i) == grad.g_inputs[j]) original_idx = j; if (grad_op.output(i) == grad.g_inputs[j]) {
original_idx = j;
}
}
// Ignore unused && in-placee GI // Ignore unused && in-placee GI
if (original_idx == -1) continue; if (original_idx == -1) continue;
bool output_in_inputs = false; bool output_in_inputs = false;
for (const auto& input : grad_op.input()) for (const auto& input : grad_op.input()) {
if (grad_op.output(i) == input) output_in_inputs = true; if (grad_op.output(i) == input) {
output_in_inputs = true;
}
}
if (output_in_inputs) continue; if (output_in_inputs) continue;
// Found a split branch // Find a split branch
const auto& original_name = op.input(original_idx); const auto& original_name = op.input(original_idx);
if (inputs_count[original_name] > 1) { if (inputs_count[original_name] > 1) {
// Split // Split
...@@ -120,7 +130,7 @@ void GraphGradientMaker::Make( ...@@ -120,7 +130,7 @@ void GraphGradientMaker::Make(
// Now, append the required ops // Now, append the required ops
if (!is_skip) { if (!is_skip) {
// 1) GradientGenerateOp // GradientGenerateOp
if (gen_grads.size() > 0) { if (gen_grads.size() > 0) {
vector<string> op_inputs, op_outputs; vector<string> op_inputs, op_outputs;
Argument arg_defaults; Argument arg_defaults;
...@@ -141,15 +151,18 @@ void GraphGradientMaker::Make( ...@@ -141,15 +151,18 @@ void GraphGradientMaker::Make(
} }
backward_def.add_op()->CopyFrom(generate_op); backward_def.add_op()->CopyFrom(generate_op);
} }
// 2) GradientOp // GradientOp
for (const auto& grad_op : grad.ops) for (const auto& grad_op : grad.ops) {
backward_def.add_op()->CopyFrom(grad_op); backward_def.add_op()->CopyFrom(grad_op);
} }
// 3) GradientGatherOp }
for (const auto& gather_op : gather_ops)
// GradientGatherOp
for (const auto& gather_op : gather_ops) {
backward_def.add_op()->CopyFrom(gather_op); backward_def.add_op()->CopyFrom(gather_op);
}
// Done! // Done
if (!is_skip) { if (!is_skip) {
for (int i = 0; i < op.input_size(); ++i) { for (int i = 0; i < op.input_size(); ++i) {
if (!grad.g_inputs[i].empty()) if (!grad.g_inputs[i].empty())
...@@ -162,32 +175,40 @@ void GraphGradientMaker::Make( ...@@ -162,32 +175,40 @@ void GraphGradientMaker::Make(
GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
Set<int> invalid_ops; Set<int> invalid_ops;
Map<string, int> ref_count; Map<string, int> ref_count;
Map<string, pair<int, string>> ssa_map; Map<string, pair<int, string>> gather_map;
// Count the refs for detecting leaf nodes
for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) {
const auto& op = input_def.op(op_idx); const auto& op = input_def.op(op_idx);
if (!str::find(op.type(), "Gradient")) continue; if (!str::find(op.type(), "Gradient")) continue;
// Flag the gathering gradients
if (op.type() == "GradientGather") { if (op.type() == "GradientGather") {
invalid_ops.insert(op_idx); invalid_ops.insert(op_idx);
if (ignored_grads_.count(op.output(0))) { if (ignored_grads_.count(op.output(0))) {
for (const auto& input : op.input()) for (const auto& input : op.input()) {
ignored_grads_.insert(input); ignored_grads_.insert(input);
}
continue; continue;
} else { } else {
string head; string first_input;
for (const auto& input : op.input()) { for (const auto& input : op.input()) {
if (!input.empty()) { if (!input.empty()) {
if (head.empty()) head = input; if (first_input.empty()) first_input = input;
ssa_map[input] = {op_idx, head}; gather_map[input] = {op_idx, first_input};
}
} }
} }
} }
// Count the references to detect leafs
for (const auto& input : op.input()) {
if (str::find(input, "grad")) {
ref_count[input] += 1;
}
} }
for (const auto& input : op.input())
if (str::find(input, "grad")) ref_count[input] += 1;
} }
// Decompose the <GradientGather> in SSA format // Decompose the <GradientGather> into <GradientAdd>
// This trick accumulates the split to target right after computing,
// which helps to reduce the total number of buffers.
GraphDef output_def(input_def); GraphDef output_def(input_def);
output_def.clear_op(); output_def.clear_op();
for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) { for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) {
...@@ -196,20 +217,20 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) { ...@@ -196,20 +217,20 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
output_def.add_op()->CopyFrom(op); output_def.add_op()->CopyFrom(op);
if (!str::find(op.type(), "Gradient")) continue; if (!str::find(op.type(), "Gradient")) continue;
for (const auto& output : op.output()) { for (const auto& output : op.output()) {
const auto& find_iter = ssa_map.find(output); const auto& find_iter = gather_map.find(output);
if (find_iter != ssa_map.end()) { if (find_iter != gather_map.end()) {
const auto& gather_op = input_def.op(find_iter->second.first); const auto& gather_op = input_def.op(find_iter->second.first);
auto acc_op(gather_op); auto add_op(gather_op);
acc_op.clear_input(); add_op.clear_input();
if (output != find_iter->second.second) { if (output != find_iter->second.second) {
acc_op.set_type("GradientAdd"); add_op.set_type("GradientAdd");
// Make an in-place to avoid a new buffer // Make an in-place to avoid a new buffer
acc_op.add_input(gather_op.output(0)); add_op.add_input(gather_op.output(0));
const auto& ref_iter = ref_count.find(gather_op.output(0)); const auto& ref_iter = ref_count.find(gather_op.output(0));
if (ref_iter != ref_count.end()) ref_iter->second++; if (ref_iter != ref_count.end()) ref_iter->second++;
} }
acc_op.add_input(output); add_op.add_input(output);
output_def.add_op()->CopyFrom(acc_op); output_def.add_op()->CopyFrom(add_op);
} }
} }
} }
......
...@@ -164,7 +164,7 @@ class DRAGON_API UnifiedMemory { ...@@ -164,7 +164,7 @@ class DRAGON_API UnifiedMemory {
/*! \brief The binding cpu tensor for cnml */ /*! \brief The binding cpu tensor for cnml */
cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr; cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr;
/*! \brief Tje binding mlu tensor for cnml */ /*! \brief The binding mlu tensor for cnml */
cnmlTensor_t cnml_mlu_tensor_ = nullptr; cnmlTensor_t cnml_mlu_tensor_ = nullptr;
}; };
......
...@@ -236,14 +236,18 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -236,14 +236,18 @@ PYBIND11_MODULE(libdragon_python, m) {
const bool verbose) { const bool verbose) {
GraphDef backward_ops; GraphDef backward_ops;
GraphGradientMaker maker; GraphGradientMaker maker;
for (auto& name : ignored_grads) for (const auto& name : ignored_grads) {
maker.add_ignored_grad(name); maker.add_ignored_grad(name);
for (auto& name : sources) }
for (const auto& name : sources) {
maker.add_hooked_grad(name + "_grad"); maker.add_hooked_grad(name + "_grad");
}
maker.Make(forward_ops, targets, input_grads, backward_ops); maker.Make(forward_ops, targets, input_grads, backward_ops);
py::gil_scoped_release g; py::gil_scoped_release g;
if (is_sharing) backward_ops = maker.Share(backward_ops); if (is_sharing) {
for (auto& def : backward_ops.op()) { backward_ops = maker.Share(backward_ops);
}
for (const auto& def : backward_ops.op()) {
if (verbose) { if (verbose) {
auto msg = string("\n") + def.DebugString(); auto msg = string("\n") + def.DebugString();
msg.pop_back(); msg.pop_back();
...@@ -268,8 +272,9 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -268,8 +272,9 @@ PYBIND11_MODULE(libdragon_python, m) {
<< "Can't be used in C++."; << "Can't be used in C++.";
break; break;
case 1: // CaffeModel case 1: // CaffeModel
for (const auto& e : tensors) for (const auto& name : tensors) {
refs.emplace_back(self->GetTensor(e)); refs.emplace_back(self->GetTensor(name));
}
SavaCaffeModel(filename, refs); SavaCaffeModel(filename, refs);
break; break;
default: default:
......
...@@ -50,7 +50,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() { ...@@ -50,7 +50,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
// Allocate for the reserve space // Allocate for the reserve space
size_t reserve_size; size_t reserve_size;
CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size)); CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size));
auto* X_mask = Buffer("mask")->Reshape({(int64_t)reserve_size}); auto* X_mask = Buffer("X_mask")->Reshape({(int64_t)reserve_size});
CUDNN_CHECK(cudnnDropoutForward( CUDNN_CHECK(cudnnDropoutForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
...@@ -106,7 +106,7 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() { ...@@ -106,7 +106,7 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
// Check the reserve space // Check the reserve space
size_t reserve_size; size_t reserve_size;
CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size)); CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size));
auto* X_mask = Buffer("mask"); auto* X_mask = Buffer("X_mask");
CHECK_EQ(X_mask->size(), reserve_size); CHECK_EQ(X_mask->size(), reserve_size);
CUDNN_CHECK(cudnnDropoutBackward( CUDNN_CHECK(cudnnDropoutBackward(
......
...@@ -9,19 +9,7 @@ ...@@ -9,19 +9,7 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The basic idea of directly run operators comes from ``caffe2``, """Execute tensor operations. """
it spends much more time on Python frontend than C++ backend,
which should not be taken for running computation-intensive operators.
We extend a new ``PERSISTENT`` engine, that hashes the arguments
as many as possible, i.e., creates a operator once while running
with arbitrary inputs and outputs many times.
Note that it is still a challenge to persist the operators which
take the argument with uncertain numerical bounds. In this case,
our engine will still create lots of duplicates.
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -30,9 +18,9 @@ from __future__ import print_function ...@@ -30,9 +18,9 @@ from __future__ import print_function
from dragon.core.eager import backprop from dragon.core.eager import backprop
from dragon.core.eager.tensor import EagerTensor from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import device_spec from dragon.core.framework import device_spec
from dragon.core.framework import config
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.util import nest
from dragon.core.util import six from dragon.core.util import six
...@@ -43,15 +31,6 @@ def run_operator( ...@@ -43,15 +31,6 @@ def run_operator(
no_grad=False, no_grad=False,
pre_callback=None, pre_callback=None,
): ):
inputs = nest.flatten(inputs)
outputs = nest.flatten(outputs)
if len(outputs) == 0:
raise ValueError(
'The number of <outputs> should be '
'at least 1. Got {}.'.format(len(outputs))
)
requires_grad = False requires_grad = False
input_names, output_names = [], [] input_names, output_names = [], []
tape = backprop.get_default_tape() tape = backprop.get_default_tape()
...@@ -68,6 +47,7 @@ def run_operator( ...@@ -68,6 +47,7 @@ def run_operator(
requires_grad = True requires_grad = True
# Allocate outputs. # Allocate outputs.
cfg = config.config()
ws = workspace.get_workspace() ws = workspace.get_workspace()
output_scope = context.get_eager_scope(requires_grad) output_scope = context.get_eager_scope(requires_grad)
gc = ws.collectors # Garbage collectors gc = ws.collectors # Garbage collectors
...@@ -100,8 +80,8 @@ def run_operator( ...@@ -100,8 +80,8 @@ def run_operator(
# Dispatch the computation. # Dispatch the computation.
if pre_callback is not None: if pre_callback is not None:
pre_callback(op_def.name) pre_callback(ws, op_def.name)
workspace.run_operator(op_def) ws.RunOperator(op_def, cfg.graph_verbosity > 0)
# Return the outputs. # Return the outputs.
return outputs if len(outputs) > 1 else outputs[0] return outputs if len(outputs) > 1 else outputs[0]
...@@ -68,13 +68,9 @@ class Operator(object): ...@@ -68,13 +68,9 @@ class Operator(object):
pre_callback=callback, pre_callback=callback,
) )
def feed_arg(self, name, value, dtype='int64'): def feed_arg(self, ws, name, value, dtype='int64'):
"""Set the value of tensors argument.""" """Set the value of tensor argument."""
workspace.get_workspace().FeedTensor( ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device)
name,
numpy.array(value, dtype),
self._arg_device,
)
@classmethod @classmethod
def instantiate(cls, **kwargs): def instantiate(cls, **kwargs):
......
...@@ -23,19 +23,19 @@ class Arange(Operator): ...@@ -23,19 +23,19 @@ class Arange(Operator):
} }
} }
def feed(self, handle, slice_args): def feed(self, ws, handle, slice_args):
for i in range(len(slice_args)): for i in range(len(slice_args)):
self.feed_arg( self.feed_arg(
'{}/slice[{}]' ws,
.format(handle, i), '{}/slice[{}]'.format(handle, i),
slice_args[i], 'float32' slice_args[i], 'float32'
) )
def forward(self, slice_args, trainable=False): def forward(self, slice_args, trainable=False):
output = self.dispatch( output = self.dispatch(
[], [self.alloc()], [], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, slice_args) self.feed(ws, handle, slice_args)
) )
output._requires_grad = trainable output._requires_grad = trainable
return output return output
...@@ -109,19 +109,19 @@ class ChannelNormalize(Operator): ...@@ -109,19 +109,19 @@ class ChannelNormalize(Operator):
} }
} }
def feed(self, handle, perm): def feed(self, ws, handle, perm):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/perm[{}]' ws,
.format(handle, i), '{}/perm[{}]'.format(handle, i),
perm[i], 'int64' perm[i], 'int64'
) )
def forward(self, inputs, perm): def forward(self, inputs, perm):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, perm) self.feed(ws, handle, perm)
) )
...@@ -199,19 +199,19 @@ class Expand(Operator): ...@@ -199,19 +199,19 @@ class Expand(Operator):
} }
} }
def feed(self, handle, dims): def feed(self, ws, handle, dims):
for i, d in enumerate(dims): for i, d in enumerate(dims):
self.feed_arg( self.feed_arg(
'{}/dims[{}]' ws,
.format(handle, i), '{}/dims[{}]'.format(handle, i),
d, 'int64' d, 'int64'
) )
def forward(self, inputs, dims): def forward(self, inputs, dims):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, dims) self.feed(ws, handle, dims)
) )
...@@ -301,8 +301,7 @@ class Moments(Operator): ...@@ -301,8 +301,7 @@ class Moments(Operator):
} }
def forward(self, inputs): def forward(self, inputs):
outputs = [self.alloc(), self.alloc()] return self.dispatch(inputs, [self.alloc(), self.alloc()])
return self.dispatch(inputs, outputs)
class Multinomial(Operator): class Multinomial(Operator):
...@@ -378,19 +377,19 @@ class Pad(Operator): ...@@ -378,19 +377,19 @@ class Pad(Operator):
} }
} }
def feed(self, handle, pads): def feed(self, ws, handle, pads):
for i, e in enumerate(pads): for i, e in enumerate(pads):
self.feed_arg( self.feed_arg(
'{}/pads[{}]' ws,
.format(handle, i), '{}/pads[{}]'.format(handle, i),
e, 'int64' e, 'int64'
) )
def forward(self, inputs, pads): def forward(self, inputs, pads):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, pads) self.feed(ws, handle, pads)
) )
...@@ -449,11 +448,11 @@ class Reshape(Operator): ...@@ -449,11 +448,11 @@ class Reshape(Operator):
} }
} }
def feed(self, handle, shape): def feed(self, ws, handle, shape):
for i, e in enumerate(shape): for i, e in enumerate(shape):
self.feed_arg( self.feed_arg(
'{}/dims[{}]' ws,
.format(handle, i), '{}/dims[{}]'.format(handle, i),
e, 'int64' e, 'int64'
) )
...@@ -461,8 +460,8 @@ class Reshape(Operator): ...@@ -461,8 +460,8 @@ class Reshape(Operator):
outputs = [inputs[0] if inplace else self.alloc()] outputs = [inputs[0] if inplace else self.alloc()]
return self.dispatch( return self.dispatch(
inputs, outputs, inputs, outputs,
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, shape) self.feed(ws, handle, shape)
) )
...@@ -486,24 +485,24 @@ class Slice(Operator): ...@@ -486,24 +485,24 @@ class Slice(Operator):
} }
} }
def feed(self, handle, starts, sizes): def feed(self, ws, handle, starts, sizes):
for i, e in enumerate(starts): for i, e in enumerate(starts):
self.feed_arg( self.feed_arg(
'{}/starts[{}]' ws,
.format(handle, i), '{}/starts[{}]'.format(handle, i),
e, 'int64' e, 'int64'
) )
self.feed_arg( self.feed_arg(
'{}/sizes[{}]' ws,
.format(handle, i), '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64' sizes[i], 'int64'
) )
def forward(self, inputs, starts, sizes): def forward(self, inputs, starts, sizes):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, starts, sizes) self.feed(ws, handle, starts, sizes)
) )
...@@ -591,19 +590,19 @@ class Tile(Operator): ...@@ -591,19 +590,19 @@ class Tile(Operator):
} }
} }
def feed(self, handle, multiples): def feed(self, ws, handle, multiples):
for i, d in enumerate(multiples): for i, d in enumerate(multiples):
self.feed_arg( self.feed_arg(
'{}/multiples[{}]' ws,
.format(handle, i), '{}/multiples[{}]'.format(handle, i),
d, 'int64' d, 'int64'
) )
def forward(self, inputs, multiples): def forward(self, inputs, multiples):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, multiples) self.feed(ws, handle, multiples)
) )
...@@ -623,19 +622,19 @@ class Transpose(Operator): ...@@ -623,19 +622,19 @@ class Transpose(Operator):
} }
} }
def feed(self, handle, perm): def feed(self, ws, handle, perm):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/perm[{}]' ws,
.format(handle, i), '{}/perm[{}]'.format(handle, i),
perm[i], 'int64' perm[i], 'int64'
) )
def forward(self, inputs, perm): def forward(self, inputs, perm):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, perm) self.feed(ws, handle, perm)
) )
......
...@@ -34,24 +34,24 @@ class Assign(Operator): ...@@ -34,24 +34,24 @@ class Assign(Operator):
}, },
} }
def feed(self, handle, starts, sizes): def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/starts[{}]' ws,
.format(handle, i), '{}/starts[{}]'.format(handle, i),
starts[i], 'int64', starts[i], 'int64',
) )
self.feed_arg( self.feed_arg(
'{}/sizes[{}]' ws,
.format(handle, i), '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64', sizes[i], 'int64',
) )
def forward(self, inputs, starts, sizes): def forward(self, ws, inputs, starts, sizes):
return self.dispatch( return self.dispatch(
[inputs[1]], [inputs[0]], [inputs[1]], [inputs[0]],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, starts, sizes), self.feed(ws, handle, starts, sizes),
no_grad=True, no_grad=True,
) )
......
...@@ -23,11 +23,11 @@ class Initializer(Operator): ...@@ -23,11 +23,11 @@ class Initializer(Operator):
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32') self.dtype = kwargs.get('dtype', 'float32')
def feed(self, handle, shape): def feed(self, ws, handle, shape):
for i, e in enumerate(shape): for i, e in enumerate(shape):
self.feed_arg( self.feed_arg(
'{}/dims[{}]' ws,
.format(handle, i), '{}/dims[{}]'.format(handle, i),
e, 'int64' e, 'int64'
) )
...@@ -49,8 +49,8 @@ class Initializer(Operator): ...@@ -49,8 +49,8 @@ class Initializer(Operator):
] ]
return self.dispatch( return self.dispatch(
inputs, outputs, inputs, outputs,
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, shape) self.feed(ws, handle, shape)
) )
......
...@@ -43,8 +43,7 @@ class _ConvNd(Operator): ...@@ -43,8 +43,7 @@ class _ConvNd(Operator):
} }
def forward(self, inputs): def forward(self, inputs):
outputs = [self.alloc()] return self.dispatch(inputs, [self.alloc()])
return self.dispatch(inputs, outputs)
class _PoolNd(Operator): class _PoolNd(Operator):
...@@ -75,8 +74,7 @@ class _PoolNd(Operator): ...@@ -75,8 +74,7 @@ class _PoolNd(Operator):
} }
def forward(self, inputs): def forward(self, inputs):
outputs = [self.alloc()] return self.dispatch(inputs, [self.alloc()])
return self.dispatch(inputs, outputs)
class BiasAdd(Operator): class BiasAdd(Operator):
...@@ -184,25 +182,25 @@ class Resize(Operator): ...@@ -184,25 +182,25 @@ class Resize(Operator):
} }
} }
def feed(self, handle, sizes, scales): def feed(self, ws, handle, sizes, scales):
for i in range(self.num_sizes): for i in range(self.num_sizes):
self.feed_arg( self.feed_arg(
'{}/sizes[{}]' ws,
.format(handle, i), '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64', sizes[i], 'int64',
) )
for i in range(self.num_scales): for i in range(self.num_scales):
self.feed_arg( self.feed_arg(
'{}/scales[{}]' ws,
.format(handle, i), '{}/scales[{}]'.format(handle, i),
scales[i], 'float32', scales[i], 'float32',
) )
def forward(self, inputs, sizes=None, scales=None): def forward(self, inputs, sizes=None, scales=None):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, sizes, scales) self.feed(ws, handle, sizes, scales)
) )
......
...@@ -17,7 +17,6 @@ import numpy ...@@ -17,7 +17,6 @@ import numpy
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.framework import proto_util from dragon.core.framework import proto_util
from dragon.core.framework import workspace
from dragon.vm.torch import executor from dragon.vm.torch import executor
...@@ -64,13 +63,9 @@ class Function(object): ...@@ -64,13 +63,9 @@ class Function(object):
pre_callback=callback, pre_callback=callback,
) )
def feed_arg(self, name, value, dtype='int64'): def feed_arg(self, ws, name, value, dtype='int64'):
"""Set the value of tensor argument.""" """Set the value of tensor argument."""
workspace.get_workspace().FeedTensor( ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device)
name,
numpy.array(value, dtype),
self._arg_device,
)
@classmethod @classmethod
def instantiate(cls, device, **kwargs): def instantiate(cls, device, **kwargs):
......
...@@ -9,27 +9,15 @@ ...@@ -9,27 +9,15 @@
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The basic idea of directly run operators comes from ``caffe2``, """Execute tensor operations. """
it spends much more time on Python frontend than C++ backend,
which should not be taken for running computation-intensive operators.
We extend a new ``PERSISTENT`` engine, that hashes the arguments
as many as possible, i.e., creates a operator once while running
with arbitrary inputs and outputs many times.
Note that it is still a challenge to persist the operators which
take the argument with uncertain numerical bounds. In this case,
our engine will still create lots of duplicates.
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.framework import config
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.util import nest
from dragon.core.util import six from dragon.core.util import six
from dragon.vm.torch.autograd import grad_mode from dragon.vm.torch.autograd import grad_mode
from dragon.vm.torch.cpp import device as Device from dragon.vm.torch.cpp import device as Device
...@@ -44,15 +32,6 @@ def run_operator( ...@@ -44,15 +32,6 @@ def run_operator(
no_grad=False, no_grad=False,
pre_callback=None, pre_callback=None,
): ):
inputs = nest.flatten(inputs)
outputs = nest.flatten(outputs)
if len(outputs) == 0:
raise ValueError(
'The number of <outputs> should be '
'at least 1. Got {}.'.format(len(outputs))
)
requires_grad = False requires_grad = False
input_names, output_names = [], [] input_names, output_names = [], []
...@@ -64,6 +43,7 @@ def run_operator( ...@@ -64,6 +43,7 @@ def run_operator(
requires_grad = requires_grad and grad_mode.is_grad_enabled() requires_grad = requires_grad and grad_mode.is_grad_enabled()
# Allocate outputs. # Allocate outputs.
cfg = config.config()
ws = workspace.get_workspace() ws = workspace.get_workspace()
output_scope = context.get_eager_scope(requires_grad) output_scope = context.get_eager_scope(requires_grad)
gc = ws.collectors # Garbage collectors gc = ws.collectors # Garbage collectors
...@@ -110,8 +90,8 @@ def run_operator( ...@@ -110,8 +90,8 @@ def run_operator(
# Dispatch the computation. # Dispatch the computation.
if pre_callback is not None: if pre_callback is not None:
pre_callback(op_def.name) pre_callback(ws, op_def.name)
workspace.run_operator(op_def) ws.RunOperator(op_def, cfg.graph_verbosity > 0)
# Return the outputs. # Return the outputs.
return outputs if len(outputs) > 1 else outputs[0] return outputs if len(outputs) > 1 else outputs[0]
...@@ -60,8 +60,7 @@ class _ConvNd(function.Function): ...@@ -60,8 +60,7 @@ class _ConvNd(function.Function):
def forward(self, input, weight, bias=None): def forward(self, input, weight, bias=None):
inputs = [input, weight] + ([bias] if bias else []) inputs = [input, weight] + ([bias] if bias else [])
outputs = [self.alloc()] return self.dispatch(inputs, [self.alloc()])
return self.dispatch(inputs, outputs)
class _Loss(function.Function): class _Loss(function.Function):
...@@ -292,8 +291,7 @@ class GroupNorm(function.Function): ...@@ -292,8 +291,7 @@ class GroupNorm(function.Function):
} }
def forward(self, input, weight, bias): def forward(self, input, weight, bias):
inputs = [input, weight, bias] return self.dispatch([input, weight, bias], [self.alloc()])
return self.dispatch(inputs, [self.alloc()])
class L1Loss(_Loss): class L1Loss(_Loss):
...@@ -396,19 +394,19 @@ class Pad(function.Function): ...@@ -396,19 +394,19 @@ class Pad(function.Function):
} }
} }
def feed(self, handle, pads): def feed(self, ws, handle, pads):
for i, e in enumerate(pads): for i, e in enumerate(pads):
self.feed_arg( self.feed_arg(
'{}/pads[{}]' ws,
.format(handle, i), '{}/pads[{}]'.format(handle, i),
e, 'int64' e, 'int64'
) )
def forward(self, input, pads): def forward(self, input, pads):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, pads), self.feed(ws, handle, pads),
) )
...@@ -524,25 +522,25 @@ class Resize(function.Function): ...@@ -524,25 +522,25 @@ class Resize(function.Function):
} }
} }
def feed(self, handle, sizes, scales): def feed(self, ws, handle, sizes, scales):
for i in range(self.num_sizes): for i in range(self.num_sizes):
self.feed_arg( self.feed_arg(
'{}/sizes[{}]' ws,
.format(handle, i), '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64', sizes[i], 'int64',
) )
for i in range(self.num_scales): for i in range(self.num_scales):
self.feed_arg( self.feed_arg(
'{}/scales[{}]' ws,
.format(handle, i), '{}/scales[{}]'.format(handle, i),
scales[i], 'float32', scales[i], 'float32',
) )
def forward(self, input, sizes=None, scales=None): def forward(self, input, sizes=None, scales=None):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, sizes, scales) self.feed(ws, handle, sizes, scales)
) )
......
...@@ -58,16 +58,16 @@ class Assign(function.Function): ...@@ -58,16 +58,16 @@ class Assign(function.Function):
}, },
} }
def feed(self, handle, starts, sizes): def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/starts[{}]' ws,
.format(handle, i), '{}/starts[{}]'.format(handle, i),
starts[i], 'int64', starts[i], 'int64',
) )
self.feed_arg( self.feed_arg(
'{}/sizes[{}]' ws,
.format(handle, i), '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64', sizes[i], 'int64',
) )
...@@ -75,8 +75,8 @@ class Assign(function.Function): ...@@ -75,8 +75,8 @@ class Assign(function.Function):
self._check_device([input, out]) self._check_device([input, out])
return self.dispatch( return self.dispatch(
[input], [out], [input], [out],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, starts, sizes), self.feed(ws, handle, starts, sizes),
no_grad=True, no_grad=True,
check_device=False, check_device=False,
) )
...@@ -127,19 +127,19 @@ class ChannelNormalize(function.Function): ...@@ -127,19 +127,19 @@ class ChannelNormalize(function.Function):
} }
} }
def feed(self, handle, perm): def feed(self, ws, handle, perm):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/perm[{}]' ws,
.format(handle, i), '{}/perm[{}]'.format(handle, i),
perm[i], 'int64', perm[i], 'int64',
) )
def forward(self, input, perm): def forward(self, input, perm):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, perm), self.feed(ws, handle, perm),
) )
...@@ -220,19 +220,19 @@ class Expand(function.Function): ...@@ -220,19 +220,19 @@ class Expand(function.Function):
}, },
} }
def feed(self, handle, times): def feed(self, ws, handle, times):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/dims[{}]' ws,
.format(handle, i), '{}/dims[{}]'.format(handle, i),
times[i], 'int64', times[i], 'int64',
) )
def forward(self, input, dims): def forward(self, input, dims):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, dims), self.feed(ws, handle, dims),
) )
...@@ -366,11 +366,11 @@ class Reshape(function.Function): ...@@ -366,11 +366,11 @@ class Reshape(function.Function):
}, },
} }
def feed(self, handle, shape): def feed(self, ws, handle, shape):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/dims[{}]' ws,
.format(handle, i), '{}/dims[{}]'.format(handle, i),
shape[i], 'int64', shape[i], 'int64',
) )
...@@ -378,8 +378,8 @@ class Reshape(function.Function): ...@@ -378,8 +378,8 @@ class Reshape(function.Function):
out = out if out else self.alloc() out = out if out else self.alloc()
return self.dispatch( return self.dispatch(
[input], [out], [input], [out],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, shape), self.feed(ws, handle, shape),
) )
...@@ -403,24 +403,24 @@ class Slice(function.Function): ...@@ -403,24 +403,24 @@ class Slice(function.Function):
}, },
} }
def feed(self, handle, starts, sizes): def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/starts[{}]' ws,
.format(handle, i), '{}/starts[{}]'.format(handle, i),
starts[i], 'int64', starts[i], 'int64',
) )
self.feed_arg( self.feed_arg(
'{}/sizes[{}]' ws,
.format(handle, i), '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64', sizes[i], 'int64',
) )
def forward(self, input, starts, sizes): def forward(self, input, starts, sizes):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, starts, sizes) self.feed(ws, handle, starts, sizes)
) )
...@@ -496,19 +496,19 @@ class Tile(function.Function): ...@@ -496,19 +496,19 @@ class Tile(function.Function):
}, },
} }
def feed(self, handle, times): def feed(self, ws, handle, times):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/multiples[{}]' ws,
.format(handle, i), '{}/multiples[{}]'.format(handle, i),
times[i], 'int64', times[i], 'int64',
) )
def forward(self, input, times): def forward(self, input, times):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, times), self.feed(ws, handle, times),
) )
...@@ -528,19 +528,19 @@ class Transpose(function.Function): ...@@ -528,19 +528,19 @@ class Transpose(function.Function):
}, },
} }
def feed(self, handle, perm): def feed(self, ws, handle, perm):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/perm[{}]' ws,
.format(handle, i), '{}/perm[{}]'.format(handle, i),
perm[i], 'int64', perm[i], 'int64',
) )
def forward(self, input, perm): def forward(self, input, perm):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, perm), self.feed(ws, handle, perm),
) )
......
...@@ -22,19 +22,19 @@ class _Initializer(function.Function): ...@@ -22,19 +22,19 @@ class _Initializer(function.Function):
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32') self.dtype = kwargs.get('dtype', 'float32')
def feed(self, handle, shape): def feed(self, ws, handle, shape):
for i in range(self.ndim): for i in range(self.ndim):
self.feed_arg( self.feed_arg(
'{}/dims[{}]' ws,
.format(handle, i), '{}/dims[{}]'.format(handle, i),
shape[i], 'int64', shape[i], 'int64',
) )
def forward(self, out, shape, shape_like=None): def forward(self, out, shape, shape_like=None):
return self.dispatch( return self.dispatch(
[] if shape_like is None else [shape_like], [out], [] if shape_like is None else [shape_like], [out],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, shape), self.feed(ws, handle, shape),
) )
...@@ -56,19 +56,19 @@ class Arange(function.Function): ...@@ -56,19 +56,19 @@ class Arange(function.Function):
} }
} }
def feed(self, handle, slice_args): def feed(self, ws, handle, slice_args):
for i in range(len(slice_args)): for i in range(len(slice_args)):
self.feed_arg( self.feed_arg(
'{}/slice[{}]' ws,
.format(handle, i), '{}/slice[{}]'.format(handle, i),
slice_args[i], 'float32' slice_args[i], 'float32'
) )
def forward(self, slice_args): def forward(self, slice_args):
return self.dispatch( return self.dispatch(
[], [self.alloc()], [], [self.alloc()],
callback=lambda handle: callback=lambda ws, handle:
self.feed(handle, slice_args) self.feed(ws, handle, slice_args)
) )
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!