Commit ad83f4e4 by Ting PAN

Simplify the parsing of tensor arguments

Summary:
This commit moves the parser into ArgHelper which designed
to add the descriptors before.
1 parent 9fc5249b
...@@ -49,6 +49,10 @@ copy ...@@ -49,6 +49,10 @@ copy
#### ####
.. automethod:: dragon.EagerTensor.copy .. automethod:: dragon.EagerTensor.copy
from_value
##########
.. automethod:: dragon.EagerTensor.from_value
get_value get_value
######### #########
.. automethod:: dragon.EagerTensor.get_value .. automethod:: dragon.EagerTensor.get_value
......
...@@ -49,6 +49,10 @@ copy ...@@ -49,6 +49,10 @@ copy
#### ####
.. automethod:: dragon.Tensor.copy .. automethod:: dragon.Tensor.copy
from_value
##########
.. automethod:: dragon.Tensor.from_value
get_value get_value
########## ##########
.. automethod:: dragon.Tensor.get_value .. automethod:: dragon.Tensor.get_value
......
...@@ -9,6 +9,9 @@ vm.tensorflow ...@@ -9,6 +9,9 @@ vm.tensorflow
`class GradientTape <tensorflow/GradientTape.html>`_ `class GradientTape <tensorflow/GradientTape.html>`_
: Record the operations for auto differentiation. : Record the operations for auto differentiation.
`class Module <tensorflow/Module.html>`_
: The base class of neural network modules.
`class TensorShape <tensorflow/TensorShape.html>`_ `class TensorShape <tensorflow/TensorShape.html>`_
: Represent the a sequence of dimensions. : Represent the a sequence of dimensions.
...@@ -36,6 +39,9 @@ vm.tensorflow ...@@ -36,6 +39,9 @@ vm.tensorflow
`constant(...) <tensorflow/constant.html>`_ `constant(...) <tensorflow/constant.html>`_
: Return a tensor initialized from the value. : Return a tensor initialized from the value.
`convert_to_tensor(...) <tensorflow/convert_to_tensor.html>`_
: Convert the given value to a tensor.
`device(...) <tensorflow/device.html>`_ `device(...) <tensorflow/device.html>`_
: Context-manager to nest the device spec. : Context-manager to nest the device spec.
...@@ -118,6 +124,7 @@ vm.tensorflow ...@@ -118,6 +124,7 @@ vm.tensorflow
:hidden: :hidden:
tensorflow/GradientTape tensorflow/GradientTape
tensorflow/Module
tensorflow/TensorShape tensorflow/TensorShape
tensorflow/TensorSpec tensorflow/TensorSpec
tensorflow/argsort tensorflow/argsort
...@@ -126,6 +133,7 @@ vm.tensorflow ...@@ -126,6 +133,7 @@ vm.tensorflow
tensorflow/clip_by_value tensorflow/clip_by_value
tensorflow/concat tensorflow/concat
tensorflow/constant tensorflow/constant
tensorflow/convert_to_tensor
tensorflow/device tensorflow/device
tensorflow/expand_dims tensorflow/expand_dims
tensorflow/eye tensorflow/eye
......
Module
======
.. autoclass:: dragon.vm.tensorflow.Module
__init__
--------
.. automethod:: dragon.vm.tensorflow.Module.__init__
Properties
----------
name
####
.. autoattribute:: dragon.vm.tensorflow.Module.name
name_scope
##########
.. autoattribute:: dragon.vm.tensorflow.Module.name_scope
submodules
##########
.. autoattribute:: dragon.vm.tensorflow.Module.submodules
trainable_variables
###################
.. autoattribute:: dragon.vm.tensorflow.Module.trainable_variables
variables
#########
.. autoattribute:: dragon.vm.tensorflow.Module.variables
Methods
-------
flatten
#######
.. automethod:: dragon.vm.tensorflow.Module.flatten
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
convert_to_tensor
=================
.. autofunction:: dragon.vm.tensorflow.convert_to_tensor
.. _tf.constant(...): constant.html
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
...@@ -1007,10 +1007,10 @@ def tile_spec(args, inputs, outputs): ...@@ -1007,10 +1007,10 @@ def tile_spec(args, inputs, outputs):
@register('Transpose') @register('Transpose')
def transpose_spec(args, inputs, outputs): def transpose_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
perm = args['perm']
if 'perm_desc' in args or 'perm_descs' in args: if 'perm_desc' in args or 'perm_descs' in args:
return outputs return outputs
try: try:
perm = args['perm']
if perm is None: if perm is None:
perm = list(range(((len(inputs[0].shape)) - 1), -1, -1)) perm = list(range(((len(inputs[0].shape)) - 1), -1, -1))
out_shape = list(inputs[0].shape[:]) out_shape = list(inputs[0].shape[:])
......
...@@ -27,22 +27,24 @@ from dragon.core.util import nest ...@@ -27,22 +27,24 @@ from dragon.core.util import nest
class Tensor(types.TensorMetaclass): class Tensor(types.TensorMetaclass):
"""Tensor abstraction for graph executing.""" """Tensor abstraction for graph executing."""
def __init__(self, shape=None, dtype=None, name=None): def __init__(self, shape, dtype='float32', name=None):
"""Create a ``Tensor``. """Create a ``Tensor``.
Parameters Parameters
---------- ----------
shape : Sequence[int], optional shape : Sequence[int]
The optional tensor shape. The tensor shape.
dtype : str, optional dtype : str, optional, default='float32'
The optional data type. The optional data type.
name : str, optional name : str, optional
The optional tensor name. The optional tensor name.
""" """
self._op, self._grad = None, None self._shape = self.shape = shape
self._name, self._shape, self._dtype = None, None, None self._dtype = self.dtype = dtype
self.name, self.shape, self.dtype = name, shape, dtype self._name = self.name = name
self._op = None
self._grad = None
@property @property
def dtype(self): def dtype(self):
...@@ -229,6 +231,34 @@ class Tensor(types.TensorMetaclass): ...@@ -229,6 +231,34 @@ class Tensor(types.TensorMetaclass):
""" """
@classmethod
def from_value(cls, value, dtype=None, name=None):
"""Return a tensor converted from the given value.
Parameters
----------
value : array_like
The value to convert.
dtype: str, optional
The optional data type.
name: str, optional
The optional tensor name.
Returns
-------
dragon.Tensor
The output tensor.
"""
if not isinstance(value, numpy.ndarray):
value = numpy.array(value, dtype if dtype else 'float32')
name = workspace.get_workspace().unique_name(
name=context.get_name_scope() + (name if name else 'Const'),
suffix=':0',
namespace='Tensor')
ref = TensorRef(name, list(value.shape), str(value.dtype))
return ref.set_value(value)
def get_value(self): def get_value(self):
"""Return the value of implementation. """Return the value of implementation.
...@@ -394,36 +424,6 @@ class Tensor(types.TensorMetaclass): ...@@ -394,36 +424,6 @@ class Tensor(types.TensorMetaclass):
""" """
return self._register_as('uniform', low=low, high=high) return self._register_as('uniform', low=low, high=high)
@classmethod
def convert_to(cls, value, dtype=None, name=None):
"""Convert the given ``value`` to a ``dragon.Tensor``.
Parameters
----------
value : array_like
The value to convert.
dtype: str, optional
The optional data type.
name: str, optional
The optional name for this tensor.
Returns
-------
dragon.Tensor
The constant contains the value.
"""
if not isinstance(value, numpy.ndarray):
value = numpy.array(value, dtype if dtype else 'float32')
return TensorRef(
name=workspace.get_workspace().unique_name(
name=context.get_name_scope() + (name if name else 'Const'),
suffix=':0',
namespace='Tensor'),
shape=list(value.shape),
dtype=str(value.dtype),
).set_value(value)
def _register_as(self, type, **kwargs): def _register_as(self, type, **kwargs):
"""Fill self with the specific type of filler.""" """Fill self with the specific type of filler."""
filler = dragon_pb2.FillerInfo() filler = dragon_pb2.FillerInfo()
......
...@@ -58,11 +58,13 @@ class EagerTensor(Tensor): ...@@ -58,11 +58,13 @@ class EagerTensor(Tensor):
if shape is not None: if shape is not None:
self._from_shape(shape, kwargs.get('dtype', 'float32')) self._from_shape(shape, kwargs.get('dtype', 'float32'))
elif len(args) == 1: elif len(args) == 1:
self._from_numpy( if not isinstance(args[0], numpy.ndarray):
args[0] if isinstance(args[0], numpy.ndarray) dtype = kwargs.get('dtype', 'float32')
else numpy.array(args[0], kwargs.get('dtype', 'float32')), self._from_array(numpy.array(args[0], dtype))
kwargs.get('copy', True), else:
) dtype = kwargs.get('dtype', None)
self._from_array(numpy.array(
args[0], dtype, copy=kwargs.get('copy', True)))
else: else:
raise ValueError('Excepted at most one argument.') raise ValueError('Excepted at most one argument.')
...@@ -227,6 +229,29 @@ class EagerTensor(Tensor): ...@@ -227,6 +229,29 @@ class EagerTensor(Tensor):
""" """
@classmethod
def from_value(cls, value, dtype=None, name=None):
"""Return a tensor converted from the given value.
The input ``value``
Parameters
----------
value : array_like
The value to convert.
dtype: str, optional
The optional data type.
name: str, optional
The optional tensor name.
Returns
-------
dragon.EagerTensor
The output tensor.
"""
return EagerTensor(value, dtype=dtype, name=name, copy=False)
def get_value(self): def get_value(self):
"""Return the value of implementation. """Return the value of implementation.
...@@ -408,17 +433,16 @@ class EagerTensor(Tensor): ...@@ -408,17 +433,16 @@ class EagerTensor(Tensor):
""" """
def _from_numpy(self, array, copy): def _from_array(self, array):
"""Create impl from the numpy array.""" """Create implementation from the array."""
ws = workspace.get_workspace() ws = workspace.get_workspace()
array = array.copy() if copy else array
self._const_size = array.size self._const_size = array.size
self._gc, self._is_leaf = ws.collectors.TENSOR, True self._gc, self._is_leaf = ws.collectors.TENSOR, True
self._impl = ws.create_tensor(self._gc.alloc( self._impl = ws.create_tensor(self._gc.alloc(
context.get_eager_scope())).FromNumpy(array) context.get_eager_scope())).FromNumpy(array)
def _from_shape(self, shape, dtype): def _from_shape(self, shape, dtype):
"""Create impl from the shape and data type.""" """Create implementation from the shape."""
ws = workspace.get_workspace() ws = workspace.get_workspace()
self._gc, self._is_leaf = ws.collectors.TENSOR, True self._gc, self._is_leaf = ws.collectors.TENSOR, True
self._impl = ws.create_tensor(self._gc.alloc( self._impl = ws.create_tensor(self._gc.alloc(
......
...@@ -20,7 +20,6 @@ from dragon.core.ops import math_ops ...@@ -20,7 +20,6 @@ from dragon.core.ops import math_ops
from dragon.core.ops import array_ops from dragon.core.ops import array_ops
from dragon.core.ops.utils import ArgHelper from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
...@@ -53,7 +52,7 @@ def dropout(inputs, ratio=0.5, **kwargs): ...@@ -53,7 +52,7 @@ def dropout(inputs, ratio=0.5, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Dropout op_lib = activation_ops_lib.Dropout
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -96,7 +95,7 @@ def drop_block2d(inputs, ratio=0.1, block_size=7, data_format='NCHW', **kwargs): ...@@ -96,7 +95,7 @@ def drop_block2d(inputs, ratio=0.1, block_size=7, data_format='NCHW', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.DropBlock2d op_lib = activation_ops_lib.DropBlock2d
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -132,7 +131,7 @@ def drop_path(inputs, ratio=0.2, **kwargs): ...@@ -132,7 +131,7 @@ def drop_path(inputs, ratio=0.2, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.DropPath op_lib = activation_ops_lib.DropPath
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -177,7 +176,7 @@ def elu(inputs, alpha=1.0, **kwargs): ...@@ -177,7 +176,7 @@ def elu(inputs, alpha=1.0, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['alpha'] = float(alpha) args['alpha'] = float(alpha)
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Elu op_lib = activation_ops_lib.Elu
...@@ -219,7 +218,7 @@ def hardsigmoid(inputs, alpha=0.2, beta=0.5, **kwargs): ...@@ -219,7 +218,7 @@ def hardsigmoid(inputs, alpha=0.2, beta=0.5, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['alpha'] = float(alpha) args['alpha'] = float(alpha)
args['beta'] = float(beta) args['beta'] = float(beta)
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
...@@ -265,7 +264,7 @@ def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs): ...@@ -265,7 +264,7 @@ def hardswish(inputs, alpha=0.2, beta=0.5, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['alpha'] = float(alpha) args['alpha'] = float(alpha)
args['beta'] = float(beta) args['beta'] = float(beta)
op_lib = activation_ops_lib.HardSwish op_lib = activation_ops_lib.HardSwish
...@@ -312,7 +311,7 @@ def leaky_relu(inputs, alpha=0.2, **kwargs): ...@@ -312,7 +311,7 @@ def leaky_relu(inputs, alpha=0.2, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['alpha'] = float(alpha) args['alpha'] = float(alpha)
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Relu op_lib = activation_ops_lib.Relu
...@@ -354,15 +353,10 @@ def log_softmax(inputs, axis=-1, **kwargs): ...@@ -354,15 +353,10 @@ def log_softmax(inputs, axis=-1, **kwargs):
""" """
return math_ops.sub( return math_ops.sub(
[inputs, [inputs, math_ops.log(array_ops.sum(
math_ops.log(
array_ops.sum(
math_ops.exp(inputs, **kwargs), math_ops.exp(inputs, **kwargs),
axis=[axis], axis=[axis], keep_dims=True, **kwargs), **kwargs)],
keep_dims=True, **kwargs
**kwargs),
**kwargs)
], **kwargs
) )
...@@ -403,7 +397,7 @@ def prelu(inputs, channel_shared=False, data_format='NCHW', **kwargs): ...@@ -403,7 +397,7 @@ def prelu(inputs, channel_shared=False, data_format='NCHW', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = activation_ops_lib.PRelu op_lib = activation_ops_lib.PRelu
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -445,7 +439,7 @@ def relu(inputs, **kwargs): ...@@ -445,7 +439,7 @@ def relu(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Relu op_lib = activation_ops_lib.Relu
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -488,7 +482,7 @@ def relu6(inputs, **kwargs): ...@@ -488,7 +482,7 @@ def relu6(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Relu6 op_lib = activation_ops_lib.Relu6
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -533,7 +527,7 @@ def selu(inputs, alpha=1.67326, gamma=1.0507, **kwargs): ...@@ -533,7 +527,7 @@ def selu(inputs, alpha=1.67326, gamma=1.0507, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['alpha'], args['gamma'] = float(alpha), float(gamma) args['alpha'], args['gamma'] = float(alpha), float(gamma)
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Selu op_lib = activation_ops_lib.Selu
...@@ -573,7 +567,7 @@ def sigmoid(inputs, **kwargs): ...@@ -573,7 +567,7 @@ def sigmoid(inputs, **kwargs):
The output tensor The output tensor
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Activation op_lib = activation_ops_lib.Activation
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -613,7 +607,7 @@ def softmax(inputs, axis=-1, **kwargs): ...@@ -613,7 +607,7 @@ def softmax(inputs, axis=-1, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Softmax op_lib = activation_ops_lib.Softmax
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -650,7 +644,7 @@ def tanh(inputs, **kwargs): ...@@ -650,7 +644,7 @@ def tanh(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = activation_ops_lib.Activation op_lib = activation_ops_lib.Activation
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -688,7 +682,7 @@ def swish(inputs, **kwargs): ...@@ -688,7 +682,7 @@ def swish(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = activation_ops_lib.Activation op_lib = activation_ops_lib.Activation
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
......
...@@ -21,7 +21,6 @@ from dragon.core.framework import types ...@@ -21,7 +21,6 @@ from dragon.core.framework import types
from dragon.core.ops import array_ops_lib from dragon.core.ops import array_ops_lib
from dragon.core.ops.utils import ArgHelper from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
from dragon.core.util import nest from dragon.core.util import nest
...@@ -58,7 +57,7 @@ def argmax(inputs, axis=None, keep_dims=False, **kwargs): ...@@ -58,7 +57,7 @@ def argmax(inputs, axis=None, keep_dims=False, **kwargs):
The index of maximum elements. The index of maximum elements.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.ArgReduce op_lib = array_ops_lib.ArgReduce
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -104,7 +103,7 @@ def argmin(inputs, axis=None, keep_dims=False, **kwargs): ...@@ -104,7 +103,7 @@ def argmin(inputs, axis=None, keep_dims=False, **kwargs):
The index of minimum elements. The index of minimum elements.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.ArgReduce op_lib = array_ops_lib.ArgReduce
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -152,7 +151,7 @@ def argsort(inputs, axis=-1, descending=False, **kwargs): ...@@ -152,7 +151,7 @@ def argsort(inputs, axis=-1, descending=False, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Sort op_lib = array_ops_lib.Sort
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -202,7 +201,7 @@ def broadcast_to(inputs, shape, **kwargs): ...@@ -202,7 +201,7 @@ def broadcast_to(inputs, shape, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Expand op_lib = array_ops_lib.Expand
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -237,7 +236,7 @@ def cast(inputs, dtype, **kwargs): ...@@ -237,7 +236,7 @@ def cast(inputs, dtype, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = array_ops_lib.Cast op_lib = array_ops_lib.Cast
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -281,7 +280,7 @@ def channel_affine(inputs, axis=1, num_axes=1, **kwargs): ...@@ -281,7 +280,7 @@ def channel_affine(inputs, axis=1, num_axes=1, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = array_ops_lib.ChannelAffine op_lib = array_ops_lib.ChannelAffine
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -345,7 +344,7 @@ def channel_normalize( ...@@ -345,7 +344,7 @@ def channel_normalize(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.ChannelNormalize op_lib = array_ops_lib.ChannelNormalize
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -380,7 +379,7 @@ def channel_shuffle(inputs, axis=0, group=1, **kwargs): ...@@ -380,7 +379,7 @@ def channel_shuffle(inputs, axis=0, group=1, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.ChannelShuffle op_lib = array_ops_lib.ChannelShuffle
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -425,7 +424,7 @@ def concat(inputs, axis=0, **kwargs): ...@@ -425,7 +424,7 @@ def concat(inputs, axis=0, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Concat op_lib = array_ops_lib.Concat
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(axis=axis).apply(inputs) return op_lib.instantiate(axis=axis).apply(inputs)
...@@ -480,7 +479,7 @@ def cumsum(inputs, axis=0, exclusive=False, reverse=False, **kwargs): ...@@ -480,7 +479,7 @@ def cumsum(inputs, axis=0, exclusive=False, reverse=False, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Cumulative op_lib = array_ops_lib.Cumulative
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -528,7 +527,7 @@ def expand_dims(inputs, axis, **kwargs): ...@@ -528,7 +527,7 @@ def expand_dims(inputs, axis, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args.pop('axis') args.pop('axis')
args['axes'] = None if axis is None else nest.flatten(axis) args['axes'] = None if axis is None else nest.flatten(axis)
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
...@@ -573,7 +572,7 @@ def flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs): ...@@ -573,7 +572,7 @@ def flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = array_ops_lib.Flatten op_lib = array_ops_lib.Flatten
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -617,7 +616,7 @@ def identity(inputs, **kwargs): ...@@ -617,7 +616,7 @@ def identity(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = array_ops_lib.Identity op_lib = array_ops_lib.Identity
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -661,7 +660,7 @@ def index_select(inputs, index, axis=0, **kwargs): ...@@ -661,7 +660,7 @@ def index_select(inputs, index, axis=0, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
axes = nest.flatten(axis) axes = nest.flatten(axis)
axes.sort() axes.sort()
if axes[-1] != (axes[0] + len(axes) - 1): if axes[-1] != (axes[0] + len(axes) - 1):
...@@ -677,7 +676,7 @@ def index_select(inputs, index, axis=0, **kwargs): ...@@ -677,7 +676,7 @@ def index_select(inputs, index, axis=0, **kwargs):
).apply([inputs, index]) ).apply([inputs, index])
else: else:
if not isinstance(index, Tensor): if not isinstance(index, Tensor):
index = Tensor.convert_to(index, 'int64') index = Tensor.from_value(index, 'int64')
args['inputs'], args['index'] = \ args['inputs'], args['index'] = \
[args['inputs'], index], None [args['inputs'], index], None
args['axis'], args['num_axes'] = axes[0], len(axes) args['axis'], args['num_axes'] = axes[0], len(axes)
...@@ -719,7 +718,7 @@ def linspace(start, stop, num, dtype='int64', axis=0, **kwargs): ...@@ -719,7 +718,7 @@ def linspace(start, stop, num, dtype='int64', axis=0, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['dtype'] = args['dtype'].lower() args['dtype'] = args['dtype'].lower()
args['start'] = nest.flatten(start) args['start'] = nest.flatten(start)
args['stop'] = nest.flatten(stop) args['stop'] = nest.flatten(stop)
...@@ -758,7 +757,7 @@ def masked_select(inputs, **kwargs): ...@@ -758,7 +757,7 @@ def masked_select(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.MaskedSelect op_lib = array_ops_lib.MaskedSelect
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate().apply(inputs) return op_lib.instantiate().apply(inputs)
...@@ -802,7 +801,7 @@ def max(inputs, axis=None, keep_dims=False, **kwargs): ...@@ -802,7 +801,7 @@ def max(inputs, axis=None, keep_dims=False, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args.pop('axis') args.pop('axis')
args['axes'] = None if axis is None else nest.flatten(axis) args['axes'] = None if axis is None else nest.flatten(axis)
op_lib = array_ops_lib.Reduce op_lib = array_ops_lib.Reduce
...@@ -853,7 +852,7 @@ def mean(inputs, axis=None, keep_dims=False, **kwargs): ...@@ -853,7 +852,7 @@ def mean(inputs, axis=None, keep_dims=False, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args.pop('axis') args.pop('axis')
args['axes'] = None if axis is None else nest.flatten(axis) args['axes'] = None if axis is None else nest.flatten(axis)
op_lib = array_ops_lib.Reduce op_lib = array_ops_lib.Reduce
...@@ -904,7 +903,7 @@ def min(inputs, axis=None, keep_dims=False, **kwargs): ...@@ -904,7 +903,7 @@ def min(inputs, axis=None, keep_dims=False, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args.pop('axis') args.pop('axis')
args['axes'] = None if axis is None else nest.flatten(axis) args['axes'] = None if axis is None else nest.flatten(axis)
op_lib = array_ops_lib.Reduce op_lib = array_ops_lib.Reduce
...@@ -963,7 +962,7 @@ def moments(inputs, axis=None, keep_dims=False, **kwargs): ...@@ -963,7 +962,7 @@ def moments(inputs, axis=None, keep_dims=False, **kwargs):
The variance tensor. The variance tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args.pop('axis') args.pop('axis')
args['axes'] = None if axis is None else nest.flatten(axis) args['axes'] = None if axis is None else nest.flatten(axis)
op_lib = array_ops_lib.Moments op_lib = array_ops_lib.Moments
...@@ -1002,7 +1001,7 @@ def multinomial(inputs, num_samples=1, epsilon=0, normalize=False, **kwargs): ...@@ -1002,7 +1001,7 @@ def multinomial(inputs, num_samples=1, epsilon=0, normalize=False, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['epsilon'] = float(epsilon) args['epsilon'] = float(epsilon)
op_lib = array_ops_lib.Multinomial op_lib = array_ops_lib.Multinomial
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -1033,7 +1032,7 @@ def nonzero(inputs, **kwargs): ...@@ -1033,7 +1032,7 @@ def nonzero(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.NonZero op_lib = array_ops_lib.NonZero
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate().apply([inputs]) return op_lib.instantiate().apply([inputs])
...@@ -1082,7 +1081,7 @@ def one_hot(inputs, depth, on_value=1, off_value=0, **kwargs): ...@@ -1082,7 +1081,7 @@ def one_hot(inputs, depth, on_value=1, off_value=0, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.OneHot op_lib = array_ops_lib.OneHot
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -1140,7 +1139,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs): ...@@ -1140,7 +1139,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['value'] = float(value) args['value'] = float(value)
args['mode'] = mode.upper() args['mode'] = mode.upper()
pads_begin, pads_end = [], [] pads_begin, pads_end = [], []
...@@ -1186,7 +1185,7 @@ def permutation(limit, dtype='int64', **kwargs): ...@@ -1186,7 +1185,7 @@ def permutation(limit, dtype='int64', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['dtype'] = args['dtype'].lower() args['dtype'] = args['dtype'].lower()
op_lib = array_ops_lib.Permutation op_lib = array_ops_lib.Permutation
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
...@@ -1236,7 +1235,7 @@ def range(start, limit=None, delta=1, dtype='int64', **kwargs): ...@@ -1236,7 +1235,7 @@ def range(start, limit=None, delta=1, dtype='int64', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['dtype'] = args['dtype'].lower() args['dtype'] = args['dtype'].lower()
if limit is None: if limit is None:
args['slice'] = (float(start), float(delta)) args['slice'] = (float(start), float(delta))
...@@ -1278,7 +1277,7 @@ def repeat(inputs, axis=None, repeats=1, **kwargs): ...@@ -1278,7 +1277,7 @@ def repeat(inputs, axis=None, repeats=1, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Repeat op_lib = array_ops_lib.Repeat
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -1325,7 +1324,7 @@ def reshape(inputs, shape, **kwargs): ...@@ -1325,7 +1324,7 @@ def reshape(inputs, shape, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
op_lib = array_ops_lib.Reshape op_lib = array_ops_lib.Reshape
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -1359,7 +1358,7 @@ def shape(inputs, **kwargs): ...@@ -1359,7 +1358,7 @@ def shape(inputs, **kwargs):
The tensor shape. The tensor shape.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Shape op_lib = array_ops_lib.Shape
if isinstance(inputs, EagerTensor): if isinstance(inputs, EagerTensor):
return op_lib.instantiate().apply([inputs]) return op_lib.instantiate().apply([inputs])
...@@ -1408,7 +1407,7 @@ def slice(inputs, starts, sizes, **kwargs): ...@@ -1408,7 +1407,7 @@ def slice(inputs, starts, sizes, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Slice op_lib = array_ops_lib.Slice
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -1453,7 +1452,7 @@ def sort(inputs, axis=-1, descending=False, **kwargs): ...@@ -1453,7 +1452,7 @@ def sort(inputs, axis=-1, descending=False, **kwargs):
The value and index tensor. The value and index tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Sort op_lib = array_ops_lib.Sort
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -1519,7 +1518,7 @@ def split( ...@@ -1519,7 +1518,7 @@ def split(
The outputs. The outputs.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Split op_lib = array_ops_lib.Split
if nest.is_sequence(num_or_size_splits): if nest.is_sequence(num_or_size_splits):
num_splits = len(num_or_size_splits) num_splits = len(num_or_size_splits)
...@@ -1580,7 +1579,7 @@ def squeeze(inputs, axis=None, **kwargs): ...@@ -1580,7 +1579,7 @@ def squeeze(inputs, axis=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args.pop('axis') args.pop('axis')
args['axes'] = None if axis is None else nest.flatten(axis) args['axes'] = None if axis is None else nest.flatten(axis)
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
...@@ -1612,7 +1611,7 @@ def stack(inputs, axis=0, **kwargs): ...@@ -1612,7 +1611,7 @@ def stack(inputs, axis=0, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Stack op_lib = array_ops_lib.Stack
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -1658,7 +1657,7 @@ def sum(inputs, axis=None, keep_dims=False, **kwargs): ...@@ -1658,7 +1657,7 @@ def sum(inputs, axis=None, keep_dims=False, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args.pop('axis') args.pop('axis')
args['axes'] = None if axis is None else nest.flatten(axis) args['axes'] = None if axis is None else nest.flatten(axis)
op_lib = array_ops_lib.Reduce op_lib = array_ops_lib.Reduce
...@@ -1691,7 +1690,7 @@ def tile(inputs, repeats, **kwargs): ...@@ -1691,7 +1690,7 @@ def tile(inputs, repeats, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Tile op_lib = array_ops_lib.Tile
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -1730,7 +1729,7 @@ def transpose(inputs, perm=None, **kwargs): ...@@ -1730,7 +1729,7 @@ def transpose(inputs, perm=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Transpose op_lib = array_ops_lib.Transpose
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -1780,7 +1779,7 @@ def top_k(inputs, k=1, axis=-1, largest=True, sorted=True, **kwargs): ...@@ -1780,7 +1779,7 @@ def top_k(inputs, k=1, axis=-1, largest=True, sorted=True, **kwargs):
The value and index tensor. The value and index tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.TopK op_lib = array_ops_lib.TopK
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -1836,7 +1835,7 @@ def unique(inputs, return_inverse=False, return_counts=False, **kwargs): ...@@ -1836,7 +1835,7 @@ def unique(inputs, return_inverse=False, return_counts=False, **kwargs):
The counts tensor. The counts tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Unique op_lib = array_ops_lib.Unique
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate( return op_lib.instantiate(
...@@ -1878,7 +1877,7 @@ def where(inputs, **kwargs): ...@@ -1878,7 +1877,7 @@ def where(inputs, **kwargs):
""" """
if types.is_tensor(inputs) or len(inputs) == 1: if types.is_tensor(inputs) or len(inputs) == 1:
return nonzero(inputs, **kwargs) return nonzero(inputs, **kwargs)
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = array_ops_lib.Where op_lib = array_ops_lib.Where
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate().apply(inputs) return op_lib.instantiate().apply(inputs)
......
...@@ -19,7 +19,6 @@ from dragon.core.framework import ops ...@@ -19,7 +19,6 @@ from dragon.core.framework import ops
from dragon.core.ops import control_flow_ops_lib from dragon.core.ops import control_flow_ops_lib
from dragon.core.ops.utils import ArgHelper from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
...@@ -45,7 +44,7 @@ def assign(inputs, starts=None, sizes=None, **kwargs): ...@@ -45,7 +44,7 @@ def assign(inputs, starts=None, sizes=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype) inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype)
op_lib = control_flow_ops_lib.Assign op_lib = control_flow_ops_lib.Assign
...@@ -80,7 +79,7 @@ def masked_assign(inputs, **kwargs): ...@@ -80,7 +79,7 @@ def masked_assign(inputs, **kwargs):
The input tensor. The input tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inplace = args.pop('inplace') if 'inplace' in args else False inplace = args.pop('inplace') if 'inplace' in args else False
inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype) inputs[1] = ops.scalar_to_tensor(inputs[1], inputs[0].dtype)
op_lib = control_flow_ops_lib.MaskedAssign op_lib = control_flow_ops_lib.MaskedAssign
......
...@@ -17,8 +17,8 @@ from __future__ import print_function ...@@ -17,8 +17,8 @@ from __future__ import print_function
from dragon.core import distributed from dragon.core import distributed
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.ops import distributed_ops_lib from dragon.core.ops import distributed_ops_lib
from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
...@@ -40,7 +40,7 @@ def all_reduce(inputs, operation='MEAN', group=None, **kwargs): ...@@ -40,7 +40,7 @@ def all_reduce(inputs, operation='MEAN', group=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if group is None: if group is None:
group = distributed.get_group() group = distributed.get_group()
if group is None: if group is None:
...@@ -80,7 +80,7 @@ def broadcast(inputs, root=0, group=None, **kwargs): ...@@ -80,7 +80,7 @@ def broadcast(inputs, root=0, group=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if group is None: if group is None:
group = distributed.get_group() group = distributed.get_group()
if group is None: if group is None:
......
...@@ -16,8 +16,8 @@ from __future__ import print_function ...@@ -16,8 +16,8 @@ from __future__ import print_function
from dragon.core.autograph.op_def import OpDef from dragon.core.autograph.op_def import OpDef
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
def python_plugin( def python_plugin(
...@@ -52,7 +52,7 @@ def python_plugin( ...@@ -52,7 +52,7 @@ def python_plugin(
The outputs. The outputs.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if context.executing_eagerly(): if context.executing_eagerly():
raise RuntimeError('Excepted the graph execution mode.') raise RuntimeError('Excepted the graph execution mode.')
else: else:
...@@ -77,7 +77,7 @@ def stop_gradient(inputs, **kwargs): ...@@ -77,7 +77,7 @@ def stop_gradient(inputs, **kwargs):
An identity of input. An identity of input.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if context.executing_eagerly(): if context.executing_eagerly():
raise RuntimeError('Excepted the graph execution mode.') raise RuntimeError('Excepted the graph execution mode.')
else: else:
......
...@@ -24,7 +24,6 @@ from dragon.core.framework import types ...@@ -24,7 +24,6 @@ from dragon.core.framework import types
from dragon.core.ops import init_ops_lib from dragon.core.ops import init_ops_lib
from dragon.core.ops.utils import ArgHelper from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
def constant(value, dtype=None, shape=None, name=None): def constant(value, dtype=None, shape=None, name=None):
...@@ -55,17 +54,33 @@ def constant(value, dtype=None, shape=None, name=None): ...@@ -55,17 +54,33 @@ def constant(value, dtype=None, shape=None, name=None):
The output tensor. The output tensor.
""" """
if types.is_eager_tensor(value): # Determine the initial value.
value = value.numpy(True) if types.is_tensor(value):
if dtype is not None: initial_value = value.get_value()
value = value.astype(dtype)
else: else:
value = numpy.array(value, dtype=dtype) initial_value = value
value = value.reshape(shape) if shape else value # Determine the data type and shape.
initial_value = numpy.array(initial_value, dtype)
if not hasattr(value, 'dtype'):
# Discard the default 64 bit types.
if initial_value.dtype == numpy.float64:
initial_value = initial_value.astype(numpy.float32)
elif initial_value.dtype == numpy.int64:
initial_value = initial_value.astype(numpy.int32)
# Determine the shape.
if shape is not None:
if initial_value.size == 1:
# Broadcast with scalar value.
scalar = initial_value.flatten()[0]
initial_value = numpy.empty(shape, initial_value.dtype)
initial_value.fill(scalar)
else:
# Reshape.
initial_value = initial_value.reshape(shape)
if context.executing_eagerly(): if context.executing_eagerly():
return EagerTensor(value) return EagerTensor(initial_value, name=name)
else: else:
return Tensor.convert_to(value, str(value.dtype), name) return Tensor.from_value(initial_value, dtype, name)
def eye(n, m=None, k=0, dtype='float32', **kwargs): def eye(n, m=None, k=0, dtype='float32', **kwargs):
...@@ -105,7 +120,7 @@ def eye(n, m=None, k=0, dtype='float32', **kwargs): ...@@ -105,7 +120,7 @@ def eye(n, m=None, k=0, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
m = n if m is None else m m = n if m is None else m
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
op_lib = init_ops_lib.Eye op_lib = init_ops_lib.Eye
...@@ -165,7 +180,7 @@ def eye_like(other, k=0, dtype='float32', **kwargs): ...@@ -165,7 +180,7 @@ def eye_like(other, k=0, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
op_lib = init_ops_lib.Eye op_lib = init_ops_lib.Eye
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -198,7 +213,7 @@ def fill(shape, value=0, dtype=None, **kwargs): ...@@ -198,7 +213,7 @@ def fill(shape, value=0, dtype=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['value'] = float(value) args['value'] = float(value)
if dtype is None: if dtype is None:
args['dtype'] = str(numpy.array(value).dtype) args['dtype'] = str(numpy.array(value).dtype)
...@@ -242,7 +257,7 @@ def glorot_normal(shape, scale=2.0, mode='fan_in', dtype='float32', **kwargs): ...@@ -242,7 +257,7 @@ def glorot_normal(shape, scale=2.0, mode='fan_in', dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['scale'] = float(scale) args['scale'] = float(scale)
args['mode'] = mode.lower() args['mode'] = mode.lower()
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
...@@ -284,7 +299,7 @@ def glorot_uniform(shape, mode='fan_in', scale=3.0, dtype='float32', **kwargs): ...@@ -284,7 +299,7 @@ def glorot_uniform(shape, mode='fan_in', scale=3.0, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['scale'] = float(scale) args['scale'] = float(scale)
args['mode'] = mode.lower() args['mode'] = mode.lower()
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
...@@ -352,7 +367,7 @@ def ones_like(other, dtype='float32', **kwargs): ...@@ -352,7 +367,7 @@ def ones_like(other, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
op_lib = init_ops_lib.Fill op_lib = init_ops_lib.Fill
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -387,7 +402,7 @@ def random_normal(shape, mean=0, std=1, dtype='float32', **kwargs): ...@@ -387,7 +402,7 @@ def random_normal(shape, mean=0, std=1, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['mean'] = float(mean) args['mean'] = float(mean)
args['std'] = float(std) args['std'] = float(std)
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
...@@ -427,7 +442,7 @@ def random_normal_like(other, mean=0, std=1, dtype='float32', **kwargs): ...@@ -427,7 +442,7 @@ def random_normal_like(other, mean=0, std=1, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['mean'] = float(mean) args['mean'] = float(mean)
args['std'] = float(std) args['std'] = float(std)
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
...@@ -467,7 +482,7 @@ def random_uniform(shape, low=-1, high=1, dtype='float32', **kwargs): ...@@ -467,7 +482,7 @@ def random_uniform(shape, low=-1, high=1, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['low'], args['high'] = float(low), float(high) args['low'], args['high'] = float(low), float(high)
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
op_lib = init_ops_lib.RandomUniform op_lib = init_ops_lib.RandomUniform
...@@ -506,7 +521,7 @@ def random_uniform_like(other, low=-1, high=1, dtype='float32', **kwargs): ...@@ -506,7 +521,7 @@ def random_uniform_like(other, low=-1, high=1, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['low'], args['high'] = float(low), float(high) args['low'], args['high'] = float(low), float(high)
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
op_lib = init_ops_lib.RandomUniform op_lib = init_ops_lib.RandomUniform
...@@ -545,7 +560,7 @@ def truncated_normal(shape, mean=0, std=1, dtype='float32', **kwargs): ...@@ -545,7 +560,7 @@ def truncated_normal(shape, mean=0, std=1, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['mean'], args['std'] = float(mean), float(std) args['mean'], args['std'] = float(mean), float(std)
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
op_lib = init_ops_lib.TruncatedNormal op_lib = init_ops_lib.TruncatedNormal
...@@ -613,7 +628,7 @@ def zeros_like(other, dtype='float32', **kwargs): ...@@ -613,7 +628,7 @@ def zeros_like(other, dtype='float32', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
trainable = args.pop('trainable') if 'trainable' in args else False trainable = args.pop('trainable') if 'trainable' in args else False
op_lib = init_ops_lib.Fill op_lib = init_ops_lib.Fill
if context.executing_eagerly(): if context.executing_eagerly():
......
...@@ -17,8 +17,8 @@ from __future__ import print_function ...@@ -17,8 +17,8 @@ from __future__ import print_function
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.ops import activation_ops from dragon.core.ops import activation_ops
from dragon.core.ops import loss_ops_lib from dragon.core.ops import loss_ops_lib
from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
...@@ -46,7 +46,7 @@ def ctc_loss(inputs, padding_mask=-1, **kwargs): ...@@ -46,7 +46,7 @@ def ctc_loss(inputs, padding_mask=-1, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs[0] = activation_ops.softmax(inputs[0], axis=2) inputs[0] = activation_ops.softmax(inputs[0], axis=2)
op_lib = loss_ops_lib.Operator op_lib = loss_ops_lib.Operator
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -84,7 +84,7 @@ def l1_loss(inputs, reduction='mean', **kwargs): ...@@ -84,7 +84,7 @@ def l1_loss(inputs, reduction='mean', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
op_lib = loss_ops_lib.L1Loss op_lib = loss_ops_lib.L1Loss
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -124,7 +124,7 @@ def l2_loss(inputs, reduction='mean', **kwargs): ...@@ -124,7 +124,7 @@ def l2_loss(inputs, reduction='mean', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
op_lib = loss_ops_lib.L2Loss op_lib = loss_ops_lib.L2Loss
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -166,7 +166,7 @@ def nll_loss( ...@@ -166,7 +166,7 @@ def nll_loss(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
op_lib = loss_ops_lib.NLLLoss op_lib = loss_ops_lib.NLLLoss
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -209,7 +209,7 @@ def sigmoid_cross_entropy(inputs, reduction='valid', **kwargs): ...@@ -209,7 +209,7 @@ def sigmoid_cross_entropy(inputs, reduction='valid', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
op_lib = loss_ops_lib.SigmoidCrossEntropy op_lib = loss_ops_lib.SigmoidCrossEntropy
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -258,7 +258,7 @@ def sigmoid_focal_loss( ...@@ -258,7 +258,7 @@ def sigmoid_focal_loss(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['alpha'] = float(args['alpha']) args['alpha'] = float(args['alpha'])
args['gamma'] = float(args['gamma']) args['gamma'] = float(args['gamma'])
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
...@@ -305,7 +305,7 @@ def smooth_l1_loss(inputs, beta=1., reduction='mean', **kwargs): ...@@ -305,7 +305,7 @@ def smooth_l1_loss(inputs, beta=1., reduction='mean', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['beta'] = float(args['beta']) args['beta'] = float(args['beta'])
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
op_lib = loss_ops_lib.SmoothL1Loss op_lib = loss_ops_lib.SmoothL1Loss
...@@ -350,7 +350,7 @@ def softmax_cross_entropy(inputs, axis=1, reduction='mean', **kwargs): ...@@ -350,7 +350,7 @@ def softmax_cross_entropy(inputs, axis=1, reduction='mean', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
op_lib = loss_ops_lib.SoftmaxCrossEntropy op_lib = loss_ops_lib.SoftmaxCrossEntropy
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -402,7 +402,7 @@ def sparse_softmax_cross_entropy( ...@@ -402,7 +402,7 @@ def sparse_softmax_cross_entropy(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['reduction'] = reduction.upper() args['reduction'] = reduction.upper()
op_lib = loss_ops_lib.SparseSoftmaxCrossEntropy op_lib = loss_ops_lib.SparseSoftmaxCrossEntropy
if context.executing_eagerly(): if context.executing_eagerly():
......
...@@ -17,8 +17,8 @@ from __future__ import print_function ...@@ -17,8 +17,8 @@ from __future__ import print_function
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.framework import ops from dragon.core.framework import ops
from dragon.core.ops import math_ops_lib from dragon.core.ops import math_ops_lib
from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
...@@ -45,7 +45,7 @@ def abs(inputs, **kwargs): ...@@ -45,7 +45,7 @@ def abs(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Abs').apply([inputs]) return op_lib.instantiate(op_type='Abs').apply([inputs])
...@@ -79,7 +79,7 @@ def add(inputs, **kwargs): ...@@ -79,7 +79,7 @@ def add(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -113,7 +113,7 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs): ...@@ -113,7 +113,7 @@ def axpby(inputs, outputs=None, alpha=1., beta=1., **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['alpha'], args['beta'] = float(alpha), float(beta) args['alpha'], args['beta'] = float(alpha), float(beta)
op_lib = math_ops_lib.Axpby op_lib = math_ops_lib.Axpby
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -235,7 +235,7 @@ def ceil(inputs, **kwargs): ...@@ -235,7 +235,7 @@ def ceil(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Ceil').apply([inputs]) return op_lib.instantiate(op_type='Ceil').apply([inputs])
...@@ -264,7 +264,7 @@ def clip(inputs, low=None, high=None, **kwargs): ...@@ -264,7 +264,7 @@ def clip(inputs, low=None, high=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if low is not None: if low is not None:
args['low'] = float(args['low']) args['low'] = float(args['low'])
if high is not None: if high is not None:
...@@ -304,7 +304,7 @@ def cos(inputs, **kwargs): ...@@ -304,7 +304,7 @@ def cos(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Cos').apply([inputs]) return op_lib.instantiate(op_type='Cos').apply([inputs])
...@@ -338,7 +338,7 @@ def div(inputs, **kwargs): ...@@ -338,7 +338,7 @@ def div(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -389,7 +389,7 @@ def dot(inputs, **kwargs): ...@@ -389,7 +389,7 @@ def dot(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Dot').apply(inputs) return op_lib.instantiate(op_type='Dot').apply(inputs)
...@@ -424,7 +424,7 @@ def equal(inputs, **kwargs): ...@@ -424,7 +424,7 @@ def equal(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -457,7 +457,7 @@ def exp(inputs, **kwargs): ...@@ -457,7 +457,7 @@ def exp(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Exp').apply([inputs]) return op_lib.instantiate(op_type='Exp').apply([inputs])
...@@ -489,7 +489,7 @@ def floor(inputs, **kwargs): ...@@ -489,7 +489,7 @@ def floor(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Floor').apply([inputs]) return op_lib.instantiate(op_type='Floor').apply([inputs])
...@@ -522,7 +522,7 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs): ...@@ -522,7 +522,7 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.FullyConnected op_lib = math_ops_lib.FullyConnected
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -561,7 +561,7 @@ def greater(inputs, **kwargs): ...@@ -561,7 +561,7 @@ def greater(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -597,7 +597,7 @@ def greater_equal(inputs, **kwargs): ...@@ -597,7 +597,7 @@ def greater_equal(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -635,7 +635,7 @@ def invert(inputs, **kwargs): ...@@ -635,7 +635,7 @@ def invert(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Invert').apply([inputs]) return op_lib.instantiate(op_type='Invert').apply([inputs])
...@@ -667,7 +667,7 @@ def is_inf(inputs, **kwargs): ...@@ -667,7 +667,7 @@ def is_inf(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='IsInf').apply([inputs]) return op_lib.instantiate(op_type='IsInf').apply([inputs])
...@@ -699,7 +699,7 @@ def is_nan(inputs, **kwargs): ...@@ -699,7 +699,7 @@ def is_nan(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='IsNaN').apply([inputs]) return op_lib.instantiate(op_type='IsNaN').apply([inputs])
...@@ -731,7 +731,7 @@ def log(inputs, **kwargs): ...@@ -731,7 +731,7 @@ def log(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Log').apply([inputs]) return op_lib.instantiate(op_type='Log').apply([inputs])
...@@ -766,7 +766,7 @@ def less(inputs, **kwargs): ...@@ -766,7 +766,7 @@ def less(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -802,7 +802,7 @@ def less_equal(inputs, **kwargs): ...@@ -802,7 +802,7 @@ def less_equal(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -855,7 +855,7 @@ def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs): ...@@ -855,7 +855,7 @@ def matmul(inputs, transpose_a=False, transpose_b=False, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.MatMul op_lib = math_ops_lib.MatMul
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
...@@ -886,7 +886,7 @@ def maximum(inputs, **kwargs): ...@@ -886,7 +886,7 @@ def maximum(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -912,7 +912,7 @@ def minimum(inputs, **kwargs): ...@@ -912,7 +912,7 @@ def minimum(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -947,7 +947,7 @@ def mul(inputs, **kwargs): ...@@ -947,7 +947,7 @@ def mul(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -978,7 +978,7 @@ def negative(inputs, **kwargs): ...@@ -978,7 +978,7 @@ def negative(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Neg').apply([inputs]) return op_lib.instantiate(op_type='Neg').apply([inputs])
...@@ -1013,7 +1013,7 @@ def not_equal(inputs, **kwargs): ...@@ -1013,7 +1013,7 @@ def not_equal(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -1048,7 +1048,7 @@ def pow(inputs, **kwargs): ...@@ -1048,7 +1048,7 @@ def pow(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -1081,7 +1081,7 @@ def reciprocal(inputs, **kwargs): ...@@ -1081,7 +1081,7 @@ def reciprocal(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Reciprocal').apply([inputs]) return op_lib.instantiate(op_type='Reciprocal').apply([inputs])
...@@ -1113,7 +1113,7 @@ def round(inputs, **kwargs): ...@@ -1113,7 +1113,7 @@ def round(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Round').apply([inputs]) return op_lib.instantiate(op_type='Round').apply([inputs])
...@@ -1145,7 +1145,7 @@ def rsqrt(inputs, **kwargs): ...@@ -1145,7 +1145,7 @@ def rsqrt(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Rsqrt').apply([inputs]) return op_lib.instantiate(op_type='Rsqrt').apply([inputs])
...@@ -1183,7 +1183,7 @@ def sign(inputs, **kwargs): ...@@ -1183,7 +1183,7 @@ def sign(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Sign').apply([inputs]) return op_lib.instantiate(op_type='Sign').apply([inputs])
...@@ -1215,7 +1215,7 @@ def sin(inputs, **kwargs): ...@@ -1215,7 +1215,7 @@ def sin(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Sin').apply([inputs]) return op_lib.instantiate(op_type='Sin').apply([inputs])
...@@ -1247,7 +1247,7 @@ def sqrt(inputs, **kwargs): ...@@ -1247,7 +1247,7 @@ def sqrt(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Sqrt').apply([inputs]) return op_lib.instantiate(op_type='Sqrt').apply([inputs])
...@@ -1279,7 +1279,7 @@ def square(inputs, **kwargs): ...@@ -1279,7 +1279,7 @@ def square(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = math_ops_lib.UnaryOp op_lib = math_ops_lib.UnaryOp
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate(op_type='Square').apply([inputs]) return op_lib.instantiate(op_type='Square').apply([inputs])
...@@ -1313,7 +1313,7 @@ def sub(inputs, **kwargs): ...@@ -1313,7 +1313,7 @@ def sub(inputs, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
inputs = ops.remove_binary_scalar(inputs) inputs = ops.remove_binary_scalar(inputs)
op_lib = math_ops_lib.BinaryOp op_lib = math_ops_lib.BinaryOp
if context.executing_eagerly(): if context.executing_eagerly():
......
...@@ -16,8 +16,8 @@ from __future__ import print_function ...@@ -16,8 +16,8 @@ from __future__ import print_function
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.ops import metric_ops_lib from dragon.core.ops import metric_ops_lib
from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
...@@ -41,7 +41,7 @@ def accuracy(inputs, axis=-1, top_k=1, ignore_index=None, **kwargs): ...@@ -41,7 +41,7 @@ def accuracy(inputs, axis=-1, top_k=1, ignore_index=None, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = metric_ops_lib.Accuracy op_lib = metric_ops_lib.Accuracy
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
......
...@@ -17,8 +17,8 @@ from __future__ import print_function ...@@ -17,8 +17,8 @@ from __future__ import print_function
from dragon.core import distributed from dragon.core import distributed
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.ops import normalization_ops_lib from dragon.core.ops import normalization_ops_lib
from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
from dragon.core.util import nest from dragon.core.util import nest
...@@ -61,7 +61,7 @@ def batch_norm( ...@@ -61,7 +61,7 @@ def batch_norm(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['momentum'], args['epsilon'] = float(momentum), float(epsilon) args['momentum'], args['epsilon'] = float(momentum), float(epsilon)
op_lib = normalization_ops_lib.BatchNorm op_lib = normalization_ops_lib.BatchNorm
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -105,7 +105,7 @@ def group_norm(inputs, axis=-1, group=32, epsilon=1e-5, **kwargs): ...@@ -105,7 +105,7 @@ def group_norm(inputs, axis=-1, group=32, epsilon=1e-5, **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['epsilon'] = float(epsilon) args['epsilon'] = float(epsilon)
op_lib = normalization_ops_lib.GroupNorm op_lib = normalization_ops_lib.GroupNorm
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -217,7 +217,7 @@ def lp_normalize(inputs, axis=None, p=2, epsilon=1e-12, reduction='sum', **kwarg ...@@ -217,7 +217,7 @@ def lp_normalize(inputs, axis=None, p=2, epsilon=1e-12, reduction='sum', **kwarg
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if axis is None: if axis is None:
args['axis'], args['num_axes'] = 0, -1 args['axis'], args['num_axes'] = 0, -1
else: else:
...@@ -284,7 +284,7 @@ def local_response_norm( ...@@ -284,7 +284,7 @@ def local_response_norm(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: %s' % data_format) raise ValueError('Unsupported data format: %s' % data_format)
args['alpha'], args['beta'], args['bias'] = \ args['alpha'], args['beta'], args['bias'] = \
...@@ -345,7 +345,7 @@ def sync_batch_norm( ...@@ -345,7 +345,7 @@ def sync_batch_norm(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['momentum'], args['epsilon'] = float(momentum), float(epsilon) args['momentum'], args['epsilon'] = float(momentum), float(epsilon)
if process_group is None: if process_group is None:
process_group = distributed.get_group() process_group = distributed.get_group()
......
...@@ -22,8 +22,8 @@ from dragon.core.eager import context ...@@ -22,8 +22,8 @@ from dragon.core.eager import context
from dragon.core.eager.tensor import EagerTensor from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.ops import rnn_ops_lib from dragon.core.ops import rnn_ops_lib
from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
from dragon.core.util import nest from dragon.core.util import nest
...@@ -363,7 +363,7 @@ def LSTMCell(inputs, **kwargs): ...@@ -363,7 +363,7 @@ def LSTMCell(inputs, **kwargs):
The **h** and **c**. The **h** and **c**.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
op_lib = rnn_ops_lib.LSTMCell op_lib = rnn_ops_lib.LSTMCell
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib.instantiate().apply(inputs) return op_lib.instantiate().apply(inputs)
......
...@@ -34,8 +34,7 @@ class OpSchema(object): ...@@ -34,8 +34,7 @@ class OpSchema(object):
raise ValueError( raise ValueError(
'The number of <inputs> is {}, ' 'The number of <inputs> is {}, '
'not in range: [min={}, max={}].' 'not in range: [min={}, max={}].'
.format(len(inputs), min_num, max_num) .format(len(inputs), min_num, max_num))
)
def decorated(inner_function): def decorated(inner_function):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
...@@ -49,10 +48,11 @@ class OpSchema(object): ...@@ -49,10 +48,11 @@ class OpSchema(object):
class ArgHelper(object): class ArgHelper(object):
"""Generate the descriptor for dynamic arguments.""" """Generate and parse the descriptor for tensor arguments."""
@classmethod @staticmethod
def desc(cls, name, as_target=True): def desc(name, as_target=True):
"""Add desc for a single argument."""
def decorated(inner_function): def decorated(inner_function):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
def generator(arguments): def generator(arguments):
...@@ -60,23 +60,32 @@ class ArgHelper(object): ...@@ -60,23 +60,32 @@ class ArgHelper(object):
if arg is None: if arg is None:
return arguments return arguments
if types.is_tensor(arg): if types.is_tensor(arg):
if context.executing_eagerly(): ArgHelper._convert_to_desc(arguments, name, arg, as_target)
arguments[name] = arg.get_value().tolist()
return arguments
if as_target:
if 'extra_inputs' not in arguments:
arguments['extra_inputs'] = []
arguments['extra_inputs'].extend([arg])
arguments.pop(name)
arguments[name + '_desc'] = arg.id
return arguments return arguments
kwargs.update({'gen_desc_{}'.format(name): generator}) kwargs.update({'gen_desc_{}'.format(name): generator})
return inner_function(*args, **kwargs) return inner_function(*args, **kwargs)
return decorator.make_decorator(inner_function, wrapper) return decorator.make_decorator(inner_function, wrapper)
return decorated return decorated
@classmethod @staticmethod
def repeated_desc(cls, name, name_v2=None, dtype='int64', as_target=True): def parse(locals):
"""Parse all the arguments into a dict."""
__all__ = locals
kwargs = __all__['kwargs']
del __all__['kwargs']
desc_generators = {}
for k, v in kwargs.items():
if 'gen_desc' in k:
desc_generators[k] = v
for k in desc_generators.keys():
kwargs.pop(k)
for v in desc_generators.values():
__all__ = v(__all__)
return dict(__all__, **kwargs)
@staticmethod
def repeated_desc(name, name_v2=None, dtype='int64', as_target=True):
"""Add desc for a repeated argument."""
def decorated(inner_function): def decorated(inner_function):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
def generator(arguments): def generator(arguments):
...@@ -87,61 +96,50 @@ class ArgHelper(object): ...@@ -87,61 +96,50 @@ class ArgHelper(object):
if name_v2: if name_v2:
arguments.pop(name) arguments.pop(name)
if types.is_tensor(arg): if types.is_tensor(arg):
ArgHelper._convert_to_desc(arguments, key, arg, as_target)
else:
if any([types.is_tensor(ele) for ele in arg]):
ArgHelper._convert_to_descs(arguments, dtype, key, arg, as_target)
else:
arguments[key] = arg
return arguments
kwargs.update({'gen_desc_{}'.format(name): generator})
return inner_function(*args, **kwargs)
return decorator.make_decorator(inner_function, wrapper)
return decorated
@staticmethod
def _convert_to_desc(arguments, name, arg, as_target=False):
"""Convert the argument to a desc."""
if context.executing_eagerly(): if context.executing_eagerly():
arguments[key] = arg.get_value().tolist() arguments[name] = arg.get_value().tolist()
return arguments return arguments
arguments.pop(key)
arguments[key + '_desc'] = arg.id
if as_target: if as_target:
if 'extra_inputs' not in arguments: if 'extra_inputs' not in arguments:
arguments['extra_inputs'] = [] arguments['extra_inputs'] = []
arguments['extra_inputs'] += [arg] arguments['extra_inputs'] += [arg]
else: arguments.pop(name)
has_tensor = False arguments[name + '_desc'] = arg.id
arg = nest.flatten(arg) return arguments
for e in arg:
if types.is_tensor(e): @staticmethod
has_tensor = True def _convert_to_descs(arguments, dtype, name, arg, as_target=False):
break """Convert the argument to a sequence of descs."""
if has_tensor:
if context.executing_eagerly(): if context.executing_eagerly():
for i, e in enumerate(arg): for i, ele in enumerate(arg):
if types.is_tensor(e): if types.is_tensor(ele):
arg[i] = e.get_value().tolist() arg[i] = ele.get_value().tolist()
arguments[key] = arg arguments[name] = arg
else: else:
descs = [] descs = []
for i, e in enumerate(arg): for i, ele in enumerate(arg):
if types.is_tensor(e): if types.is_tensor(ele):
if as_target: if as_target:
if 'extra_inputs' not in arguments: if 'extra_inputs' not in arguments:
arguments['extra_inputs'] = [] arguments['extra_inputs'] = []
arguments['extra_inputs'] += [e] arguments['extra_inputs'] += [ele]
descs.append(e.id) descs.append(ele.id)
else: else:
descs.append(Tensor.convert_to(e, dtype).id) descs.append(Tensor.from_value(ele, dtype, 'DescConst').id)
arguments.pop(key) arguments.pop(name)
arguments[key + '_descs'] = descs arguments[name + '_descs'] = descs
else:
arguments[key] = arg
return arguments
kwargs.update({'gen_desc_{}'.format(name): generator})
return inner_function(*args, **kwargs)
return decorator.make_decorator(inner_function, wrapper)
return decorated
def parse_args(locals):
"""Parse all the arguments into a dict."""
__all__ = locals
kwargs = __all__['kwargs']
del __all__['kwargs']
desc_generators = {}
for k, v in kwargs.items():
if 'gen_desc' in k:
desc_generators[k] = v
for k in desc_generators.keys():
kwargs.pop(k)
for v in desc_generators.values():
__all__ = v(__all__)
return dict(__all__, **kwargs)
...@@ -18,7 +18,6 @@ from dragon.core.eager import context ...@@ -18,7 +18,6 @@ from dragon.core.eager import context
from dragon.core.ops import vision_ops_lib from dragon.core.ops import vision_ops_lib
from dragon.core.ops.utils import ArgHelper from dragon.core.ops.utils import ArgHelper
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args
from dragon.core.util import nest from dragon.core.util import nest
...@@ -39,7 +38,7 @@ def bias_add(inputs, data_format='NCHW', **kwargs): ...@@ -39,7 +38,7 @@ def bias_add(inputs, data_format='NCHW', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: %s' % data_format) raise ValueError('Unsupported data format: %s' % data_format)
op_lib = vision_ops_lib.BiasAdd op_lib = vision_ops_lib.BiasAdd
...@@ -92,7 +91,7 @@ def conv2d( ...@@ -92,7 +91,7 @@ def conv2d(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if padding not in ('VALID', 'SAME', 'SAME_UPPER', 'SAME_LOWER'): if padding not in ('VALID', 'SAME', 'SAME_UPPER', 'SAME_LOWER'):
raise ValueError('Unsupported padding algorithm: %s' % padding) raise ValueError('Unsupported padding algorithm: %s' % padding)
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
...@@ -172,7 +171,7 @@ def conv2d_transpose( ...@@ -172,7 +171,7 @@ def conv2d_transpose(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if padding not in ('VALID', 'SAME', 'SAME_UPPER', 'SAME_LOWER'): if padding not in ('VALID', 'SAME', 'SAME_UPPER', 'SAME_LOWER'):
raise ValueError('Unsupported padding algorithm: %s' % padding) raise ValueError('Unsupported padding algorithm: %s' % padding)
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
...@@ -246,7 +245,7 @@ def depthwise_conv2d( ...@@ -246,7 +245,7 @@ def depthwise_conv2d(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if padding not in ('VALID', 'SAME', 'SAME_UPPER', 'SAME_LOWER'): if padding not in ('VALID', 'SAME', 'SAME_UPPER', 'SAME_LOWER'):
raise ValueError('Unsupported padding algorithm: %s' % padding) raise ValueError('Unsupported padding algorithm: %s' % padding)
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
...@@ -306,7 +305,7 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs): ...@@ -306,7 +305,7 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: %s' % data_format) raise ValueError('Unsupported data format: %s' % data_format)
op_lib = vision_ops_lib.DepthToSpace op_lib = vision_ops_lib.DepthToSpace
...@@ -366,7 +365,7 @@ def pool2d( ...@@ -366,7 +365,7 @@ def pool2d(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['mode'] = mode.upper() args['mode'] = mode.upper()
if args['mode'] not in ('MAX', 'AVG'): if args['mode'] not in ('MAX', 'AVG'):
raise ValueError('Unsupported pooling mode: %s' % mode) raise ValueError('Unsupported pooling mode: %s' % mode)
...@@ -453,7 +452,7 @@ def resize( ...@@ -453,7 +452,7 @@ def resize(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['mode'] = mode.upper() args['mode'] = mode.upper()
if sizes is None and scales is None: if sizes is None and scales is None:
raise ValueError('Specify either <sizes> or <scales>.') raise ValueError('Specify either <sizes> or <scales>.')
...@@ -513,7 +512,7 @@ def roi_align( ...@@ -513,7 +512,7 @@ def roi_align(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['spatial_scale'] = float(spatial_scale) args['spatial_scale'] = float(spatial_scale)
op_lib = vision_ops_lib.RoiAlign op_lib = vision_ops_lib.RoiAlign
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -561,7 +560,7 @@ def roi_pool( ...@@ -561,7 +560,7 @@ def roi_pool(
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
args['spatial_scale'] = float(spatial_scale) args['spatial_scale'] = float(spatial_scale)
op_lib = vision_ops_lib.RoiPool op_lib = vision_ops_lib.RoiPool
if context.executing_eagerly(): if context.executing_eagerly():
...@@ -605,7 +604,7 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs): ...@@ -605,7 +604,7 @@ def space_to_depth(inputs, block_size, data_format='NCHW', **kwargs):
The output tensor. The output tensor.
""" """
args = parse_args(locals()) args = ArgHelper.parse(locals())
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: %s' % data_format) raise ValueError('Unsupported data format: %s' % data_format)
op_lib = vision_ops_lib.SpaceToDepth op_lib = vision_ops_lib.SpaceToDepth
......
...@@ -180,7 +180,7 @@ def roi_align(op_def, context): ...@@ -180,7 +180,7 @@ def roi_align(op_def, context):
node, const_tensors = export_util.translate(**locals()) node, const_tensors = export_util.translate(**locals())
# Make a dummy "batch_indices". # Make a dummy "batch_indices".
batch_indices = helper.from_array( batch_indices = helper.from_array(
numpy.array([1], 'int64'), numpy.zeros((context.blob_shapes[node.input[1]][0],), 'int64'),
context.unique_name(op_def.input[0] + '/roi_align/batch_indices'), context.unique_name(op_def.input[0] + '/roi_align/batch_indices'),
) )
node.input.extend([batch_indices.name]) node.input.extend([batch_indices.name])
......
...@@ -33,6 +33,7 @@ from dragon.vm.tensorflow._api.keras import optimizers ...@@ -33,6 +33,7 @@ from dragon.vm.tensorflow._api.keras import optimizers
from dragon.vm.tensorflow.core.eager.backprop import GradientTape from dragon.vm.tensorflow.core.eager.backprop import GradientTape
from dragon.vm.tensorflow.core.framework.tensor_shape import TensorShape from dragon.vm.tensorflow.core.framework.tensor_shape import TensorShape
from dragon.vm.tensorflow.core.framework.tensor_spec import TensorSpec from dragon.vm.tensorflow.core.framework.tensor_spec import TensorSpec
from dragon.vm.tensorflow.core.module.module import Module
# Functions # Functions
from dragon.vm.tensorflow.core.eager.def_function import function from dragon.vm.tensorflow.core.eager.def_function import function
......
...@@ -13,13 +13,7 @@ from __future__ import absolute_import ...@@ -13,13 +13,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy from dragon.core.ops import init_ops
from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context as eager_context
from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import context
from dragon.core.framework import workspace
def constant(value, dtype=None, shape=None, name='Const'): def constant(value, dtype=None, shape=None, name='Const'):
...@@ -51,39 +45,4 @@ def constant(value, dtype=None, shape=None, name='Const'): ...@@ -51,39 +45,4 @@ def constant(value, dtype=None, shape=None, name='Const'):
""" """
dtype = str(dtype) if dtype else None dtype = str(dtype) if dtype else None
if dtype is not None: return init_ops.constant(value, dtype, shape, name=name)
if isinstance(value, numpy.ndarray):
value = value.astype(dtype)
else:
value = numpy.array(value, dtype)
else:
if not isinstance(value, numpy.ndarray):
value = numpy.array(value)
# Discard the default 64bit types.
if value.dtype == numpy.float64:
value = value.astype(numpy.float32)
elif value.dtype == numpy.int64:
value = value.astype(numpy.int32)
# Determine the shape.
if shape is not None:
if value.size == 1:
# Case 1: Broadcast with scalar value.
scalar = value.flatten()[0]
value = numpy.empty(shape, value.dtype)
value.fill(scalar)
else:
# Case 2: Reshape directly.
value = value.reshape(shape)
# Return a named tensor with value copied.
name = context.get_name_scope() + name
if eager_context.executing_eagerly():
return EagerTensor(value, name=name + ':0')
else:
return TensorRef(
name=workspace.get_workspace().unique_name(
name, ':0', 'Tensor'),
shape=list(value.shape),
dtype=str(value.dtype),
).set_value(value)
...@@ -23,13 +23,20 @@ from dragon.vm.tensorflow.core.framework import constant_op ...@@ -23,13 +23,20 @@ from dragon.vm.tensorflow.core.framework import constant_op
def convert_to_tensor(value, dtype=None, name=None): def convert_to_tensor(value, dtype=None, name=None):
"""Converts the given value to a Tensor. """Convert the given value to a tensor.
Examples:
```python
x = tf.convert_to_tensor([1, 2])
y = tf.constant([1, 2]) # Equivalent
```
Parameters Parameters
---------- ----------
value : Union[number, Sequence, numpy.ndarray] value : Union[number, Sequence, numpy.ndarray]
The value to convert. The value to convert.
dtype : dragon.vm.tensorflow.dtypes.DType, optional dtype : str, optional
The optional data type. The optional data type.
name : str, optional name : str, optional
The Optional name. The Optional name.
...@@ -39,6 +46,10 @@ def convert_to_tensor(value, dtype=None, name=None): ...@@ -39,6 +46,10 @@ def convert_to_tensor(value, dtype=None, name=None):
dragon.Tensor dragon.Tensor
The output tensor. The output tensor.
See Also
--------
`tf.constant(...)`_
""" """
if types.is_tensor(value): if types.is_tensor(value):
return value return value
......
...@@ -26,7 +26,7 @@ from dragon.vm.tensorflow.core.ops import variables ...@@ -26,7 +26,7 @@ from dragon.vm.tensorflow.core.ops import variables
class Module(object): class Module(object):
"""The base class of modules. """The base class of neural network modules.
Inherit this class to design a new module: Inherit this class to design a new module:
...@@ -222,19 +222,22 @@ def flatten_module( ...@@ -222,19 +222,22 @@ def flatten_module(
def valid_identifier(name): def valid_identifier(name):
"""Return **True** if the name can be a python identifier.""" """Return if the name can be a python identifier."""
return bool(_VALID_IDENTIFIER.match(name)) return bool(_VALID_IDENTIFIER.match(name))
def _is_module(obj): def _is_module(obj):
"""Return if the object is a instance of module."""
return isinstance(obj, Module) return isinstance(obj, Module)
def _is_variable(obj): def _is_variable(obj):
"""Return if the object is a variable."""
return isinstance(obj, variables.VariableMetaclass) return isinstance(obj, variables.VariableMetaclass)
def _is_trainable_variable(obj): def _is_trainable_variable(obj):
"""Return if the object is a trainable variable."""
return _is_variable(obj) and getattr(obj, "trainable", False) return _is_variable(obj) and getattr(obj, "trainable", False)
......
...@@ -49,19 +49,20 @@ class Variable(VariableMetaclass, EagerTensor): ...@@ -49,19 +49,20 @@ class Variable(VariableMetaclass, EagerTensor):
dtype = str(dtype) if dtype else None dtype = str(dtype) if dtype else None
self._name = context.get_name_scope() + name + ':0' self._name = context.get_name_scope() + name + ':0'
# Determine th value. # Determine th value.
if isinstance(initial_value, EagerTensor): if isinstance(initial_value, numpy.ndarray):
initial_value = initial_value.numpy() if dtype is None or initial_value.dtype == dtype:
initial_value = initial_value.copy()
elif isinstance(initial_value, EagerTensor):
initial_value = initial_value.get_value()
if dtype is None or initial_value.dtype == dtype:
initial_value = initial_value.copy()
elif isinstance(initial_value, Tensor): elif isinstance(initial_value, Tensor):
initial_value = initial_value.get_value() initial_value = initial_value.get_value()
# Determine the data type. # Determine the data type and shape.
if not isinstance(initial_value, numpy.ndarray): initial_value = numpy.array(initial_value, dtype, copy=False)
initial_value = numpy.array(initial_value, dtype)
elif dtype is not None:
initial_value = initial_value.astype(dtype)
# Determine the tensor shape.
if shape is not None: if shape is not None:
initial_value = initial_value.reshape(shape) initial_value = initial_value.reshape(shape)
self._from_numpy(initial_value, copy=False) self._from_array(initial_value)
@property @property
def trainable(self): def trainable(self):
......
...@@ -58,17 +58,17 @@ class TestFunction(unittest.TestCase): ...@@ -58,17 +58,17 @@ class TestFunction(unittest.TestCase):
"""Test the graph function.""" """Test the graph function."""
@dragon.function(input_signature=[ @dragon.function(input_signature=[
dragon.Tensor(dtype='int32'), dragon.Tensor((1,), dtype='int32'),
dragon.Tensor(dtype='int32'), dragon.Tensor((1,), dtype='int32'),
dragon.Tensor(dtype='int32'), dragon.Tensor((1,), dtype='int32'),
]) ])
def func1(self, a, b, c=0, **kwargs): def func1(self, a, b, c=0, **kwargs):
_ = kwargs _ = kwargs
return a + b + c return a + b + c
def test_create_function(self): def test_create_function(self):
a = dragon.Tensor(dtype='int32').set_value(1) a = dragon.Tensor((), dtype='int32').set_value(1)
b = dragon.Tensor(dtype='int32').set_value(2) b = dragon.Tensor((), dtype='int32').set_value(2)
y = a + 1 y = a + 1
try: try:
dragon.create_function(outputs=y, optimizer=dragon.optimizers.SGD()) dragon.create_function(outputs=y, optimizer=dragon.optimizers.SGD())
...@@ -85,7 +85,7 @@ class TestFunction(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestFunction(unittest.TestCase):
self.assertEqual(int(f()), 3) self.assertEqual(int(f()), 3)
def test_def_function(self): def test_def_function(self):
@dragon.function(input_signature=[dragon.Tensor()]) @dragon.function(input_signature=[dragon.Tensor(None)])
def func2(a, b): def func2(a, b):
return a + b return a + b
self.assertEqual(self.func1([1, 2], [3, 4]).get_value().tolist(), [4, 6]) self.assertEqual(self.func1([1, 2], [3, 4]).get_value().tolist(), [4, 6])
...@@ -109,8 +109,8 @@ class TestFunction(unittest.TestCase): ...@@ -109,8 +109,8 @@ class TestFunction(unittest.TestCase):
_ = optimizer.op_type _ = optimizer.op_type
except KeyError: except KeyError:
pass pass
value = dragon.Tensor(dtype='float32').set_value(1.) value = dragon.Tensor((), dtype='float32').set_value(1.)
grad = dragon.Tensor(dtype='float32').set_value(1.) grad = dragon.Tensor((), dtype='float32').set_value(1.)
optimizer.apply_gradients([(value, grad)]) optimizer.apply_gradients([(value, grad)])
dragon.create_function(optimizer=optimizer)() dragon.create_function(optimizer=optimizer)()
......
...@@ -72,17 +72,17 @@ class TestTensor(unittest.TestCase): ...@@ -72,17 +72,17 @@ class TestTensor(unittest.TestCase):
"""Test the tensor class.""" """Test the tensor class."""
def test_properties(self): def test_properties(self):
a, b = dragon.Tensor(), dragon.EagerTensor(0) a, b = dragon.Tensor(()), dragon.EagerTensor(0)
self.assertEqual(dragon.Tensor().ndim, 0) self.assertEqual(dragon.Tensor(()).ndim, 0)
self.assertEqual(dragon.Tensor(shape=(2,)).ndim, 1) self.assertEqual(dragon.Tensor(shape=(2,)).ndim, 1)
self.assertEqual(dragon.Tensor().shape, None) self.assertEqual(dragon.Tensor(None).shape, None)
self.assertEqual(dragon.Tensor(shape=(2,)).shape, (2,)) self.assertEqual(dragon.Tensor(shape=(2,)).shape, (2,))
self.assertEqual(dragon.Tensor().size, 0) self.assertEqual(dragon.Tensor(None).size, 0)
self.assertEqual(dragon.Tensor(()).size, 1) self.assertEqual(dragon.Tensor(()).size, 1)
self.assertEqual(dragon.Tensor(shape=(2, None)).size, math.inf) self.assertEqual(dragon.Tensor(shape=(2, None)).size, math.inf)
self.assertEqual(dragon.Tensor(shape=(2,)).size, 2) self.assertEqual(dragon.Tensor(shape=(2,)).size, 2)
self.assertEqual(dragon.Tensor().dtype, None) self.assertEqual(dragon.Tensor(None, None).dtype, None)
self.assertEqual(dragon.Tensor(dtype='float32').dtype, 'float32') self.assertEqual(dragon.Tensor(None, dtype='float32').dtype, 'float32')
self.assertEqual(dragon.EagerTensor(shape=(2,)).ndim, 1) self.assertEqual(dragon.EagerTensor(shape=(2,)).ndim, 1)
self.assertEqual(dragon.EagerTensor(shape=(2,)).shape, (2,)) self.assertEqual(dragon.EagerTensor(shape=(2,)).shape, (2,))
self.assertEqual(dragon.EagerTensor(shape=(2,)).size, 2) self.assertEqual(dragon.EagerTensor(shape=(2,)).size, 2)
...@@ -92,7 +92,8 @@ class TestTensor(unittest.TestCase): ...@@ -92,7 +92,8 @@ class TestTensor(unittest.TestCase):
self.assertNotEqual(a.__repr__(), b.__repr__()) self.assertNotEqual(a.__repr__(), b.__repr__())
self.assertNotEqual(b.__repr__(), dragon.EagerTensor((2,)).__repr__()) self.assertNotEqual(b.__repr__(), dragon.EagerTensor((2,)).__repr__())
self.assertEqual(int(a.constant().set_value(1)), 1) self.assertEqual(int(a.constant().set_value(1)), 1)
self.assertEqual(float(dragon.Tensor.convert_to(1)), 1.) self.assertEqual(float(dragon.Tensor.from_value(1)), 1.)
self.assertEqual(float(dragon.EagerTensor.from_value(1)), 1.)
self.assertEqual(int(b.set_value(1)), 1) self.assertEqual(int(b.set_value(1)), 1)
self.assertEqual(float(b), 1.) self.assertEqual(float(b), 1.)
self.assertEqual(int(b.get_value()), 1) self.assertEqual(int(b.get_value()), 1)
...@@ -160,7 +161,7 @@ class TestWorkspace(unittest.TestCase): ...@@ -160,7 +161,7 @@ class TestWorkspace(unittest.TestCase):
w = dragon.Workspace() w = dragon.Workspace()
with w.as_default(): with w.as_default():
v1, v2 = dragon.EagerTensor(1), np.array(2) v1, v2 = dragon.EagerTensor(1), np.array(2)
x = dragon.Tensor(name='test_feed_tensor/x') x = dragon.Tensor((), name='test_feed_tensor/x')
w.feed_tensor(x, v1) w.feed_tensor(x, v1)
self.assertEqual(int(x), 1) self.assertEqual(int(x), 1)
w.feed_tensor(x, v2) w.feed_tensor(x, v2)
...@@ -169,7 +170,7 @@ class TestWorkspace(unittest.TestCase): ...@@ -169,7 +170,7 @@ class TestWorkspace(unittest.TestCase):
def test_merge_form(self): def test_merge_form(self):
w1, w2 = dragon.Workspace(), dragon.Workspace() w1, w2 = dragon.Workspace(), dragon.Workspace()
with w1.as_default(): with w1.as_default():
x = dragon.Tensor(name='test_merge_from/x').set_value(0) x = dragon.Tensor((), name='test_merge_from/x').set_value(0)
w2.merge_from(w1) w2.merge_from(w1)
with w2.as_default(): with w2.as_default():
self.assertEqual(int(x), 0) self.assertEqual(int(x), 0)
......
...@@ -77,9 +77,11 @@ class Tensor(object): ...@@ -77,9 +77,11 @@ class Tensor(object):
if len(args) == 1: if len(args) == 1:
if isinstance(args[0], (list, tuple)): if isinstance(args[0], (list, tuple)):
dtype = kwargs.get('dtype', 'float32') dtype = kwargs.get('dtype', 'float32')
self._from_numpy(numpy.array(args[0], dtype=dtype), copy=False) self._from_array(numpy.array(args[0], dtype))
elif isinstance(args[0], numpy.ndarray): elif isinstance(args[0], numpy.ndarray):
self._from_numpy(args[0], copy=kwargs.get('copy', True)) dtype = kwargs.get('dtype', None)
self._from_array(numpy.array(
args[0], dtype, copy=kwargs.get('copy', True)))
else: else:
if not isinstance(args[0], six.integer_types): if not isinstance(args[0], six.integer_types):
raise ValueError('Excepted an integer as size.') raise ValueError('Excepted an integer as size.')
...@@ -2304,16 +2306,15 @@ class Tensor(object): ...@@ -2304,16 +2306,15 @@ class Tensor(object):
""" """
return self.fill_(0) return self.fill_(0)
def _from_numpy(self, array, copy): def _from_array(self, array):
"""Create impl from the numpy array.""" """Create implementation from the array."""
ws = workspace.get_workspace() ws = workspace.get_workspace()
array = array.copy() if copy else array
self._gc, self._is_leaf = ws.collectors.TENSOR, True self._gc, self._is_leaf = ws.collectors.TENSOR, True
self._impl = ws.create_tensor(self._gc.alloc( self._impl = ws.create_tensor(self._gc.alloc(
context.get_eager_scope())).FromNumpy(array) context.get_eager_scope())).FromNumpy(array)
def _from_shape(self, shape, dtype): def _from_shape(self, shape, dtype):
"""Create impl from the shape and data type.""" """Create implementation from the shape."""
ws = workspace.get_workspace() ws = workspace.get_workspace()
self._gc, self._is_leaf = ws.collectors.TENSOR, True self._gc, self._is_leaf = ws.collectors.TENSOR, True
self._impl = ws.create_tensor(self._gc.alloc( self._impl = ws.create_tensor(self._gc.alloc(
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!