Commit 94863c22 by Ting PAN

Fix the disorder while compiling ops

1 parent f4e789be
...@@ -17,28 +17,30 @@ class RepeatOp : public Operator<Context> { ...@@ -17,28 +17,30 @@ class RepeatOp : public Operator<Context> {
RepeatOp(const OperatorDef& op_def, Workspace* ws) RepeatOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
repeats(OperatorBase::GetSingleArg<int>("repeats", 1)) {} repeats_desc(OperatorBase::GetSingleArg<string>("repeats", "")) {}
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunWithType();
protected: protected:
TIndex axis, repeats, outer_dim, dim, inner_dim; TIndex axis, outer_dim, dim, inner_dim, reps;
string repeats_desc;
}; };
template <class Context> template <class Context>
class RepeatGradientOp : public Operator<Context> { class RepeatGradientOp : public Operator<Context> {
public: public:
RepeatGradientOp(const OperatorDef& op_def, Workspace* ws) RepeatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
repeats(OperatorBase::GetSingleArg<int>("repeats", 1)) {} repeats_desc(OperatorBase::GetSingleArg<string>("repeats", "")) {}
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void RunWithType(); template<typename T> void RunWithType();
protected: protected:
TIndex axis, repeats, outer_dim, dim, inner_dim; TIndex axis, outer_dim, dim, inner_dim, reps;
string repeats_desc;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -33,20 +33,21 @@ class L2NormOp final : public Operator<Context> { ...@@ -33,20 +33,21 @@ class L2NormOp final : public Operator<Context> {
TIndex outer_dim, dim, inner_dim, spatial_dim; TIndex outer_dim, dim, inner_dim, spatial_dim;
}; };
template <class Context> template <class Context>
class L2NormGradientOp final : public Operator<Context> { class L2NormGradientOp final : public Operator<Context> {
public: public:
L2NormGradientOp(const OperatorDef& op_def, Workspace* ws) L2NormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)) {} num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
TIndex axis, num_axes, end_axis; TIndex axis, num_axes, end_axis;
string mode;
bool across_inner; bool across_inner;
Tensor* norm, *multiplier, *buffer, *buffer_inner; Tensor* norm, *multiplier, *buffer, *buffer_inner;
TIndex outer_dim, dim, inner_dim; TIndex outer_dim, dim, inner_dim;
......
...@@ -279,7 +279,6 @@ def Reduce(inputs, axis=-1, operation='NONE', keep_dims=False, **kwargs): ...@@ -279,7 +279,6 @@ def Reduce(inputs, axis=-1, operation='NONE', keep_dims=False, **kwargs):
output.shape[i] = 1 output.shape[i] = 1
else: output.shape = [1] else: output.shape = [1]
else: else:
if keep_dims: output.shape[axis] = 1 if keep_dims: output.shape[axis] = 1
else: del output.shape[axis] else: del output.shape[axis]
...@@ -445,7 +444,7 @@ def Repeat(inputs, axis=-1, repeats=1, **kwargs): ...@@ -445,7 +444,7 @@ def Repeat(inputs, axis=-1, repeats=1, **kwargs):
The input tensor. The input tensor.
axis : int axis : int
The axis to repeat. Defaults is ``-1`` (Repeat as Scalar). The axis to repeat. Defaults is ``-1`` (Repeat as Scalar).
repeats : int repeats : int or Tensor
The magnitude of repeating. The magnitude of repeating.
Returns Returns
...@@ -456,12 +455,17 @@ def Repeat(inputs, axis=-1, repeats=1, **kwargs): ...@@ -456,12 +455,17 @@ def Repeat(inputs, axis=-1, repeats=1, **kwargs):
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['extra_inputs'] = [Tensor.Convert(repeats, dtype='int32')]
arguments['repeats'] = arguments['extra_inputs'][0].name
output = Tensor.CreateOperator(nout=1, op_type='Repeat', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Repeat', **arguments)
if inputs.shape is not None: if inputs.shape is not None and \
not isinstance(repeats, Tensor):
if axis == -1: if axis == -1:
total_count = np.prod(inputs.shape) fake_shape = inputs.shape[:]
fake_shape = [1 if dim is None else dim for dim in fake_shape]
total_count = np.prod(fake_shape)
output.shape = [total_count * repeats] output.shape = [total_count * repeats]
else: else:
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
......
...@@ -552,7 +552,6 @@ class Net(object): ...@@ -552,7 +552,6 @@ class Net(object):
""" """
return list(self._net_outputs) return list(self._net_outputs)
def replace(self, A, B): def replace(self, A, B):
"""Replace the A as B. """Replace the A as B.
......
...@@ -262,7 +262,6 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -262,7 +262,6 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
all_exprs = dict(all_exprs, **output.expressions) all_exprs = dict(all_exprs, **output.expressions)
all_extra_targets = all_extra_targets.union(output.extra_targets) all_extra_targets = all_extra_targets.union(output.extra_targets)
if len(output.grad_wrts) > 0: existing_grads = True if len(output.grad_wrts) > 0: existing_grads = True
for extra_target in all_extra_targets: meta_graph.target.extend([extra_target])
# we should sort out the topology of these operators before using # we should sort out the topology of these operators before using
all_exprs = sorted(all_exprs.items(), key=lambda d: d[0]) all_exprs = sorted(all_exprs.items(), key=lambda d: d[0])
...@@ -280,9 +279,10 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -280,9 +279,10 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
external_input_exprs = OrderedDict(external_input_exprs, **new_tensor.expressions) external_input_exprs = OrderedDict(external_input_exprs, **new_tensor.expressions)
else: else:
external_input_exprs = dict(external_input_exprs, **new_tensor.expressions) external_input_exprs = dict(external_input_exprs, **new_tensor.expressions)
external_input_exprs = OrderedDict(sorted(external_input_exprs.items(), lambda x, y: cmp(x[1], y[1]))) external_input_exprs = OrderedDict(sorted(external_input_exprs.items(), key=lambda A: A[0]))
elif isinstance(new_tensor, np.ndarray): elif isinstance(new_tensor, np.ndarray):
ws.FeedTensor(new_tensor, GetTensorName()) ws.FeedTensor(new_tensor, GetTensorName())
all_extra_targets = all_extra_targets.union(new_tensor.extra_targets)
external_input_ops = [v for k, v in external_input_exprs.items()] external_input_ops = [v for k, v in external_input_exprs.items()]
for op in forward_ops: for op in forward_ops:
op.input.extend([name_dict[input] if input in name_dict op.input.extend([name_dict[input] if input in name_dict
...@@ -298,8 +298,15 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -298,8 +298,15 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
forward_ops, grad_ops = GraphGradientMaker.Make(forward_ops, targets) forward_ops, grad_ops = GraphGradientMaker.Make(forward_ops, targets)
else: else:
grad_ops = [] grad_ops = []
# Write Ops
meta_graph.op.extend(forward_ops + grad_ops) meta_graph.op.extend(forward_ops + grad_ops)
# Write Extra Targets
for extra_target in all_extra_targets:
meta_graph.target.extend([extra_target])
# Write Misc
if len(outputs) > 0: if len(outputs) > 0:
GraphDef_Device(meta_graph) GraphDef_Device(meta_graph)
GraphDef_Opt(meta_graph) GraphDef_Opt(meta_graph)
......
...@@ -36,7 +36,7 @@ find_packages('dragon') ...@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules() find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.1.6', version='0.2.1.7',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon', url='https://github.com/neopenx/Dragon',
author='Ting Pan', author='Ting Pan',
......
...@@ -12,7 +12,7 @@ void RepeatOp<Context>::RunWithType() { ...@@ -12,7 +12,7 @@ void RepeatOp<Context>::RunWithType() {
outer_dim, outer_dim,
dim, dim,
inner_dim, inner_dim,
repeats, reps,
Xdata, Xdata,
Ydata, Ydata,
&ctx()); &ctx());
...@@ -20,16 +20,20 @@ void RepeatOp<Context>::RunWithType() { ...@@ -20,16 +20,20 @@ void RepeatOp<Context>::RunWithType() {
template <class Context> template <class Context>
void RepeatOp<Context>::RunOnDevice() { void RepeatOp<Context>::RunOnDevice() {
// parse repeats from desc
Tensor* repeats = ws()->GetTensor(repeats_desc);
CHECK(repeats->IsType<int>()) << "\nThe type of repeats should be int32.";
reps = repeats->template data<int, CPUContext>()[0];
if (axis == -1) { if (axis == -1) {
outer_dim = inner_dim = 1; outer_dim = inner_dim = 1;
dim = input(0).count(); dim = input(0).count();
output(0)->Reshape(vector<TIndex>(1, dim * repeats)); output(0)->Reshape(vector<TIndex>(1, dim * reps));
} else { } else {
outer_dim = input(0).count(0, axis); outer_dim = input(0).count(0, axis);
dim = input(0).dim(axis); dim = input(0).dim(axis);
inner_dim = input(0).count(axis + 1); inner_dim = input(0).count(axis + 1);
vector<TIndex> dims = input(0).dims(); vector<TIndex> dims = input(0).dims();
dims[axis] *= repeats; dims[axis] *= reps;
output(0)->Reshape(dims); output(0)->Reshape(dims);
} }
...@@ -51,7 +55,7 @@ void RepeatGradientOp<Context>::RunWithType() { ...@@ -51,7 +55,7 @@ void RepeatGradientOp<Context>::RunWithType() {
outer_dim, outer_dim,
dim, dim,
inner_dim, inner_dim,
repeats, reps,
dYdata, dYdata,
dXdata, dXdata,
&ctx()); &ctx());
...@@ -59,6 +63,10 @@ void RepeatGradientOp<Context>::RunWithType() { ...@@ -59,6 +63,10 @@ void RepeatGradientOp<Context>::RunWithType() {
template <class Context> template <class Context>
void RepeatGradientOp<Context>::RunOnDevice() { void RepeatGradientOp<Context>::RunOnDevice() {
// parse repeats from desc
Tensor* repeats = ws()->GetTensor(repeats_desc);
CHECK(repeats->IsType<int>()) << "\nThe type of repeats should be int32.";
reps = repeats->template data<int, CPUContext>()[0];
if (axis == -1) { if (axis == -1) {
outer_dim = inner_dim = 1; outer_dim = inner_dim = 1;
dim = input(0).count(); dim = input(0).count();
......
...@@ -116,6 +116,7 @@ void L2NormGradientOp<Context>::RunWithType() { ...@@ -116,6 +116,7 @@ void L2NormGradientOp<Context>::RunWithType() {
if (across_inner) { if (across_inner) {
Ndata = norm->template data<T, CPUContext>(); Ndata = norm->template data<T, CPUContext>();
T sum_of_x_mul_dy = math::Dot<T, Context>(buffer->count(), Xdata, dYdata); T sum_of_x_mul_dy = math::Dot<T, Context>(buffer->count(), Xdata, dYdata);
if (mode == "MEAN") sum_of_x_mul_dy = sum_of_x_mul_dy / dim;
math::Scale<T, Context>(buffer->count(), sum_of_x_mul_dy / Ndata[n] / Ndata[n], Xdata, dXdata); math::Scale<T, Context>(buffer->count(), sum_of_x_mul_dy / Ndata[n] / Ndata[n], Xdata, dXdata);
math::Sub<T, Context>(buffer->count(), dYdata, dXdata, dXdata); math::Sub<T, Context>(buffer->count(), dYdata, dXdata, dXdata);
math::Scal<T, Context>(buffer->count(), T(1.0 / Ndata[n]), dXdata); math::Scal<T, Context>(buffer->count(), T(1.0 / Ndata[n]), dXdata);
...@@ -123,7 +124,7 @@ void L2NormGradientOp<Context>::RunWithType() { ...@@ -123,7 +124,7 @@ void L2NormGradientOp<Context>::RunWithType() {
// compute \sum_{i} x_{i, j}dy_{i, j} // compute \sum_{i} x_{i, j}dy_{i, j}
math::Mul<T, Context>(buffer->count(), Xdata, dYdata, Bdata); math::Mul<T, Context>(buffer->count(), Xdata, dYdata, Bdata);
math::Gemv<T, Context>(CblasTrans, dim, inner_dim, math::Gemv<T, Context>(CblasTrans, dim, inner_dim,
1.0, mode == "MEAN" ? 1.0 / dim : 1.0,
Bdata, DMuldata, Bdata, DMuldata,
0.0, 0.0,
BInnerdata); BInnerdata);
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!