Commit df172cc8 by Ting PAN

Simplify the operation executor

Summary:
This commit removes the redundant workspace reference
when executing a tensor operation.
1 parent b37d4e5e
......@@ -37,6 +37,7 @@ void GraphGradientMaker::Make(
Map<string, int> inputs_count, grads_count;
Set<string> all_split_grads, targets_set;
Map<string, string> targets_to_grads;
// PLAY for the forward
for (auto* op : forward_ops) {
if (NoGradientRegistry()->Has(op->type())) continue;
......@@ -51,6 +52,7 @@ void GraphGradientMaker::Make(
if (!input_in_outputs) inputs_count[input]++;
}
}
// PLAY for the backward
for (int i = 0; i < targets.size(); ++i) {
// Set the gradient of targets
......@@ -59,17 +61,19 @@ void GraphGradientMaker::Make(
}
targets_set.insert(targets[i]);
}
for (int i = (int)forward_ops.size() - 1; i >= 0; --i) {
for (int op_idx = (int)forward_ops.size() - 1; op_idx >= 0; --op_idx) {
// Collect inputs and outputs, generate raw gradient ops
const OperatorDef& op = *forward_ops[i];
const OperatorDef& op = *forward_ops[op_idx];
vector<pair<string, int>> gen_grads;
bool is_skip = CheckGrad(op, targets_set, gen_grads);
vector<string> g_outputs;
for (auto& output : op.output()) {
for (const auto& output : op.output()) {
string g_output = "";
if (inputs_to_grads_.count(output) > 0)
if (inputs_to_grads_.count(output) > 0) {
g_output = inputs_to_grads_[output];
g_outputs.emplace_back(g_output);
}
g_outputs.push_back(g_output);
}
auto grad = MakeGradientForOp(op, g_outputs);
......@@ -82,15 +86,21 @@ void GraphGradientMaker::Make(
for (int i = 0; i < grad_op.output_size(); ++i) {
auto* output = grad_op.mutable_output(i);
int original_idx = -1;
for (int j = 0; j < grad.g_inputs.size(); ++j)
if (grad_op.output(i) == grad.g_inputs[j]) original_idx = j;
for (int j = 0; j < grad.g_inputs.size(); ++j) {
if (grad_op.output(i) == grad.g_inputs[j]) {
original_idx = j;
}
}
// Ignore unused && in-placee GI
if (original_idx == -1) continue;
bool output_in_inputs = false;
for (const auto& input : grad_op.input())
if (grad_op.output(i) == input) output_in_inputs = true;
for (const auto& input : grad_op.input()) {
if (grad_op.output(i) == input) {
output_in_inputs = true;
}
}
if (output_in_inputs) continue;
// Found a split branch
// Find a split branch
const auto& original_name = op.input(original_idx);
if (inputs_count[original_name] > 1) {
// Split
......@@ -120,7 +130,7 @@ void GraphGradientMaker::Make(
// Now, append the required ops
if (!is_skip) {
// 1) GradientGenerateOp
// GradientGenerateOp
if (gen_grads.size() > 0) {
vector<string> op_inputs, op_outputs;
Argument arg_defaults;
......@@ -141,15 +151,18 @@ void GraphGradientMaker::Make(
}
backward_def.add_op()->CopyFrom(generate_op);
}
// 2) GradientOp
for (const auto& grad_op : grad.ops)
// GradientOp
for (const auto& grad_op : grad.ops) {
backward_def.add_op()->CopyFrom(grad_op);
}
// 3) GradientGatherOp
for (const auto& gather_op : gather_ops)
}
// GradientGatherOp
for (const auto& gather_op : gather_ops) {
backward_def.add_op()->CopyFrom(gather_op);
}
// Done!
// Done
if (!is_skip) {
for (int i = 0; i < op.input_size(); ++i) {
if (!grad.g_inputs[i].empty())
......@@ -162,32 +175,40 @@ void GraphGradientMaker::Make(
GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
Set<int> invalid_ops;
Map<string, int> ref_count;
Map<string, pair<int, string>> ssa_map;
// Count the refs for detecting leaf nodes
Map<string, pair<int, string>> gather_map;
for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) {
const auto& op = input_def.op(op_idx);
if (!str::find(op.type(), "Gradient")) continue;
// Flag the gathering gradients
if (op.type() == "GradientGather") {
invalid_ops.insert(op_idx);
if (ignored_grads_.count(op.output(0))) {
for (const auto& input : op.input())
for (const auto& input : op.input()) {
ignored_grads_.insert(input);
}
continue;
} else {
string head;
string first_input;
for (const auto& input : op.input()) {
if (!input.empty()) {
if (head.empty()) head = input;
ssa_map[input] = {op_idx, head};
if (first_input.empty()) first_input = input;
gather_map[input] = {op_idx, first_input};
}
}
}
}
// Count the references to detect leafs
for (const auto& input : op.input()) {
if (str::find(input, "grad")) {
ref_count[input] += 1;
}
}
for (const auto& input : op.input())
if (str::find(input, "grad")) ref_count[input] += 1;
}
// Decompose the <GradientGather> in SSA format
// Decompose the <GradientGather> into <GradientAdd>
// This trick accumulates the split to target right after computing,
// which helps to reduce the total number of buffers.
GraphDef output_def(input_def);
output_def.clear_op();
for (int op_idx = 0; op_idx < input_def.op_size(); ++op_idx) {
......@@ -196,20 +217,20 @@ GraphDef GraphGradientMaker::Share(const GraphDef& input_def) {
output_def.add_op()->CopyFrom(op);
if (!str::find(op.type(), "Gradient")) continue;
for (const auto& output : op.output()) {
const auto& find_iter = ssa_map.find(output);
if (find_iter != ssa_map.end()) {
const auto& find_iter = gather_map.find(output);
if (find_iter != gather_map.end()) {
const auto& gather_op = input_def.op(find_iter->second.first);
auto acc_op(gather_op);
acc_op.clear_input();
auto add_op(gather_op);
add_op.clear_input();
if (output != find_iter->second.second) {
acc_op.set_type("GradientAdd");
add_op.set_type("GradientAdd");
// Make an in-place to avoid a new buffer
acc_op.add_input(gather_op.output(0));
add_op.add_input(gather_op.output(0));
const auto& ref_iter = ref_count.find(gather_op.output(0));
if (ref_iter != ref_count.end()) ref_iter->second++;
}
acc_op.add_input(output);
output_def.add_op()->CopyFrom(acc_op);
add_op.add_input(output);
output_def.add_op()->CopyFrom(add_op);
}
}
}
......
......@@ -164,7 +164,7 @@ class DRAGON_API UnifiedMemory {
/*! \brief The binding cpu tensor for cnml */
cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr;
/*! \brief Tje binding mlu tensor for cnml */
/*! \brief The binding mlu tensor for cnml */
cnmlTensor_t cnml_mlu_tensor_ = nullptr;
};
......
......@@ -236,14 +236,18 @@ PYBIND11_MODULE(libdragon_python, m) {
const bool verbose) {
GraphDef backward_ops;
GraphGradientMaker maker;
for (auto& name : ignored_grads)
for (const auto& name : ignored_grads) {
maker.add_ignored_grad(name);
for (auto& name : sources)
}
for (const auto& name : sources) {
maker.add_hooked_grad(name + "_grad");
}
maker.Make(forward_ops, targets, input_grads, backward_ops);
py::gil_scoped_release g;
if (is_sharing) backward_ops = maker.Share(backward_ops);
for (auto& def : backward_ops.op()) {
if (is_sharing) {
backward_ops = maker.Share(backward_ops);
}
for (const auto& def : backward_ops.op()) {
if (verbose) {
auto msg = string("\n") + def.DebugString();
msg.pop_back();
......@@ -268,8 +272,9 @@ PYBIND11_MODULE(libdragon_python, m) {
<< "Can't be used in C++.";
break;
case 1: // CaffeModel
for (const auto& e : tensors)
refs.emplace_back(self->GetTensor(e));
for (const auto& name : tensors) {
refs.emplace_back(self->GetTensor(name));
}
SavaCaffeModel(filename, refs);
break;
default:
......
......@@ -50,7 +50,7 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
// Allocate for the reserve space
size_t reserve_size;
CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size));
auto* X_mask = Buffer("mask")->Reshape({(int64_t)reserve_size});
auto* X_mask = Buffer("X_mask")->Reshape({(int64_t)reserve_size});
CUDNN_CHECK(cudnnDropoutForward(
ctx()->cudnn_handle(),
......@@ -106,7 +106,7 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
// Check the reserve space
size_t reserve_size;
CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size));
auto* X_mask = Buffer("mask");
auto* X_mask = Buffer("X_mask");
CHECK_EQ(X_mask->size(), reserve_size);
CUDNN_CHECK(cudnnDropoutBackward(
......
......@@ -9,19 +9,7 @@
#
# ------------------------------------------------------------
"""The basic idea of directly run operators comes from ``caffe2``,
it spends much more time on Python frontend than C++ backend,
which should not be taken for running computation-intensive operators.
We extend a new ``PERSISTENT`` engine, that hashes the arguments
as many as possible, i.e., creates a operator once while running
with arbitrary inputs and outputs many times.
Note that it is still a challenge to persist the operators which
take the argument with uncertain numerical bounds. In this case,
our engine will still create lots of duplicates.
"""
"""Execute tensor operations. """
from __future__ import absolute_import
from __future__ import division
......@@ -30,9 +18,9 @@ from __future__ import print_function
from dragon.core.eager import backprop
from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import device_spec
from dragon.core.framework import config
from dragon.core.framework import context
from dragon.core.framework import workspace
from dragon.core.util import nest
from dragon.core.util import six
......@@ -43,15 +31,6 @@ def run_operator(
no_grad=False,
pre_callback=None,
):
inputs = nest.flatten(inputs)
outputs = nest.flatten(outputs)
if len(outputs) == 0:
raise ValueError(
'The number of <outputs> should be '
'at least 1. Got {}.'.format(len(outputs))
)
requires_grad = False
input_names, output_names = [], []
tape = backprop.get_default_tape()
......@@ -68,6 +47,7 @@ def run_operator(
requires_grad = True
# Allocate outputs.
cfg = config.config()
ws = workspace.get_workspace()
output_scope = context.get_eager_scope(requires_grad)
gc = ws.collectors # Garbage collectors
......@@ -100,8 +80,8 @@ def run_operator(
# Dispatch the computation.
if pre_callback is not None:
pre_callback(op_def.name)
workspace.run_operator(op_def)
pre_callback(ws, op_def.name)
ws.RunOperator(op_def, cfg.graph_verbosity > 0)
# Return the outputs.
return outputs if len(outputs) > 1 else outputs[0]
......@@ -68,13 +68,9 @@ class Operator(object):
pre_callback=callback,
)
def feed_arg(self, name, value, dtype='int64'):
"""Set the value of tensors argument."""
workspace.get_workspace().FeedTensor(
name,
numpy.array(value, dtype),
self._arg_device,
)
def feed_arg(self, ws, name, value, dtype='int64'):
"""Set the value of tensor argument."""
ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device)
@classmethod
def instantiate(cls, **kwargs):
......
......@@ -23,19 +23,19 @@ class Arange(Operator):
}
}
def feed(self, handle, slice_args):
def feed(self, ws, handle, slice_args):
for i in range(len(slice_args)):
self.feed_arg(
'{}/slice[{}]'
.format(handle, i),
ws,
'{}/slice[{}]'.format(handle, i),
slice_args[i], 'float32'
)
def forward(self, slice_args, trainable=False):
output = self.dispatch(
[], [self.alloc()],
callback=lambda handle:
self.feed(handle, slice_args)
callback=lambda ws, handle:
self.feed(ws, handle, slice_args)
)
output._requires_grad = trainable
return output
......@@ -109,19 +109,19 @@ class ChannelNormalize(Operator):
}
}
def feed(self, handle, perm):
def feed(self, ws, handle, perm):
for i in range(self.ndim):
self.feed_arg(
'{}/perm[{}]'
.format(handle, i),
ws,
'{}/perm[{}]'.format(handle, i),
perm[i], 'int64'
)
def forward(self, inputs, perm):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda handle:
self.feed(handle, perm)
callback=lambda ws, handle:
self.feed(ws, handle, perm)
)
......@@ -199,19 +199,19 @@ class Expand(Operator):
}
}
def feed(self, handle, dims):
def feed(self, ws, handle, dims):
for i, d in enumerate(dims):
self.feed_arg(
'{}/dims[{}]'
.format(handle, i),
ws,
'{}/dims[{}]'.format(handle, i),
d, 'int64'
)
def forward(self, inputs, dims):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda handle:
self.feed(handle, dims)
callback=lambda ws, handle:
self.feed(ws, handle, dims)
)
......@@ -301,8 +301,7 @@ class Moments(Operator):
}
def forward(self, inputs):
outputs = [self.alloc(), self.alloc()]
return self.dispatch(inputs, outputs)
return self.dispatch(inputs, [self.alloc(), self.alloc()])
class Multinomial(Operator):
......@@ -378,19 +377,19 @@ class Pad(Operator):
}
}
def feed(self, handle, pads):
def feed(self, ws, handle, pads):
for i, e in enumerate(pads):
self.feed_arg(
'{}/pads[{}]'
.format(handle, i),
ws,
'{}/pads[{}]'.format(handle, i),
e, 'int64'
)
def forward(self, inputs, pads):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda handle:
self.feed(handle, pads)
callback=lambda ws, handle:
self.feed(ws, handle, pads)
)
......@@ -449,11 +448,11 @@ class Reshape(Operator):
}
}
def feed(self, handle, shape):
def feed(self, ws, handle, shape):
for i, e in enumerate(shape):
self.feed_arg(
'{}/dims[{}]'
.format(handle, i),
ws,
'{}/dims[{}]'.format(handle, i),
e, 'int64'
)
......@@ -461,8 +460,8 @@ class Reshape(Operator):
outputs = [inputs[0] if inplace else self.alloc()]
return self.dispatch(
inputs, outputs,
callback=lambda handle:
self.feed(handle, shape)
callback=lambda ws, handle:
self.feed(ws, handle, shape)
)
......@@ -486,24 +485,24 @@ class Slice(Operator):
}
}
def feed(self, handle, starts, sizes):
def feed(self, ws, handle, starts, sizes):
for i, e in enumerate(starts):
self.feed_arg(
'{}/starts[{}]'
.format(handle, i),
ws,
'{}/starts[{}]'.format(handle, i),
e, 'int64'
)
self.feed_arg(
'{}/sizes[{}]'
.format(handle, i),
ws,
'{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64'
)
def forward(self, inputs, starts, sizes):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda handle:
self.feed(handle, starts, sizes)
callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes)
)
......@@ -591,19 +590,19 @@ class Tile(Operator):
}
}
def feed(self, handle, multiples):
def feed(self, ws, handle, multiples):
for i, d in enumerate(multiples):
self.feed_arg(
'{}/multiples[{}]'
.format(handle, i),
ws,
'{}/multiples[{}]'.format(handle, i),
d, 'int64'
)
def forward(self, inputs, multiples):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda handle:
self.feed(handle, multiples)
callback=lambda ws, handle:
self.feed(ws, handle, multiples)
)
......@@ -623,19 +622,19 @@ class Transpose(Operator):
}
}
def feed(self, handle, perm):
def feed(self, ws, handle, perm):
for i in range(self.ndim):
self.feed_arg(
'{}/perm[{}]'
.format(handle, i),
ws,
'{}/perm[{}]'.format(handle, i),
perm[i], 'int64'
)
def forward(self, inputs, perm):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda handle:
self.feed(handle, perm)
callback=lambda ws, handle:
self.feed(ws, handle, perm)
)
......
......@@ -34,24 +34,24 @@ class Assign(Operator):
},
}
def feed(self, handle, starts, sizes):
def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim):
self.feed_arg(
'{}/starts[{}]'
.format(handle, i),
ws,
'{}/starts[{}]'.format(handle, i),
starts[i], 'int64',
)
self.feed_arg(
'{}/sizes[{}]'
.format(handle, i),
ws,
'{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64',
)
def forward(self, inputs, starts, sizes):
def forward(self, ws, inputs, starts, sizes):
return self.dispatch(
[inputs[1]], [inputs[0]],
callback=lambda handle:
self.feed(handle, starts, sizes),
callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes),
no_grad=True,
)
......
......@@ -23,11 +23,11 @@ class Initializer(Operator):
self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32')
def feed(self, handle, shape):
def feed(self, ws, handle, shape):
for i, e in enumerate(shape):
self.feed_arg(
'{}/dims[{}]'
.format(handle, i),
ws,
'{}/dims[{}]'.format(handle, i),
e, 'int64'
)
......@@ -49,8 +49,8 @@ class Initializer(Operator):
]
return self.dispatch(
inputs, outputs,
callback=lambda handle:
self.feed(handle, shape)
callback=lambda ws, handle:
self.feed(ws, handle, shape)
)
......
......@@ -43,8 +43,7 @@ class _ConvNd(Operator):
}
def forward(self, inputs):
outputs = [self.alloc()]
return self.dispatch(inputs, outputs)
return self.dispatch(inputs, [self.alloc()])
class _PoolNd(Operator):
......@@ -75,8 +74,7 @@ class _PoolNd(Operator):
}
def forward(self, inputs):
outputs = [self.alloc()]
return self.dispatch(inputs, outputs)
return self.dispatch(inputs, [self.alloc()])
class BiasAdd(Operator):
......@@ -184,25 +182,25 @@ class Resize(Operator):
}
}
def feed(self, handle, sizes, scales):
def feed(self, ws, handle, sizes, scales):
for i in range(self.num_sizes):
self.feed_arg(
'{}/sizes[{}]'
.format(handle, i),
ws,
'{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64',
)
for i in range(self.num_scales):
self.feed_arg(
'{}/scales[{}]'
.format(handle, i),
ws,
'{}/scales[{}]'.format(handle, i),
scales[i], 'float32',
)
def forward(self, inputs, sizes=None, scales=None):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda handle:
self.feed(handle, sizes, scales)
callback=lambda ws, handle:
self.feed(ws, handle, sizes, scales)
)
......
......@@ -17,7 +17,6 @@ import numpy
from dragon.core.framework import config
from dragon.core.framework import proto_util
from dragon.core.framework import workspace
from dragon.vm.torch import executor
......@@ -64,13 +63,9 @@ class Function(object):
pre_callback=callback,
)
def feed_arg(self, name, value, dtype='int64'):
def feed_arg(self, ws, name, value, dtype='int64'):
"""Set the value of tensor argument."""
workspace.get_workspace().FeedTensor(
name,
numpy.array(value, dtype),
self._arg_device,
)
ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device)
@classmethod
def instantiate(cls, device, **kwargs):
......
......@@ -9,27 +9,15 @@
#
# ------------------------------------------------------------
"""The basic idea of directly run operators comes from ``caffe2``,
it spends much more time on Python frontend than C++ backend,
which should not be taken for running computation-intensive operators.
We extend a new ``PERSISTENT`` engine, that hashes the arguments
as many as possible, i.e., creates a operator once while running
with arbitrary inputs and outputs many times.
Note that it is still a challenge to persist the operators which
take the argument with uncertain numerical bounds. In this case,
our engine will still create lots of duplicates.
"""
"""Execute tensor operations. """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.framework import config
from dragon.core.framework import context
from dragon.core.framework import workspace
from dragon.core.util import nest
from dragon.core.util import six
from dragon.vm.torch.autograd import grad_mode
from dragon.vm.torch.cpp import device as Device
......@@ -44,15 +32,6 @@ def run_operator(
no_grad=False,
pre_callback=None,
):
inputs = nest.flatten(inputs)
outputs = nest.flatten(outputs)
if len(outputs) == 0:
raise ValueError(
'The number of <outputs> should be '
'at least 1. Got {}.'.format(len(outputs))
)
requires_grad = False
input_names, output_names = [], []
......@@ -64,6 +43,7 @@ def run_operator(
requires_grad = requires_grad and grad_mode.is_grad_enabled()
# Allocate outputs.
cfg = config.config()
ws = workspace.get_workspace()
output_scope = context.get_eager_scope(requires_grad)
gc = ws.collectors # Garbage collectors
......@@ -110,8 +90,8 @@ def run_operator(
# Dispatch the computation.
if pre_callback is not None:
pre_callback(op_def.name)
workspace.run_operator(op_def)
pre_callback(ws, op_def.name)
ws.RunOperator(op_def, cfg.graph_verbosity > 0)
# Return the outputs.
return outputs if len(outputs) > 1 else outputs[0]
......@@ -60,8 +60,7 @@ class _ConvNd(function.Function):
def forward(self, input, weight, bias=None):
inputs = [input, weight] + ([bias] if bias else [])
outputs = [self.alloc()]
return self.dispatch(inputs, outputs)
return self.dispatch(inputs, [self.alloc()])
class _Loss(function.Function):
......@@ -292,8 +291,7 @@ class GroupNorm(function.Function):
}
def forward(self, input, weight, bias):
inputs = [input, weight, bias]
return self.dispatch(inputs, [self.alloc()])
return self.dispatch([input, weight, bias], [self.alloc()])
class L1Loss(_Loss):
......@@ -396,19 +394,19 @@ class Pad(function.Function):
}
}
def feed(self, handle, pads):
def feed(self, ws, handle, pads):
for i, e in enumerate(pads):
self.feed_arg(
'{}/pads[{}]'
.format(handle, i),
ws,
'{}/pads[{}]'.format(handle, i),
e, 'int64'
)
def forward(self, input, pads):
return self.dispatch(
[input], [self.alloc()],
callback=lambda handle:
self.feed(handle, pads),
callback=lambda ws, handle:
self.feed(ws, handle, pads),
)
......@@ -524,25 +522,25 @@ class Resize(function.Function):
}
}
def feed(self, handle, sizes, scales):
def feed(self, ws, handle, sizes, scales):
for i in range(self.num_sizes):
self.feed_arg(
'{}/sizes[{}]'
.format(handle, i),
ws,
'{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64',
)
for i in range(self.num_scales):
self.feed_arg(
'{}/scales[{}]'
.format(handle, i),
ws,
'{}/scales[{}]'.format(handle, i),
scales[i], 'float32',
)
def forward(self, input, sizes=None, scales=None):
return self.dispatch(
[input], [self.alloc()],
callback=lambda handle:
self.feed(handle, sizes, scales)
callback=lambda ws, handle:
self.feed(ws, handle, sizes, scales)
)
......
......@@ -58,16 +58,16 @@ class Assign(function.Function):
},
}
def feed(self, handle, starts, sizes):
def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim):
self.feed_arg(
'{}/starts[{}]'
.format(handle, i),
ws,
'{}/starts[{}]'.format(handle, i),
starts[i], 'int64',
)
self.feed_arg(
'{}/sizes[{}]'
.format(handle, i),
ws,
'{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64',
)
......@@ -75,8 +75,8 @@ class Assign(function.Function):
self._check_device([input, out])
return self.dispatch(
[input], [out],
callback=lambda handle:
self.feed(handle, starts, sizes),
callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes),
no_grad=True,
check_device=False,
)
......@@ -127,19 +127,19 @@ class ChannelNormalize(function.Function):
}
}
def feed(self, handle, perm):
def feed(self, ws, handle, perm):
for i in range(self.ndim):
self.feed_arg(
'{}/perm[{}]'
.format(handle, i),
ws,
'{}/perm[{}]'.format(handle, i),
perm[i], 'int64',
)
def forward(self, input, perm):
return self.dispatch(
[input], [self.alloc()],
callback=lambda handle:
self.feed(handle, perm),
callback=lambda ws, handle:
self.feed(ws, handle, perm),
)
......@@ -220,19 +220,19 @@ class Expand(function.Function):
},
}
def feed(self, handle, times):
def feed(self, ws, handle, times):
for i in range(self.ndim):
self.feed_arg(
'{}/dims[{}]'
.format(handle, i),
ws,
'{}/dims[{}]'.format(handle, i),
times[i], 'int64',
)
def forward(self, input, dims):
return self.dispatch(
[input], [self.alloc()],
callback=lambda handle:
self.feed(handle, dims),
callback=lambda ws, handle:
self.feed(ws, handle, dims),
)
......@@ -366,11 +366,11 @@ class Reshape(function.Function):
},
}
def feed(self, handle, shape):
def feed(self, ws, handle, shape):
for i in range(self.ndim):
self.feed_arg(
'{}/dims[{}]'
.format(handle, i),
ws,
'{}/dims[{}]'.format(handle, i),
shape[i], 'int64',
)
......@@ -378,8 +378,8 @@ class Reshape(function.Function):
out = out if out else self.alloc()
return self.dispatch(
[input], [out],
callback=lambda handle:
self.feed(handle, shape),
callback=lambda ws, handle:
self.feed(ws, handle, shape),
)
......@@ -403,24 +403,24 @@ class Slice(function.Function):
},
}
def feed(self, handle, starts, sizes):
def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim):
self.feed_arg(
'{}/starts[{}]'
.format(handle, i),
ws,
'{}/starts[{}]'.format(handle, i),
starts[i], 'int64',
)
self.feed_arg(
'{}/sizes[{}]'
.format(handle, i),
ws,
'{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64',
)
def forward(self, input, starts, sizes):
return self.dispatch(
[input], [self.alloc()],
callback=lambda handle:
self.feed(handle, starts, sizes)
callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes)
)
......@@ -496,19 +496,19 @@ class Tile(function.Function):
},
}
def feed(self, handle, times):
def feed(self, ws, handle, times):
for i in range(self.ndim):
self.feed_arg(
'{}/multiples[{}]'
.format(handle, i),
ws,
'{}/multiples[{}]'.format(handle, i),
times[i], 'int64',
)
def forward(self, input, times):
return self.dispatch(
[input], [self.alloc()],
callback=lambda handle:
self.feed(handle, times),
callback=lambda ws, handle:
self.feed(ws, handle, times),
)
......@@ -528,19 +528,19 @@ class Transpose(function.Function):
},
}
def feed(self, handle, perm):
def feed(self, ws, handle, perm):
for i in range(self.ndim):
self.feed_arg(
'{}/perm[{}]'
.format(handle, i),
ws,
'{}/perm[{}]'.format(handle, i),
perm[i], 'int64',
)
def forward(self, input, perm):
return self.dispatch(
[input], [self.alloc()],
callback=lambda handle:
self.feed(handle, perm),
callback=lambda ws, handle:
self.feed(ws, handle, perm),
)
......
......@@ -22,19 +22,19 @@ class _Initializer(function.Function):
self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32')
def feed(self, handle, shape):
def feed(self, ws, handle, shape):
for i in range(self.ndim):
self.feed_arg(
'{}/dims[{}]'
.format(handle, i),
ws,
'{}/dims[{}]'.format(handle, i),
shape[i], 'int64',
)
def forward(self, out, shape, shape_like=None):
return self.dispatch(
[] if shape_like is None else [shape_like], [out],
callback=lambda handle:
self.feed(handle, shape),
callback=lambda ws, handle:
self.feed(ws, handle, shape),
)
......@@ -56,19 +56,19 @@ class Arange(function.Function):
}
}
def feed(self, handle, slice_args):
def feed(self, ws, handle, slice_args):
for i in range(len(slice_args)):
self.feed_arg(
'{}/slice[{}]'
.format(handle, i),
ws,
'{}/slice[{}]'.format(handle, i),
slice_args[i], 'float32'
)
def forward(self, slice_args):
return self.dispatch(
[], [self.alloc()],
callback=lambda handle:
self.feed(handle, slice_args)
callback=lambda ws, handle:
self.feed(ws, handle, slice_args)
)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!