Commit 8dbb73a7 by Ting PAN

Fix the device inference in eager execution

Summary:
This commit computes the correct device for an existing output tensor.
1 parent c0c43218
Showing with 291 additions and 290 deletions
...@@ -6,32 +6,35 @@ regularizers ...@@ -6,32 +6,35 @@ regularizers
Classes Classes
------- -------
`class L1 <regularizers/L1.html>`_
: The L1 regularizer.
`class L1L2 <regularizers/L1L2.html>`_ `class L1L2 <regularizers/L1L2.html>`_
: The L1L2 regularizer. : The L1L2 regularizer.
`class L2 <regularizers/L1.html>`_
: The L1 regularizer.
`class Regularizer <regularizers/Regularizer.html>`_ `class Regularizer <regularizers/Regularizer.html>`_
: The base regularizer class. : The base regularizer class.
Functions Functions
--------- ---------
`l1(...) <regularizers/l1.html>`_ `get(...) <regularizers/get.html>`_
: Create a L1 regularizer. : Return the regularizer callable by identifier.
`l1_l2(...) <regularizers/l1_l2.html>`_ `l1_l2(...) <regularizers/l1_l2.html>`_
: Create a L1L2 regularizer. : Create a L1L2 regularizer.
`l2(...) <regularizers/l2.html>`_
: Create a L2 regularizer.
.. toctree:: .. toctree::
:hidden: :hidden:
regularizers/l1 regularizers/get
regularizers/L1
regularizers/L1L2 regularizers/L1L2
regularizers/l1_l2 regularizers/l1_l2
regularizers/l2 regularizers/L2
regularizers/Regularizer regularizers/Regularizer
.. raw:: html .. raw:: html
......
L1
==
.. autoclass:: dragon.vm.tensorflow.keras.regularizers.L1
__init__
--------
.. automethod:: dragon.vm.tensorflow.keras.regularizers.L1.__init__
.. raw:: html
<style>
h1:before {
content: "tf.keras.regularizers.";
color: #103d3e;
}
</style>
l2 L2
== ==
.. autofunction:: dragon.vm.tensorflow.keras.regularizers.l2 .. autoclass:: dragon.vm.tensorflow.keras.regularizers.L2
__init__
--------
.. automethod:: dragon.vm.tensorflow.keras.regularizers.L2.__init__
.. raw:: html .. raw:: html
......
l1 get
== ===
.. autofunction:: dragon.vm.tensorflow.keras.regularizers.l1 .. autofunction:: dragon.vm.tensorflow.keras.regularizers.get
.. raw:: html .. raw:: html
......
...@@ -35,8 +35,9 @@ void InitializeOp<Context>::RunOnDevice() { ...@@ -35,8 +35,9 @@ void InitializeOp<Context>::RunOnDevice() {
vec64_t out_shape; vec64_t out_shape;
int ndims; int ndims;
dims(0, &ndims); dims(0, &ndims);
for (int i = 0; i < ndims; i++) for (int i = 0; i < ndims; i++) {
out_shape.push_back(dims(i)); out_shape.push_back(dims(i));
}
Output(0)->Reshape(out_shape); Output(0)->Reshape(out_shape);
} }
} }
......
...@@ -4,12 +4,7 @@ namespace dragon { ...@@ -4,12 +4,7 @@ namespace dragon {
template <class Context> template <class Context>
void ShapeOp<Context>::RunOnDevice() { void ShapeOp<Context>::RunOnDevice() {
Output(0)->Reshape({Input(0).ndim()}); Output(0)->template CopyFrom<int64_t>(Input(0).dims());
auto* y = Output(0)->template mutable_data<int64_t, CPUContext>();
for (int i = 0; i < Input(0).ndim(); i++)
y[i] = Input(0).dim(i);
} }
DEPLOY_CPU(Shape); DEPLOY_CPU(Shape);
......
...@@ -23,6 +23,8 @@ class ShapeOp final : public Operator<Context> { ...@@ -23,6 +23,8 @@ class ShapeOp final : public Operator<Context> {
SIMPLE_CTOR_DTOR(ShapeOp); SIMPLE_CTOR_DTOR(ShapeOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void SwitchToDevice() override {}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -39,7 +39,7 @@ def set_execution(execution='GRAPH_MODE'): ...@@ -39,7 +39,7 @@ def set_execution(execution='GRAPH_MODE'):
""" """
if execution not in ('GRAPH_MODE', 'EAGER_MODE'): if execution not in ('GRAPH_MODE', 'EAGER_MODE'):
raise ValueError('Unsupported execution mode:', execution) raise ValueError('Unsupported execution: ' + execution)
config.config().graph_execution = execution config.config().graph_execution = execution
...@@ -75,7 +75,7 @@ def set_scheduler(scheduler='SIMPLE'): ...@@ -75,7 +75,7 @@ def set_scheduler(scheduler='SIMPLE'):
""" """
if scheduler not in ('SIMPLE', 'FUSION'): if scheduler not in ('SIMPLE', 'FUSION'):
raise ValueError('Unsupported scheduler type:', scheduler) raise ValueError('Unsupported scheduler: ' + scheduler)
if scheduler == 'SIMPLE': if scheduler == 'SIMPLE':
config.config().graph_type = '' config.config().graph_type = ''
elif scheduler == 'FUSION': elif scheduler == 'FUSION':
......
...@@ -40,7 +40,7 @@ class Backend(object): ...@@ -40,7 +40,7 @@ class Backend(object):
if not is_nccl_available(): if not is_nccl_available():
raise ValueError('NCCL backend is not available.') raise ValueError('NCCL backend is not available.')
elif value == Backend.UNDEFINED: elif value == Backend.UNDEFINED:
raise ValueError('Invalid backend:', name) raise ValueError('Invalid backend: ' + name)
return value return value
......
...@@ -44,7 +44,7 @@ def device(device_type, device_index=0): ...@@ -44,7 +44,7 @@ def device(device_type, device_index=0):
""" """
device_type = device_type.lower() device_type = device_type.lower()
if device_type not in mapping.DEVICE_STRING_TO_DEVICE_TYPE: if device_type not in mapping.DEVICE_STRING_TO_DEVICE_TYPE:
raise ValueError('Unsupported device type:', device_type) raise ValueError('Unsupported device type: ' + device_type)
return _GLOBAL_DEVICE_STACK.get_controller({ return _GLOBAL_DEVICE_STACK.get_controller({
'device_type': mapping.DEVICE_STRING_TO_DEVICE_TYPE[device_type], 'device_type': mapping.DEVICE_STRING_TO_DEVICE_TYPE[device_type],
'device_index': device_index, 'device_index': device_index,
......
...@@ -37,8 +37,11 @@ class Operator(object): ...@@ -37,8 +37,11 @@ class Operator(object):
self._arg_device = self._arg_device.SerializeToString() self._arg_device = self._arg_device.SerializeToString()
self._seed = kwargs.get('seed', config.config().random_seed) self._seed = kwargs.get('seed', config.config().random_seed)
def alloc(self): def alloc(self, out=None):
"""Return the executing device to create an output tensor.""" """Return or bind the executing device to output tensor."""
if out is not None:
out._device = self._device.copy()
return out
return self._device.copy() return self._device.copy()
def apply(self, *args, **kwargs): def apply(self, *args, **kwargs):
...@@ -124,11 +127,6 @@ class Operator(object): ...@@ -124,11 +127,6 @@ class Operator(object):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
def new_leaf(shape, dtype, device, trainable=False):
"""Return a leaf resource."""
return EagerTensor(**locals())
def remove_binary_scalar(inputs): def remove_binary_scalar(inputs):
"""Remove the scalar for binary ops.""" """Remove the scalar for binary ops."""
if types.is_tensor(inputs[0]): if types.is_tensor(inputs[0]):
......
...@@ -211,7 +211,7 @@ class Workspace(backend.Workspace): ...@@ -211,7 +211,7 @@ class Workspace(backend.Workspace):
dtype = value.dtype if dtype is None else dtype dtype = value.dtype if dtype is None else dtype
if hasattr(tensor, 'dtype') and tensor.dtype is not None: if hasattr(tensor, 'dtype') and tensor.dtype is not None:
if tensor.dtype not in mapping.TENSOR_TYPE_TO_NP_TYPE: if tensor.dtype not in mapping.TENSOR_TYPE_TO_NP_TYPE:
raise TypeError('Unsupported data type:', tensor.dtype) raise TypeError('Unsupported data type: ' + tensor.dtype)
dtype = mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.dtype] dtype = mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.dtype]
# Determine the copying device option # Determine the copying device option
if enforce_cpu is True: if enforce_cpu is True:
......
...@@ -61,10 +61,8 @@ def dropout(inputs, prob=0.5, scale=True, **kwargs): ...@@ -61,10 +61,8 @@ def dropout(inputs, prob=0.5, scale=True, **kwargs):
op_lib = activation_ops_lib.Dropout op_lib = activation_ops_lib.Dropout
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(prob=args['prob'], scale=scale) \
prob=args['prob'], .apply([inputs], inplace=inplace)
scale=scale,
).apply([inputs], inplace=inplace)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
......
...@@ -26,7 +26,7 @@ class Activation(Operator): ...@@ -26,7 +26,7 @@ class Activation(Operator):
return {'op_type': self.op_type, 'arguments': {}} return {'op_type': self.op_type, 'arguments': {}}
def forward(self, inputs, inplace=False): def forward(self, inputs, inplace=False):
outputs = [inputs[0] if inplace else self.alloc()] outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs) return self.dispatch(inputs, outputs)
...@@ -46,7 +46,7 @@ class Dropout(Activation): ...@@ -46,7 +46,7 @@ class Dropout(Activation):
} }
class DropBlock2d(Operator): class DropBlock2d(Activation):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(DropBlock2d, self).__init__(key, dev, **kwargs) super(DropBlock2d, self).__init__(key, dev, **kwargs)
self.block_size = kwargs.get('block_size', 7) self.block_size = kwargs.get('block_size', 7)
...@@ -67,10 +67,6 @@ class DropBlock2d(Operator): ...@@ -67,10 +67,6 @@ class DropBlock2d(Operator):
}, },
} }
def forward(self, inputs, inplace=False):
outputs = [inputs[0] if inplace else self.alloc()]
return self.dispatch(inputs, outputs)
class DropPath(Activation): class DropPath(Activation):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
...@@ -96,9 +92,7 @@ class Elu(Activation): ...@@ -96,9 +92,7 @@ class Elu(Activation):
def attributes(self): def attributes(self):
return { return {
'op_type': 'Elu', 'op_type': 'Elu',
'arguments': { 'arguments': {'alpha': float(self.alpha)},
'alpha': float(self.alpha),
}
} }
...@@ -110,9 +104,7 @@ class PRelu(Operator): ...@@ -110,9 +104,7 @@ class PRelu(Operator):
def attributes(self): def attributes(self):
return { return {
'op_type': 'PRelu', 'op_type': 'PRelu',
'arguments': { 'arguments': {'data_format': self.data_format},
'data_format': self.data_format,
}
} }
def forward(self, inputs): def forward(self, inputs):
...@@ -127,9 +119,7 @@ class Relu(Activation): ...@@ -127,9 +119,7 @@ class Relu(Activation):
def attributes(self): def attributes(self):
return { return {
'op_type': 'Relu', 'op_type': 'Relu',
'arguments': { 'arguments': {'alpha': float(self.alpha)},
'alpha': float(self.alpha),
}
} }
...@@ -140,9 +130,7 @@ class Relu6(Activation): ...@@ -140,9 +130,7 @@ class Relu6(Activation):
def attributes(self): def attributes(self):
return { return {
'op_type': 'Relu', 'op_type': 'Relu',
'arguments': { 'arguments': {'max_value': 6.},
'max_value': 6.,
}
} }
...@@ -170,7 +158,5 @@ class Softmax(Activation): ...@@ -170,7 +158,5 @@ class Softmax(Activation):
def attributes(self): def attributes(self):
return { return {
'op_type': 'Softmax', 'op_type': 'Softmax',
'arguments': { 'arguments': {'axis': self.axis},
'axis': self.axis,
}
} }
...@@ -14,6 +14,7 @@ from __future__ import absolute_import ...@@ -14,6 +14,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.framework import device_spec
from dragon.core.framework.ops import Operator from dragon.core.framework.ops import Operator
...@@ -41,13 +42,14 @@ class Arange(Operator): ...@@ -41,13 +42,14 @@ class Arange(Operator):
slice_args[i], 'float32') slice_args[i], 'float32')
def forward(self, slice_args, trainable=False): def forward(self, slice_args, trainable=False):
output = self.dispatch( out = self.dispatch(
[], [self.alloc()], [], [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, slice_args) self.feed(ws, handle, slice_args),
no_grad=True,
) )
output._requires_grad = trainable out._requires_grad = trainable
return output return out
class ArgReduce(Operator): class ArgReduce(Operator):
...@@ -85,7 +87,7 @@ class Cast(Operator): ...@@ -85,7 +87,7 @@ class Cast(Operator):
if inputs[0].dtype == self.dtype: if inputs[0].dtype == self.dtype:
return inputs[0] return inputs[0]
if inplace: if inplace:
return self.dispatch([], inputs, no_grad=True) return self.dispatch([], [self.alloc(inputs[0])], no_grad=True)
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
...@@ -122,7 +124,7 @@ class ChannelNormalize(Operator): ...@@ -122,7 +124,7 @@ class ChannelNormalize(Operator):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, perm) self.feed(ws, handle, perm),
) )
...@@ -225,7 +227,7 @@ class ExpandDims(Operator): ...@@ -225,7 +227,7 @@ class ExpandDims(Operator):
} }
def forward(self, inputs, inplace=False): def forward(self, inputs, inplace=False):
outputs = [inputs[0] if inplace else self.alloc()] outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs) return self.dispatch(inputs, outputs)
...@@ -247,7 +249,7 @@ class Flatten(Operator): ...@@ -247,7 +249,7 @@ class Flatten(Operator):
} }
def forward(self, inputs, inplace=False): def forward(self, inputs, inplace=False):
outputs = [inputs[0] if inplace else self.alloc()] outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs) return self.dispatch(inputs, outputs)
...@@ -447,11 +449,10 @@ class Reshape(Operator): ...@@ -447,11 +449,10 @@ class Reshape(Operator):
e, 'int64') e, 'int64')
def forward(self, inputs, shape, inplace=False): def forward(self, inputs, shape, inplace=False):
outputs = [inputs[0] if inplace else self.alloc()]
return self.dispatch( return self.dispatch(
inputs, outputs, inputs, [self.alloc(inputs[0]) if inplace else self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, shape) self.feed(ws, handle, shape),
) )
...@@ -493,12 +494,13 @@ class Slice(Operator): ...@@ -493,12 +494,13 @@ class Slice(Operator):
class Shape(Operator): class Shape(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Shape, self).__init__(key, dev, **kwargs) super(Shape, self).__init__(key, dev, **kwargs)
self._device = device_spec.DeviceSpec()
def attributes(self): def attributes(self):
return {'op_type': 'Shape', 'arguments': {}} return {'op_type': 'Shape', 'arguments': {}}
def forward(self, inputs): def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()], no_grad=True)
class Split(Operator): class Split(Operator):
...@@ -535,7 +537,7 @@ class Squeeze(Operator): ...@@ -535,7 +537,7 @@ class Squeeze(Operator):
} }
def forward(self, inputs, inplace=False): def forward(self, inputs, inplace=False):
outputs = [inputs[0] if inplace else self.alloc()] outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs) return self.dispatch(inputs, outputs)
......
...@@ -62,7 +62,7 @@ class Copy(Operator): ...@@ -62,7 +62,7 @@ class Copy(Operator):
def forward(self, inputs, outputs): def forward(self, inputs, outputs):
outputs = outputs if outputs else [self.alloc()] outputs = outputs if outputs else [self.alloc()]
return self.dispatch(inputs, outputs) return self.dispatch(inputs, outputs, no_grad=True)
class MaskedAssign(Operator): class MaskedAssign(Operator):
......
...@@ -46,7 +46,7 @@ def all_reduce(inputs, operation='MEAN', group=None, **kwargs): ...@@ -46,7 +46,7 @@ def all_reduce(inputs, operation='MEAN', group=None, **kwargs):
if group is None: if group is None:
raise ValueError('<group> is required.') raise ValueError('<group> is required.')
if operation not in ('MEAN', 'SUM'): if operation not in ('MEAN', 'SUM'):
raise ValueError('Unsupported reduce op:', operation) raise ValueError('Unsupported reduce op: ' + operation)
args.update(group.arguments) args.update(group.arguments)
args.pop('group') args.pop('group')
op_lib = distributed_ops_lib.Collective op_lib = distributed_ops_lib.Collective
......
...@@ -115,11 +115,8 @@ def eye(n, m=None, k=0, dtype='float32', **kwargs): ...@@ -115,11 +115,8 @@ def eye(n, m=None, k=0, dtype='float32', **kwargs):
if types.is_tensor(m): if types.is_tensor(m):
m = int(m.get_value()) m = int(m.get_value())
return op_lib \ return op_lib \
.instantiate( .instantiate(k=k, ndim=2, dtype=dtype) \
k=k, .apply([n, m], trainable=trainable)
ndim=2,
dtype=dtype,
).apply([n, m], trainable=trainable)
else: else:
args['n'] = args['m'] = None args['n'] = args['m'] = None
if types.is_tensor(n) or types.is_tensor(m): if types.is_tensor(n) or types.is_tensor(m):
...@@ -173,14 +170,8 @@ def eye_like(other, k=0, dtype='float32', **kwargs): ...@@ -173,14 +170,8 @@ def eye_like(other, k=0, dtype='float32', **kwargs):
op_lib = init_ops_lib.Eye op_lib = init_ops_lib.Eye
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(k=k, dtype=dtype) \
k=k, .apply([], other, trainable=trainable)
dtype=dtype,
).apply(
shape=[],
shape_like=other,
trainable=trainable,
)
else: else:
args.pop('other') args.pop('other')
return op_lib.blend(inputs=[other], **args) return op_lib.blend(inputs=[other], **args)
...@@ -366,14 +357,8 @@ def ones_like(other, dtype='float32', **kwargs): ...@@ -366,14 +357,8 @@ def ones_like(other, dtype='float32', **kwargs):
op_lib = init_ops_lib.Fill op_lib = init_ops_lib.Fill
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(value=1, dtype=dtype) \
value=1, .apply([], other, trainable=trainable)
dtype=dtype,
).apply(
shape=[],
shape_like=other,
trainable=trainable,
)
else: else:
args.pop('other') args.pop('other')
return op_lib.blend(inputs=[other], value=1., **args) return op_lib.blend(inputs=[other], value=1., **args)
...@@ -453,11 +438,7 @@ def random_normal_like(other, mean=0, std=1, dtype='float32', **kwargs): ...@@ -453,11 +438,7 @@ def random_normal_like(other, mean=0, std=1, dtype='float32', **kwargs):
mean=args['mean'], mean=args['mean'],
std=args['std'], std=args['std'],
dtype=dtype, dtype=dtype,
).apply( ).apply([], other, trainable=trainable)
shape=[],
shape_like=other,
trainable=trainable,
)
else: else:
args.pop('other') args.pop('other')
return op_lib.blend(inputs=[other], **args) return op_lib.blend(inputs=[other], **args)
...@@ -535,11 +516,7 @@ def random_uniform_like(other, low=-1, high=1, dtype='float32', **kwargs): ...@@ -535,11 +516,7 @@ def random_uniform_like(other, low=-1, high=1, dtype='float32', **kwargs):
low=args['low'], low=args['low'],
high=args['high'], high=args['high'],
dtype=dtype, dtype=dtype,
).apply( ).apply([], other, trainable=trainable)
shape=[],
shape_like=other,
trainable=trainable,
)
else: else:
args.pop('other') args.pop('other')
return op_lib.blend(inputs=[other], **args) return op_lib.blend(inputs=[other], **args)
...@@ -641,14 +618,8 @@ def zeros_like(other, dtype='float32', **kwargs): ...@@ -641,14 +618,8 @@ def zeros_like(other, dtype='float32', **kwargs):
op_lib = init_ops_lib.Fill op_lib = init_ops_lib.Fill
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(value=0, dtype=dtype) \
value=0, .apply([], other, trainable=trainable)
dtype=dtype,
).apply(
shape=[],
shape_like=other,
trainable=trainable,
)
else: else:
args.pop('other') args.pop('other')
return op_lib.blend(inputs=[other], value=0., **args) return op_lib.blend(inputs=[other], value=0., **args)
...@@ -14,7 +14,6 @@ from __future__ import absolute_import ...@@ -14,7 +14,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.framework import ops
from dragon.core.framework.ops import Operator from dragon.core.framework.ops import Operator
...@@ -30,25 +29,17 @@ class Initializer(Operator): ...@@ -30,25 +29,17 @@ class Initializer(Operator):
ws, '{}/dims[{}]'.format(handle, i), ws, '{}/dims[{}]'.format(handle, i),
dim, 'int64') dim, 'int64')
def forward( def forward(self, shape, shape_as=None, out=None, trainable=None):
self, out = self.dispatch(
shape, [] if shape_as is None else [shape_as],
out=None, [self.alloc(out)],
shape_like=None,
trainable=False,
):
inputs = [] if shape_like is None else [shape_like]
outputs = [ops.new_leaf(
shape=shape,
dtype=self.dtype,
device=self.alloc(),
trainable=trainable,
) if out is None else out]
return self.dispatch(
inputs, outputs,
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, shape), self.feed(ws, handle, shape),
no_grad=True,
) )
if trainable is not None:
out._requires_grad = trainable
return out
class Eye(Initializer): class Eye(Initializer):
......
...@@ -53,7 +53,8 @@ class Axpby(Operator): ...@@ -53,7 +53,8 @@ class Axpby(Operator):
def forward(self, inputs, outputs=None): def forward(self, inputs, outputs=None):
if outputs is None: if outputs is None:
outputs = [self.alloc() for _ in range(len(inputs))] outputs = [None] * len(inputs)
outputs = [self.alloc(out) for out in outputs]
return self.dispatch(inputs, outputs, no_grad=True) return self.dispatch(inputs, outputs, no_grad=True)
...@@ -65,12 +66,8 @@ class BinaryOp(Operator): ...@@ -65,12 +66,8 @@ class BinaryOp(Operator):
def attributes(self): def attributes(self):
return {'op_type': self.op_type, 'arguments': {}} return {'op_type': self.op_type, 'arguments': {}}
def forward(self, inputs, outputs=None): def forward(self, inputs, outputs=(None,)):
if outputs is None: return self.dispatch(inputs, [self.alloc(outputs[0])])
outputs = [self.alloc()]
else:
outputs[0]._device = self.alloc()
return self.dispatch(inputs, outputs)
class Clip(Operator): class Clip(Operator):
......
...@@ -23,7 +23,7 @@ class Metric(Operator): ...@@ -23,7 +23,7 @@ class Metric(Operator):
self.reduction = kwargs.get('reduction', 'MEAN') self.reduction = kwargs.get('reduction', 'MEAN')
def forward(self, inputs): def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()], no_grad=True)
class Accuracy(Metric): class Accuracy(Metric):
......
...@@ -576,7 +576,7 @@ def uniform(self, low=0, high=1): ...@@ -576,7 +576,7 @@ def uniform(self, low=0, high=1):
).apply(shape, out=self) ).apply(shape, out=self)
def _binary_op(a, b, op_type, outputs=None): def _binary_op(a, b, op_type, outputs=(None,)):
"""Apply the general binary operation.""" """Apply the general binary operation."""
return math_ops_lib.BinaryOp \ return math_ops_lib.BinaryOp \
.instantiate(op_type=op_type) \ .instantiate(op_type=op_type) \
......
...@@ -85,13 +85,11 @@ class BiasAdd(Operator): ...@@ -85,13 +85,11 @@ class BiasAdd(Operator):
def attributes(self): def attributes(self):
return { return {
'op_type': 'BiasAdd', 'op_type': 'BiasAdd',
'arguments': { 'arguments': {'data_format': self.data_format},
'data_format': self.data_format,
},
} }
def forward(self, inputs, inplace=False): def forward(self, inputs, inplace=False):
outputs = [inputs[0] if inplace else self.alloc()] outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs) return self.dispatch(inputs, outputs)
......
...@@ -85,6 +85,16 @@ def getfullargspec(obj): ...@@ -85,6 +85,16 @@ def getfullargspec(obj):
return _getfullargspec(target) return _getfullargspec(target)
def isclass(object):
"""Decorator-aware replacement for ``inspect.isclass``."""
return _inspect.isclass(decorator.unwrap(object)[1])
def isfunction(object):
"""Decorator-aware replacement for ``inspect.isfunction``."""
return _inspect.isfunction(decorator.unwrap(object)[1])
def ismethod(object): def ismethod(object):
"""Decorator-aware replacement for ``inspect.ismethod``.""" """Decorator-aware replacement for ``inspect.ismethod``."""
return _inspect.ismethod(decorator.unwrap(object)[1]) return _inspect.ismethod(decorator.unwrap(object)[1])
...@@ -14,12 +14,15 @@ from __future__ import division as _division ...@@ -14,12 +14,15 @@ from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
# Classes # Classes
from dragon.vm.tensorflow.core.keras.regularizers import L1
from dragon.vm.tensorflow.core.keras.regularizers import l1
from dragon.vm.tensorflow.core.keras.regularizers import L1L2 from dragon.vm.tensorflow.core.keras.regularizers import L1L2
from dragon.vm.tensorflow.core.keras.regularizers import L2
from dragon.vm.tensorflow.core.keras.regularizers import l2
from dragon.vm.tensorflow.core.keras.regularizers import Regularizer from dragon.vm.tensorflow.core.keras.regularizers import Regularizer
# Functions # Functions
from dragon.vm.tensorflow.core.keras.regularizers import l1 from dragon.vm.tensorflow.core.keras.regularizers import get
from dragon.vm.tensorflow.core.keras.regularizers import l1_l2 from dragon.vm.tensorflow.core.keras.regularizers import l1_l2
from dragon.vm.tensorflow.core.keras.regularizers import l2
__all__ = [_s for _s in dir() if not _s.startswith('_')] __all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -7,10 +7,6 @@ ...@@ -7,10 +7,6 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/activations.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
...@@ -18,6 +14,7 @@ from __future__ import division ...@@ -18,6 +14,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.util import six from dragon.core.util import six
from dragon.vm.tensorflow.core.keras.utils import generic_utils
from dragon.vm.tensorflow.core.ops import math_ops from dragon.vm.tensorflow.core.ops import math_ops
from dragon.vm.tensorflow.core.ops import nn from dragon.vm.tensorflow.core.ops import nn
...@@ -272,7 +269,7 @@ def tanh(x, **kwargs): ...@@ -272,7 +269,7 @@ def tanh(x, **kwargs):
def get(identifier): def get(identifier):
"""Return the activation callable by identifier. """Return the activation function by identifier.
Parameters Parameters
---------- ----------
...@@ -282,7 +279,7 @@ def get(identifier): ...@@ -282,7 +279,7 @@ def get(identifier):
Returns Returns
------- -------
callable callable
The activation callable. The activation function.
""" """
if identifier is None: if identifier is None:
...@@ -290,8 +287,9 @@ def get(identifier): ...@@ -290,8 +287,9 @@ def get(identifier):
elif callable(identifier): elif callable(identifier):
return identifier return identifier
elif isinstance(identifier, six.string_types): elif isinstance(identifier, six.string_types):
return globals()[identifier] return generic_utils.deserialize_keras_object(
identifier, globals(), 'activation')
else: else:
raise TypeError( raise TypeError(
'Could not interpret activation identifier: {}.' 'Could not interpret the activation identifier: {}.'
.format(repr(identifier))) .format(identifier))
...@@ -75,7 +75,7 @@ def Input( ...@@ -75,7 +75,7 @@ def Input(
if shape is not None: if shape is not None:
shape = (batch_size,) + tuple(shape) shape = (batch_size,) + tuple(shape)
if kwargs: if kwargs:
raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) raise ValueError('Unrecognized keyword arguments: ' + kwargs.keys())
if dtype is None: if dtype is None:
if tensor is not None: if tensor is not None:
dtype = tensor.dtype dtype = tensor.dtype
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.util import six from dragon.core.util import six
from dragon.vm.tensorflow.core.keras.utils import generic_utils
from dragon.vm.tensorflow.core.ops.init_ops import Constant from dragon.vm.tensorflow.core.ops.init_ops import Constant
from dragon.vm.tensorflow.core.ops.init_ops import GlorotNormal from dragon.vm.tensorflow.core.ops.init_ops import GlorotNormal
from dragon.vm.tensorflow.core.ops.init_ops import GlorotUniform from dragon.vm.tensorflow.core.ops.init_ops import GlorotUniform
...@@ -60,8 +61,9 @@ def get(identifier): ...@@ -60,8 +61,9 @@ def get(identifier):
elif callable(identifier): elif callable(identifier):
return identifier return identifier
elif isinstance(identifier, six.string_types): elif isinstance(identifier, six.string_types):
return globals()[identifier] return generic_utils.deserialize_keras_object(
identifier, globals(), 'initializer')
else: else:
raise TypeError( raise TypeError(
'Could not interpret initializer identifier: {}.' 'Could not interpret the initializer identifier: {}.'
.format(repr(identifier))) .format(identifier))
...@@ -96,7 +96,7 @@ class BatchNormalization(Layer): ...@@ -96,7 +96,7 @@ class BatchNormalization(Layer):
def build(self, input_shape): def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims: if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape) raise ValueError('Input has undefined rank: ' + input_shape)
param_shape = [input_shape.dims[self.axis]] param_shape = [input_shape.dims[self.axis]]
self.input_spec = InputSpec( self.input_spec = InputSpec(
# Each layer should adapt to the: # Each layer should adapt to the:
......
...@@ -7,10 +7,6 @@ ...@@ -7,10 +7,6 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/losses.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
...@@ -20,6 +16,7 @@ from __future__ import print_function ...@@ -20,6 +16,7 @@ from __future__ import print_function
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.ops import loss_ops from dragon.core.ops import loss_ops
from dragon.core.util import six from dragon.core.util import six
from dragon.vm.tensorflow.core.keras.utils import generic_utils
from dragon.vm.tensorflow.core.keras.utils import losses_utils from dragon.vm.tensorflow.core.keras.utils import losses_utils
...@@ -523,8 +520,9 @@ def get(identifier): ...@@ -523,8 +520,9 @@ def get(identifier):
elif callable(identifier): elif callable(identifier):
return identifier return identifier
elif isinstance(identifier, six.string_types): elif isinstance(identifier, six.string_types):
return globals()[identifier] return generic_utils.deserialize_keras_object(
identifier, globals(), 'loss')
else: else:
raise TypeError( raise TypeError(
'Could not interpret loss identifier: {}.' 'Could not interpret the loss identifier: {}.'
.format(repr(identifier))) .format(identifier))
...@@ -7,10 +7,6 @@ ...@@ -7,10 +7,6 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/optimizer_v2/optimizer.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
...@@ -51,7 +47,7 @@ class Optimizer(optimizer_v1.Optimizer): ...@@ -51,7 +47,7 @@ class Optimizer(optimizer_v1.Optimizer):
allowed_kwargs = {'scale', 'clipnorm', 'lr'} allowed_kwargs = {'scale', 'clipnorm', 'lr'}
for k in kwargs: for k in kwargs:
if k not in allowed_kwargs: if k not in allowed_kwargs:
raise TypeError('Unexpected keyword argument:', str(k)) raise TypeError('Unexpected keyword argument: ' + str(k))
if kwargs[k] < 0: if kwargs[k] < 0:
raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k])) raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
...@@ -109,15 +105,17 @@ class Optimizer(optimizer_v1.Optimizer): ...@@ -109,15 +105,17 @@ class Optimizer(optimizer_v1.Optimizer):
for g, v in grads_and_vars: for g, v in grads_and_vars:
if g is not None: if g is not None:
decay_mult = 0. decay_mult = 0.
if hasattr(v, '__regularizer__'): regularizer = getattr(v, '_regularizer', None)
decay_mult = v.__regularizer__.l2 / self.BASE_WEIGHT_DECAY if regularizer is not None:
decay_mult = regularizer.l2 / self.BASE_WEIGHT_DECAY
self._run_update(v, g, decay_mult=decay_mult) self._run_update(v, g, decay_mult=decay_mult)
else: else:
# Store for the lazy compilation. # Store for the lazy compilation.
for g, v in grads_and_vars: for g, v in grads_and_vars:
decay_mult = 0. decay_mult = 0.
if hasattr(v, '__regularizer__'): regularizer = getattr(v, '_regularizer', None)
decay_mult = v.__regularizer__.l2 / self.BASE_WEIGHT_DECAY if regularizer is not None:
decay_mult = regularizer.l2 / self.BASE_WEIGHT_DECAY
self._add_update(v, g, decay_mult=decay_mult) self._add_update(v, g, decay_mult=decay_mult)
# Increase the iterations. # Increase the iterations.
......
...@@ -7,10 +7,6 @@ ...@@ -7,10 +7,6 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/regularizers.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Built-in regularizers.""" """Built-in regularizers."""
...@@ -18,6 +14,9 @@ from __future__ import absolute_import ...@@ -18,6 +14,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.util import six
from dragon.vm.tensorflow.core.keras.utils import generic_utils
class Regularizer(object): class Regularizer(object):
"""The base regularizer class.""" """The base regularizer class."""
...@@ -36,10 +35,33 @@ class Regularizer(object): ...@@ -36,10 +35,33 @@ class Regularizer(object):
The output tensor. The output tensor.
""" """
x.__regularizer__ = self x._regularizer = self
return x return x
class L1(Regularizer):
r"""The L1 regularizer.
The **L1** regularizer is defined as:
.. math:: loss_{reg} = loss + \alpha|w|
"""
def __init__(self, l1=0.01):
r"""Create a ``L1`` regularizer.
Parameters
----------
l1 : float, optional, default=0.01
The value to :math:`\alpha`.
"""
if l1 <= 0.:
raise ValueError('<l1> should be greater than 0.')
self.l1 = l1
class L1L2(Regularizer): class L1L2(Regularizer):
r"""The L1L2 regularizer. r"""The L1L2 regularizer.
...@@ -65,39 +87,27 @@ class L1L2(Regularizer): ...@@ -65,39 +87,27 @@ class L1L2(Regularizer):
self.l1, self.l2 = l1, l2 self.l1, self.l2 = l1, l2
def get(identifier): class L2(Regularizer):
"""Return a regularizer from the identifier.""" r"""The L2 regularizer.
if identifier is None:
return None
elif callable(identifier):
return identifier
else:
raise ValueError(
'Could not interpret regularizer identifier:',
identifier,
)
The **L2** regularizer is defined as:
# Aliases .. math:: loss_{reg} = loss + \frac{\beta}{2}|w|_{2}
def l1(l=0.01):
r"""Create a L1 regularizer.
The **L1** regularizer is defined as: """
.. math:: loss_{reg} = loss + \alpha|w| def __init__(self, l2=0.01):
r"""Create a ``L2`` regularizer.
Parameters Parameters
---------- ----------
l : float, optional, default=0.01 l1 : float, optional, default=0.01
The value to :math:`\alpha`. The value to :math:`\alpha`.
Returns
-------
dragon.vm.tensorflow.keras.regularizers.Regularizer
The regularizer.
""" """
return L1L2(l1=l) if l2 <= 0.:
raise ValueError('<l2> should be greater than 0.')
self.l2 = l2
def l1_l2(l1=0.01, l2=0.01): def l1_l2(l1=0.01, l2=0.01):
...@@ -123,22 +133,35 @@ def l1_l2(l1=0.01, l2=0.01): ...@@ -123,22 +133,35 @@ def l1_l2(l1=0.01, l2=0.01):
return L1L2(l1=l1, l2=l2) return L1L2(l1=l1, l2=l2)
def l2(l=0.01): # Aliases
r"""Create a L2 regularizer. l1 = L1
l2 = L2
The **L2** regularizer is defined as:
.. math:: loss_{reg} = loss + \frac{\beta}{2}|w|_{2} def get(identifier):
"""Return the regularizer callable by identifier.
Parameters Parameters
---------- ----------
l : float, optional, default=0.01 identifier : Union[callable, str]
The value to :math:`\beta`. The identifier.
Returns Returns
------- -------
dragon.vm.tensorflow.keras.regularizers.Regularizer callable
The regularizer. The activation callable.
""" """
return L1L2(l2=l) if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, six.string_types):
if identifier == 'l1_l2':
return L1L2(l1=0.01, l2=0.01)
return generic_utils.deserialize_keras_object(
identifier, globals(), 'regularizer')
else:
raise TypeError(
'Could not interpret the regularizer identifier: {}.'
.format(identifier))
...@@ -30,7 +30,7 @@ def convert_data_format(data_format, ndim): ...@@ -30,7 +30,7 @@ def convert_data_format(data_format, ndim):
elif ndim == 5: elif ndim == 5:
return 'NDHWC' return 'NDHWC'
else: else:
raise ValueError('Input rank not supported:', ndim) raise ValueError('Input rank not supported: ' + ndim)
elif data_format == 'channels_first': elif data_format == 'channels_first':
if ndim == 3: if ndim == 3:
return 'NCW' return 'NCW'
...@@ -39,9 +39,9 @@ def convert_data_format(data_format, ndim): ...@@ -39,9 +39,9 @@ def convert_data_format(data_format, ndim):
elif ndim == 5: elif ndim == 5:
return 'NCDHW' return 'NCDHW'
else: else:
raise ValueError('Input rank not supported:', ndim) raise ValueError('Input rank not supported: ' + ndim)
else: else:
raise ValueError('Invalid data_format:', data_format) raise ValueError('Invalid data_format: ' + data_format)
def deconv_output_length( def deconv_output_length(
...@@ -94,8 +94,7 @@ def normalize_padding(value): ...@@ -94,8 +94,7 @@ def normalize_padding(value):
if value not in {'valid', 'same'}: if value not in {'valid', 'same'}:
raise ValueError( raise ValueError(
'Excepted <padding> in "valid", "same".\n' 'Excepted <padding> in "valid", "same".\n'
'Received: ' + str(value) 'Received: ' + str(value))
)
return value return value
......
...@@ -19,6 +19,32 @@ from __future__ import print_function ...@@ -19,6 +19,32 @@ from __future__ import print_function
import re import re
from dragon.core.util import inspect
from dragon.core.util import six
def deserialize_keras_object(
identifier,
module_objects,
printable_module_name='object',
):
"""Deserialize the keras object."""
if isinstance(identifier, six.string_types):
object_name = identifier
obj = module_objects.get(object_name)
if obj is None:
raise ValueError(
'Unknown ' + printable_module_name + ': ' + object_name)
if inspect.isclass(obj):
return obj()
return obj
elif inspect.isfunction(identifier):
return identifier
else:
raise TypeError(
'Could not interpret the {} identifier: {}.'
.format(printable_module_name, identifier))
def to_snake_case(name): def to_snake_case(name):
"""Convert the name from camel-style to snake-style.""" """Convert the name from camel-style to snake-style."""
......
...@@ -7,10 +7,6 @@ ...@@ -7,10 +7,6 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/utils/losses_utils.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import
......
...@@ -249,7 +249,7 @@ class VarianceScaling(Initializer): ...@@ -249,7 +249,7 @@ class VarianceScaling(Initializer):
raise ValueError('<scale> must be positive float.') raise ValueError('<scale> must be positive float.')
mode = mode.lower() mode = mode.lower()
if mode not in {'fan_in', 'fan_out', 'fan_avg'}: if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
raise ValueError('Invalid <mode> argument:', mode) raise ValueError('Invalid <mode> argument: ' + mode)
distribution = distribution.lower() distribution = distribution.lower()
if distribution not in {'normal', 'uniform'}: if distribution not in {'normal', 'uniform'}:
raise ValueError("Invalid `distribution` argument:", distribution) raise ValueError("Invalid `distribution` argument:", distribution)
...@@ -434,10 +434,12 @@ def glorot_normal_initializer(dtype='float32'): ...@@ -434,10 +434,12 @@ def glorot_normal_initializer(dtype='float32'):
# Aliases # Aliases
zeros_initializer = Zeros zeros_initializer = zero = zeros = Zeros
ones_initializer = Ones ones_initializer = one = ones = Ones
constant_initializer = Constant constant_initializer = constant = Constant
random_uniform_initializer = RandomUniform random_uniform_initializer = uniform = random_uniform = RandomUniform
random_normal_initializer = RandomNormal random_normal_initializer = normal = random_normal = RandomNormal
truncated_normal_initializer = TruncatedNormal truncated_normal_initializer = truncated_normal = TruncatedNormal
variance_scaling_initializer = VarianceScaling variance_scaling_initializer = VarianceScaling
glorot_normal = GlorotNormal
glorot_uniform = GlorotUniform
...@@ -26,7 +26,7 @@ def convert_data_format(data_format): ...@@ -26,7 +26,7 @@ def convert_data_format(data_format):
elif data_format == 'channels_first': elif data_format == 'channels_first':
return 'NCHW' return 'NCHW'
else: else:
raise ValueError('Invalid data_format:', data_format) raise ValueError('Invalid data_format: ' + data_format)
def normalize_data_format(value): def normalize_data_format(value):
......
...@@ -70,15 +70,23 @@ class Function(object): ...@@ -70,15 +70,23 @@ class Function(object):
self._arg_device = self._arg_device.SerializeToString() self._arg_device = self._arg_device.SerializeToString()
self._seed = kwargs.get('seed', config.config().random_seed) self._seed = kwargs.get('seed', config.config().random_seed)
def alloc(self): def alloc(self, out=None):
"""Return the executing device to create an output tensor. """Return or bind the executing device to output tensor.
Parameters
----------
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
Returns Returns
------- -------
dragon.vm.torch.device Union[dragon.vm.torch.device, dragon.vm.torch.Tensor]
The device spec. The executing device or output tensor.
""" """
if out is not None:
out._device = self._device.copy()
return out
return self._device.copy() return self._device.copy()
def apply(self, *args, **kwargs): def apply(self, *args, **kwargs):
...@@ -106,7 +114,7 @@ class Function(object): ...@@ -106,7 +114,7 @@ class Function(object):
"""Dispatch the execution.""" """Dispatch the execution."""
if self._def is None: if self._def is None:
self._gen_def() self._gen_def()
if check_device: if len(inputs) > 1 and check_device:
self._check_device(inputs) self._check_device(inputs)
return execute.run_operator( return execute.run_operator(
op_def=self._def, op_def=self._def,
......
...@@ -66,7 +66,7 @@ def calculate_gain(nonlinearity, param=None): ...@@ -66,7 +66,7 @@ def calculate_gain(nonlinearity, param=None):
raise ValueError("Negative slope {} is not a valid number".format(param)) raise ValueError("Negative slope {} is not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2)) return math.sqrt(2.0 / (1 + negative_slope ** 2))
else: else:
raise ValueError('Unsupported nonlinearity:', nonlinearity) raise ValueError('Unsupported nonlinearity: ' + nonlinearity)
def constant_(tensor, val): def constant_(tensor, val):
......
...@@ -340,8 +340,7 @@ class LpNormalize(function.Function): ...@@ -340,8 +340,7 @@ class LpNormalize(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)])
return self.dispatch([input], [out])
class LSTMCell(function.Function): class LSTMCell(function.Function):
...@@ -564,11 +563,7 @@ class RNNParamSet(function.Function): ...@@ -564,11 +563,7 @@ class RNNParamSet(function.Function):
} }
def forward(self, param, weights): def forward(self, param, weights):
return self.dispatch( return self.dispatch([param], [weights], no_grad=True)
[param], [weights],
no_grad=True,
check_device=False,
)
class SigmoidCrossEntropy(_Loss): class SigmoidCrossEntropy(_Loss):
......
...@@ -34,8 +34,7 @@ class ArgReduce(function.Function): ...@@ -34,8 +34,7 @@ class ArgReduce(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
outputs = [out] if out else [self.alloc()] return self.dispatch([input], [self.alloc(out)], no_grad=True)
return self.dispatch([input], outputs, no_grad=True)
class Assign(function.Function): class Assign(function.Function):
...@@ -148,8 +147,7 @@ class ChannelShuffle(function.Function): ...@@ -148,8 +147,7 @@ class ChannelShuffle(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)])
return self.dispatch([input], [out])
class Concat(function.Function): class Concat(function.Function):
...@@ -164,8 +162,7 @@ class Concat(function.Function): ...@@ -164,8 +162,7 @@ class Concat(function.Function):
} }
def forward(self, seq, out=None): def forward(self, seq, out=None):
out = out if out else self.alloc() return self.dispatch(seq, [self.alloc(out)])
return self.dispatch(seq, [out])
class Cumulative(function.Function): class Cumulative(function.Function):
...@@ -187,8 +184,7 @@ class Cumulative(function.Function): ...@@ -187,8 +184,7 @@ class Cumulative(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)])
return self.dispatch([input], [out])
class Expand(function.Function): class Expand(function.Function):
...@@ -235,9 +231,8 @@ class IndexSelect(function.Function): ...@@ -235,9 +231,8 @@ class IndexSelect(function.Function):
}, },
} }
def forward(self, input, indices, out=None): def forward(self, input, index, out=None):
out = out if out else self.alloc() return self.dispatch([input, index], [self.alloc(out)])
return self.dispatch([input, indices], [out])
class MaskedAssign(function.Function): class MaskedAssign(function.Function):
...@@ -248,7 +243,7 @@ class MaskedAssign(function.Function): ...@@ -248,7 +243,7 @@ class MaskedAssign(function.Function):
return {'op_type': 'MaskedAssign', 'arguments': {}} return {'op_type': 'MaskedAssign', 'arguments': {}}
def forward(self, out, mask, input): def forward(self, out, mask, input):
return self.dispatch([input, mask], [out]) return self.dispatch([input, mask], [self.alloc(out)])
class MaskedSelect(function.Function): class MaskedSelect(function.Function):
...@@ -259,8 +254,7 @@ class MaskedSelect(function.Function): ...@@ -259,8 +254,7 @@ class MaskedSelect(function.Function):
return {'op_type': 'MaskedSelect', 'arguments': {}} return {'op_type': 'MaskedSelect', 'arguments': {}}
def forward(self, input, mask, out=None): def forward(self, input, mask, out=None):
out = out if out else self.alloc() return self.dispatch([input, mask], [self.alloc(out)])
return self.dispatch([input, mask], [out])
class Multinomial(function.Function): class Multinomial(function.Function):
...@@ -280,8 +274,7 @@ class Multinomial(function.Function): ...@@ -280,8 +274,7 @@ class Multinomial(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)], no_grad=True)
return self.dispatch([input], [out], no_grad=True)
class NonZero(function.Function): class NonZero(function.Function):
...@@ -292,8 +285,7 @@ class NonZero(function.Function): ...@@ -292,8 +285,7 @@ class NonZero(function.Function):
return {'op_type': 'NonZero', 'arguments': {}} return {'op_type': 'NonZero', 'arguments': {}}
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)], no_grad=True)
return self.dispatch([input], [out], no_grad=True)
class OneHot(function.Function): class OneHot(function.Function):
...@@ -330,8 +322,7 @@ class Reduce(function.Function): ...@@ -330,8 +322,7 @@ class Reduce(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)])
return self.dispatch([input], [out])
class Reshape(function.Function): class Reshape(function.Function):
...@@ -356,9 +347,8 @@ class Reshape(function.Function): ...@@ -356,9 +347,8 @@ class Reshape(function.Function):
shape[i], 'int64') shape[i], 'int64')
def forward(self, input, shape, out=None): def forward(self, input, shape, out=None):
out = out if out else self.alloc()
return self.dispatch( return self.dispatch(
[input], [out], [input], [self.alloc(out)],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, shape), self.feed(ws, handle, shape),
) )
...@@ -433,8 +423,7 @@ class Stack(function.Function): ...@@ -433,8 +423,7 @@ class Stack(function.Function):
} }
def forward(self, seq, out=None): def forward(self, seq, out=None):
out = out if out else self.alloc() return self.dispatch(seq, [self.alloc(out)])
return self.dispatch(seq, [out])
class Squeeze(function.Function): class Squeeze(function.Function):
...@@ -451,8 +440,7 @@ class Squeeze(function.Function): ...@@ -451,8 +440,7 @@ class Squeeze(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)])
return self.dispatch([input], [out])
class Tile(function.Function): class Tile(function.Function):
...@@ -534,8 +522,8 @@ class TopK(function.Function): ...@@ -534,8 +522,8 @@ class TopK(function.Function):
} }
} }
def forward(self, input, outputs=None): def forward(self, input, outputs=(None, None)):
outputs = [self.alloc(), self.alloc()] if outputs is None else outputs outputs = [self.alloc(outputs[0]), self.alloc(outputs[1])]
return self.dispatch([input], outputs, no_grad=True) return self.dispatch([input], outputs, no_grad=True)
...@@ -553,8 +541,7 @@ class UnSqueeze(function.Function): ...@@ -553,8 +541,7 @@ class UnSqueeze(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)])
return self.dispatch([input], [out])
class Where(function.Function): class Where(function.Function):
......
...@@ -943,7 +943,7 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None): ...@@ -943,7 +943,7 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
axis=dim, axis=dim,
largest=largest, largest=largest,
sorted=sorted, sorted=sorted,
).apply(input, out) ).apply(input, out if out else (None, None))
def unsqueeze(input, dim, out=None): def unsqueeze(input, dim, out=None):
......
...@@ -42,7 +42,7 @@ def all_reduce(tensor, op='SUM', group=None): ...@@ -42,7 +42,7 @@ def all_reduce(tensor, op='SUM', group=None):
if group is None: if group is None:
raise ValueError('<group> is required.') raise ValueError('<group> is required.')
if op not in ('MEAN', 'SUM'): if op not in ('MEAN', 'SUM'):
raise ValueError('Unsupported reduce op:', op) raise ValueError('Unsupported reduce op: ' + op)
tensors = nest.flatten(tensor) tensors = nest.flatten(tensor)
return _functions.Collective \ return _functions.Collective \
.instantiate( .instantiate(
......
...@@ -34,6 +34,7 @@ class _Initializer(function.Function): ...@@ -34,6 +34,7 @@ class _Initializer(function.Function):
[] if shape_like is None else [shape_like], [out], [] if shape_like is None else [shape_like], [out],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, shape), self.feed(ws, handle, shape),
no_grad=True,
) )
...@@ -62,9 +63,10 @@ class Arange(function.Function): ...@@ -62,9 +63,10 @@ class Arange(function.Function):
def forward(self, slice_args, out=None): def forward(self, slice_args, out=None):
return self.dispatch( return self.dispatch(
[], [out if out else self.alloc()], [], [self.alloc(out)],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, slice_args) self.feed(ws, handle, slice_args),
no_grad=True,
) )
......
...@@ -121,7 +121,7 @@ def eye( ...@@ -121,7 +121,7 @@ def eye(
Returns Returns
------- -------
dragon.Tensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
......
...@@ -33,8 +33,7 @@ class Axpby(function.Function): ...@@ -33,8 +33,7 @@ class Axpby(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)], no_grad=True)
return self.dispatch([input], [out], no_grad=True)
class BinaryFunc(function.Function): class BinaryFunc(function.Function):
...@@ -46,8 +45,7 @@ class BinaryFunc(function.Function): ...@@ -46,8 +45,7 @@ class BinaryFunc(function.Function):
return {'op_type': self.op_type, 'arguments': {}} return {'op_type': self.op_type, 'arguments': {}}
def forward(self, input, value, out=None): def forward(self, input, value, out=None):
out = out if out else self.alloc() return self.dispatch([input, value], [self.alloc(out)])
return self.dispatch([input, value], [out])
class Clip(function.Function): class Clip(function.Function):
...@@ -70,8 +68,7 @@ class Clip(function.Function): ...@@ -70,8 +68,7 @@ class Clip(function.Function):
} }
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)])
return self.dispatch([input], [out])
class UnaryFunc(function.Function): class UnaryFunc(function.Function):
...@@ -83,8 +80,7 @@ class UnaryFunc(function.Function): ...@@ -83,8 +80,7 @@ class UnaryFunc(function.Function):
return {'op_type': self.op_type, 'arguments': {}} return {'op_type': self.op_type, 'arguments': {}}
def forward(self, input, out=None): def forward(self, input, out=None):
out = out if out else self.alloc() return self.dispatch([input], [self.alloc(out)])
return self.dispatch([input], [out])
class MatMul(function.Function): class MatMul(function.Function):
...@@ -103,5 +99,4 @@ class MatMul(function.Function): ...@@ -103,5 +99,4 @@ class MatMul(function.Function):
} }
def forward(self, mat1, mat2, out=None): def forward(self, mat1, mat2, out=None):
out = out if out else self.alloc() return self.dispatch([mat1, mat2], [self.alloc(out)])
return self.dispatch([mat1, mat2], [out])
...@@ -37,11 +37,7 @@ class ParamUpdate(function.Function): ...@@ -37,11 +37,7 @@ class ParamUpdate(function.Function):
def forward(self, param, grad): def forward(self, param, grad):
self._check_device([param, grad]) self._check_device([param, grad])
return self.dispatch( return self.dispatch([grad], [param], no_grad=True)
[grad], [param],
no_grad=True,
check_device=False,
)
class GradAccumulate(function.Function): class GradAccumulate(function.Function):
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!