Commit aa2ec8c3 by Ting PAN

Make the consistency for shape and data type

Summary:
This commit enforces the shape and data type inherited from the same metaclass,
which ensures the consistency between different styles.
1 parent 0ab14f30
Showing with 782 additions and 627 deletions
...@@ -15,9 +15,6 @@ dragon ...@@ -15,9 +15,6 @@ dragon
`class Tensor <dragon/Tensor.html>`_ `class Tensor <dragon/Tensor.html>`_
: Tensor abstraction for graph executing. : Tensor abstraction for graph executing.
`class TensorSpec <dragon/TensorSpec.html>`_
: Spec to describe properties of a tensor.
`class Workspace <dragon/Workspace.html>`_ `class Workspace <dragon/Workspace.html>`_
: Sandbox to isolate the resources and computations. : Sandbox to isolate the resources and computations.
...@@ -213,7 +210,6 @@ dragon ...@@ -213,7 +210,6 @@ dragon
dragon/stack dragon/stack
dragon/stop_gradient dragon/stop_gradient
dragon/Tensor dragon/Tensor
dragon/TensorSpec
dragon/tile dragon/tile
dragon/transpose dragon/transpose
dragon/where dragon/where
......
TensorSpec
==========
.. autoclass:: dragon.TensorSpec
__init__
--------
.. automethod:: dragon.TensorSpec.__init__
Properties
----------
dtype
#####
.. autoattribute:: dragon.TensorSpec.dtype
name
####
.. autoattribute:: dragon.TensorSpec.name
shape
#####
.. autoattribute:: dragon.TensorSpec.shape
Methods
-------
is_compatible_with
##################
.. automethod:: dragon.TensorSpec.is_compatible_with
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
...@@ -9,6 +9,9 @@ vm.tensorflow ...@@ -9,6 +9,9 @@ vm.tensorflow
`class GradientTape <tensorflow/GradientTape.html>`_ `class GradientTape <tensorflow/GradientTape.html>`_
: Record the operations for auto differentiation. : Record the operations for auto differentiation.
`class TensorShape <tensorflow/TensorShape.html>`_
: Represent the a sequence of dimensions.
`class TensorSpec <tensorflow/TensorSpec.html>`_ `class TensorSpec <tensorflow/TensorSpec.html>`_
: Spec to describe properties of a tensor. : Spec to describe properties of a tensor.
...@@ -124,6 +127,7 @@ vm.tensorflow ...@@ -124,6 +127,7 @@ vm.tensorflow
tensorflow/slice tensorflow/slice
tensorflow/split tensorflow/split
tensorflow/squeeze tensorflow/squeeze
tensorflow/TensorShape
tensorflow/TensorSpec tensorflow/TensorSpec
tensorflow/transpose tensorflow/transpose
tensorflow/zeros tensorflow/zeros
......
TensorShape
===========
.. autoclass:: dragon.vm.tensorflow.TensorShape
__init__
--------
.. automethod:: dragon.vm.tensorflow.TensorShape.__init__
Properties
----------
dims
####
.. autoattribute:: dragon.vm.tensorflow.TensorShape.dims
ndims
#####
.. autoattribute:: dragon.vm.tensorflow.TensorShape.ndims
rank
####
.. autoattribute:: dragon.vm.tensorflow.TensorShape.rank
Methods
-------
as_list
#######
.. automethod:: dragon.vm.tensorflow.TensorShape.as_list
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
...@@ -26,7 +26,7 @@ shape ...@@ -26,7 +26,7 @@ shape
<style> <style>
h1:before { h1:before {
content: "dragon."; content: "tf.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -7,6 +7,76 @@ __init__ ...@@ -7,6 +7,76 @@ __init__
-------- --------
.. automethod:: dragon.vm.tensorflow.dtypes.DType.__init__ .. automethod:: dragon.vm.tensorflow.dtypes.DType.__init__
Properties
----------
as_numpy_dtype
##############
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.as_numpy_dtype
as_datatype_enum
################
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.as_datatype_enum
base_dtype
##########
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.base_dtype
is_numpy_compatible
###################
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.is_numpy_compatible
is_bool
#######
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.is_bool
is_complex
##########
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.is_complex
is_floating
###########
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.is_floating
is_integer
##########
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.is_integer
is_quantized
############
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.is_quantized
is_unsigned
###########
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.is_unsigned
limits
######
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.limits
max
###
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.max
min
###
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.min
name
####
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.name
real_dtype
##########
.. autoattribute:: dragon.vm.tensorflow.dtypes.DType.real_dtype
Methods
-------
is_compatible_with
##################
.. automethod:: dragon.vm.tensorflow.dtypes.DType.is_compatible_with
.. raw:: html .. raw:: html
<style> <style>
......
...@@ -7,7 +7,10 @@ vm.torch ...@@ -7,7 +7,10 @@ vm.torch
####### #######
`class device <torch/device.html>`_ `class device <torch/device.html>`_
: Represent the device where tensor will be allocated. : Represent the device spec.
`class dtype <torch/device.html>`_
: The basic data type.
`class enable_grad <torch/enable_grad.html>`_ `class enable_grad <torch/enable_grad.html>`_
: Context-manager to enable gradient calculation. : Context-manager to enable gradient calculation.
...@@ -265,6 +268,7 @@ vm.torch ...@@ -265,6 +268,7 @@ vm.torch
torch/cumsum torch/cumsum
torch/device torch/device
torch/div torch/div
torch/dtype
torch/empty torch/empty
torch/enable_grad torch/enable_grad
torch/eq torch/eq
......
dtype
=====
.. autoclass:: dragon.vm.torch.dtype
__init__
--------
.. automethod:: dragon.vm.torch.dtype.__init__
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
...@@ -21,7 +21,7 @@ void ArangeOp<Context>::DoRunWithType() { ...@@ -21,7 +21,7 @@ void ArangeOp<Context>::DoRunWithType() {
// Determine the generating range // Determine the generating range
// Values are in a half-open interval: [start, stop) // Values are in a half-open interval: [start, stop)
auto count = (int64_t)std::round((stop - start) / step); auto count = (int64_t)std::ceil((stop - start) / step);
CHECK_GT(count, 0) << "\nInvalid generating range: " CHECK_GT(count, 0) << "\nInvalid generating range: "
<< "[" << start << ", " << stop << ") with step = " << step << "[" << start << ", " << stop << ") with step = " << step
<< "."; << ".";
......
...@@ -39,7 +39,6 @@ from dragon import vm ...@@ -39,7 +39,6 @@ from dragon import vm
from dragon.core.autograph.tensor import Tensor from dragon.core.autograph.tensor import Tensor
from dragon.core.eager.tensor import EagerTensor from dragon.core.eager.tensor import EagerTensor
from dragon.core.eager.backprop import GradientTape from dragon.core.eager.backprop import GradientTape
from dragon.core.framework.tensor_spec import TensorSpec
from dragon.core.framework.workspace import Workspace from dragon.core.framework.workspace import Workspace
# Function # Function
......
...@@ -254,7 +254,7 @@ class FunctionGuard(object): ...@@ -254,7 +254,7 @@ class FunctionGuard(object):
) )
shape = input_signature[i].shape shape = input_signature[i].shape
dtype = input_signature[i].dtype dtype = input_signature[i].dtype
inputs.append(Tensor(name, shape, dtype).constant()) inputs.append(Tensor(shape, dtype, name).constant())
with context.name_scope('${%d}' % id(self)), eager_context.graph_mode(): with context.name_scope('${%d}' % id(self)), eager_context.graph_mode():
returns = nest.flatten(self._python_function(*inputs)) returns = nest.flatten(self._python_function(*inputs))
outputs, dummies = [], [] outputs, dummies = [], []
...@@ -328,8 +328,8 @@ def function(func=None, input_signature=None): ...@@ -328,8 +328,8 @@ def function(func=None, input_signature=None):
```python ```python
@dragon.function(input_signature=[ @dragon.function(input_signature=[
dragon.TensorSpec(shape=[], dtype='float32'), dragon.Tensor(shape=[], dtype='float32'),
dragon.TensorSpec(shape=[], dtype='float32') dragon.Tensor(shape=[], dtype='float32'),
]) ])
def foo(x, y): def foo(x, y):
return dragon.math.add([x + y, x]) return dragon.math.add([x + y, x])
...@@ -341,8 +341,8 @@ def function(func=None, input_signature=None): ...@@ -341,8 +341,8 @@ def function(func=None, input_signature=None):
---------- ----------
func : callable, optional func : callable, optional
The function to be compiled. The function to be compiled.
input_signature : Sequence[dragon.TensorSpec], optional input_signature : Sequence[dragon.Tensor], optional
The specs to hint the input info. The tensors to hint the input info.
Returns Returns
------- -------
......
...@@ -20,9 +20,10 @@ import os ...@@ -20,9 +20,10 @@ import os
from dragon.core.autograph import grad_maker from dragon.core.autograph import grad_maker
from dragon.core.autograph.op_def import OpDef from dragon.core.autograph.op_def import OpDef
from dragon.core.autograph.op_def import OpInfo from dragon.core.autograph.op_def import OpInfo
from dragon.core.autograph.tensor import Tensor from dragon.core.autograph.tensor import TensorRef
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.framework import proto_util from dragon.core.framework import proto_util
from dragon.core.framework import types
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.proto import dragon_pb2 from dragon.core.proto import dragon_pb2
from dragon.core.util import logging from dragon.core.util import logging
...@@ -107,9 +108,7 @@ def add_update_defs(graph_def, optimizer): ...@@ -107,9 +108,7 @@ def add_update_defs(graph_def, optimizer):
name=OpDef.get_name(), name=OpDef.get_name(),
operation='MEAN', operation='MEAN',
communication='ALLREDUCE', communication='ALLREDUCE',
**process_group.arguments **process_group.arguments))
)
)
graph_def.op.extend(update_defs) graph_def.op.extend(update_defs)
...@@ -147,14 +146,11 @@ class Function(object): ...@@ -147,14 +146,11 @@ class Function(object):
if givens is not None: if givens is not None:
name_dict = {} name_dict = {}
for k, v in givens.items(): for k, v in givens.items():
if isinstance(v, Tensor): if types.is_symbolic_tensor(v):
name_dict[k.id] = v.id name_dict[k.id] = v.id
op_info.merge_from(v) op_info.merge_from(v)
else: else:
raise ValueError( raise ValueError('Excepted a Tensor, got {}.'.format(type(v).__name__))
'Excepted a Tensor, '
'got {}.'.format(type(v).__name__)
)
# Update the original defs. # Update the original defs.
op_info = copy.deepcopy(op_info) op_info = copy.deepcopy(op_info)
for k in op_info._defs.keys(): for k in op_info._defs.keys():
...@@ -257,8 +253,8 @@ class Function(object): ...@@ -257,8 +253,8 @@ class Function(object):
The self. The self.
""" """
self.outputs = [Tensor(name) for name in graph_def.output] self.outputs = [TensorRef(name) for name in graph_def.output]
self.inputs = [Tensor(name).constant() for name in graph_def.input] self.inputs = [TensorRef(name).constant() for name in graph_def.input]
# Fill with all known graph elements. # Fill with all known graph elements.
add_device_option(graph_def) add_device_option(graph_def)
...@@ -293,7 +289,7 @@ def create_function(inputs=None, outputs=None, givens=None, optimizer=None): ...@@ -293,7 +289,7 @@ def create_function(inputs=None, outputs=None, givens=None, optimizer=None):
Tensors that catch any operators can be used to create a graph: Tensors that catch any operators can be used to create a graph:
```python ```python
x = dragon.Tensor('x', dtype='float32').constant() x = dragon.Tensor(dtype='float32').constant()
y = x * 2 y = x * 2
f = dragon.create_function(outputs=y) f = dragon.create_function(outputs=y)
``` ```
...@@ -315,20 +311,20 @@ def create_function(inputs=None, outputs=None, givens=None, optimizer=None): ...@@ -315,20 +311,20 @@ def create_function(inputs=None, outputs=None, givens=None, optimizer=None):
Specify ``givens`` to substitute tensors before creating: Specify ``givens`` to substitute tensors before creating:
```python ```python
x = dragon.Tensor('x', dtype='float32').constant() x = dragon.Tensor(dtype='float32').constant()
y = x * 2 y = x * 2
foo = dragon.create_function(outputs=y) foo = dragon.create_function(outputs=y)
# "bar" takes "x2" as input, and also writes to "y" # "bar" takes "x2" as input, and also writes to "y"
x2 = dragon.Tensor('x2', dtype='float32').constant() x2 = dragon.Tensor(dtype='float32').constant()
bar = dragon.create_function(outputs=y, givens={x: x2}) bar = dragon.create_function(outputs=y, givens={x: x2})
``` ```
Specify ``optimizer`` to make a graph applying parameter updates: Specify ``optimizer`` to make a graph applying parameter updates:
```python ```python
x = dragon.Tensor('x', dtype='float32').set_value(1) x = dragon.Tensor(dtype='float32').set_value(1)
x_grad = dragon.Tensor('x_grad', dtype='float32').set_value(1) x_grad = dragon.Tensor(dtype='float32').set_value(1)
optimizer = dragon.optimizers.SGD(base_lr=0.01) optimizer = dragon.optimizers.SGD(base_lr=0.01)
optimizer.apply_gradients(values_and_grads=[(x, x_grad)]) optimizer.apply_gradients(values_and_grads=[(x, x_grad)])
......
...@@ -41,7 +41,7 @@ def arange_spec(args, inputs, outputs): ...@@ -41,7 +41,7 @@ def arange_spec(args, inputs, outputs):
else: else:
start, stop, step = slice_args start, stop, step = slice_args
try: try:
outputs[0].shape = [int(round((stop - start) / step))] outputs[0].shape = (int(math.ceil((stop - start) / step)),)
except TypeError: except TypeError:
pass pass
return outputs return outputs
...@@ -53,33 +53,40 @@ def arg_reduce_spec(args, inputs, outputs): ...@@ -53,33 +53,40 @@ def arg_reduce_spec(args, inputs, outputs):
axis, top_k = args['axis'], args['top_k'] axis, top_k = args['axis'], args['top_k']
if args['keep_dims']: if args['keep_dims']:
if axis is None: if axis is None:
outputs[0].shape = [1] outputs[0].shape = (1,)
else: else:
try: try:
outputs[0].shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
if axis < len(outputs[0].shape): out_shape[axis] = top_k
outputs[0].shape[axis] = top_k outputs[0].shape = out_shape
except TypeError: except (TypeError, IndexError):
pass pass
else: else:
if axis is None: if axis is None:
if top_k > 1: if top_k > 1:
outputs[0].shape = [top_k] outputs[0].shape = (top_k,)
else: else:
outputs[0].shape = [] outputs[0].shape = ()
else: else:
try: try:
outputs[0].shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
if axis < len(outputs[0].shape): if axis < len(out_shape):
if top_k > 1: if top_k > 1:
outputs[0].shape[axis] = top_k out_shape[axis] = top_k
else: else:
del outputs[0].shape[axis] del out_shape[axis]
except TypeError: outputs[0].shape = out_shape
except (TypeError, IndexError):
pass pass
return outputs return outputs
@register(['Assign', 'MaskedAssign'])
def assign_spec(args, inputs, outputs):
_ = locals()
return outputs
def binary_shape_spec(inputs, outputs): def binary_shape_spec(inputs, outputs):
if inputs[0].shape is None or inputs[1].shape is None: if inputs[0].shape is None or inputs[1].shape is None:
return outputs return outputs
...@@ -158,7 +165,7 @@ def concat_spec(args, inputs, outputs): ...@@ -158,7 +165,7 @@ def concat_spec(args, inputs, outputs):
out_shape = None out_shape = None
for input in inputs: for input in inputs:
if out_shape is None and input.shape is not None: if out_shape is None and input.shape is not None:
out_shape = input.shape[:] out_shape = list(input.shape[:])
try: try:
for i in range(len(out_shape)): for i in range(len(out_shape)):
for input in inputs: for input in inputs:
...@@ -188,7 +195,7 @@ def conv_spec(args, inputs, outputs): ...@@ -188,7 +195,7 @@ def conv_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
out_shape = None out_shape = None
try: try:
out_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
num_axes = len(out_shape) - 2 num_axes = len(out_shape) - 2
channel_axis = 1 if args['data_format'] == 'NCHW' else -1 channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1 spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
...@@ -212,7 +219,7 @@ def conv_spec(args, inputs, outputs): ...@@ -212,7 +219,7 @@ def conv_spec(args, inputs, outputs):
out_size = None out_size = None
out_shape[i + spatial_axis] = out_size out_shape[i + spatial_axis] = out_size
except (TypeError, IndexError): except (TypeError, IndexError):
pass out_shape = None
outputs[0].shape = out_shape outputs[0].shape = out_shape
return outputs return outputs
...@@ -222,7 +229,7 @@ def conv_transpose_spec(args, inputs, outputs): ...@@ -222,7 +229,7 @@ def conv_transpose_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
out_shape = None out_shape = None
try: try:
out_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
num_axes = len(out_shape) - 2 num_axes = len(out_shape) - 2
channel_axis = 1 if args['data_format'] == 'NCHW' else -1 channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1 spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
...@@ -251,7 +258,7 @@ def conv_transpose_spec(args, inputs, outputs): ...@@ -251,7 +258,7 @@ def conv_transpose_spec(args, inputs, outputs):
out_size = None out_size = None
out_shape[i + spatial_axis] = out_size out_shape[i + spatial_axis] = out_size
except (TypeError, IndexError): except (TypeError, IndexError):
pass out_shape = None
outputs[0].shape = out_shape outputs[0].shape = out_shape
return outputs return outputs
...@@ -261,7 +268,7 @@ def depth_to_space_spec(args, inputs, outputs): ...@@ -261,7 +268,7 @@ def depth_to_space_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
try: try:
bs = args['block_size'] bs = args['block_size']
out_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
num_axes = len(out_shape) - 2 num_axes = len(out_shape) - 2
if len(out_shape) < 3: if len(out_shape) < 3:
return outputs return outputs
...@@ -326,7 +333,7 @@ def expand_spec(args, inputs, outputs): ...@@ -326,7 +333,7 @@ def expand_spec(args, inputs, outputs):
if shape is None: if shape is None:
return outputs return outputs
try: try:
in_shape, out_shape = inputs[0].shape[:], shape[:] in_shape, out_shape = list(inputs[0].shape[:]), list(shape[:])
if len(shape) < len(in_shape): if len(shape) < len(in_shape):
num_keep = len(in_shape) - len(shape) num_keep = len(in_shape) - len(shape)
out_shape = in_shape[:num_keep] + out_shape out_shape = in_shape[:num_keep] + out_shape
...@@ -347,7 +354,7 @@ def expand_dims_spec(args, inputs, outputs): ...@@ -347,7 +354,7 @@ def expand_dims_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
axes = [] if args['axes'] is None else args['axes'] axes = [] if args['axes'] is None else args['axes']
try: try:
out_shape = inputs[0].shape[:] + [0] * len(axes) out_shape = list(inputs[0].shape[:]) + [0] * len(axes)
out_rank = len(out_shape) out_rank = len(out_shape)
for axis in axes: for axis in axes:
while axis < 0: while axis < 0:
...@@ -364,7 +371,7 @@ def expand_dims_spec(args, inputs, outputs): ...@@ -364,7 +371,7 @@ def expand_dims_spec(args, inputs, outputs):
break break
out_shape[i] = inputs[0].shape[j] out_shape[i] = inputs[0].shape[j]
j += 1 j += 1
outputs[0].shape = list(filter(lambda x: x != 0, out_shape)) outputs[0].shape = tuple(filter(lambda x: x != 0, out_shape))
except TypeError: except TypeError:
pass pass
return outputs return outputs
...@@ -402,7 +409,7 @@ def flatten_spec(args, inputs, outputs): ...@@ -402,7 +409,7 @@ def flatten_spec(args, inputs, outputs):
else: else:
out_shape = None out_shape = None
try: try:
in_shape = inputs[0].shape[:] in_shape = list(inputs[0].shape[:])
if keep_axes is not None: if keep_axes is not None:
if len(in_shape) <= keep_axes: if len(in_shape) <= keep_axes:
out_shape[:len(in_shape)] = in_shape out_shape[:len(in_shape)] = in_shape
...@@ -422,10 +429,7 @@ def flatten_spec(args, inputs, outputs): ...@@ -422,10 +429,7 @@ def flatten_spec(args, inputs, outputs):
num_flatten = math_util.prod(in_shape[axis:axis + num_axes]) num_flatten = math_util.prod(in_shape[axis:axis + num_axes])
except TypeError: except TypeError:
num_flatten = None num_flatten = None
out_shape = \ out_shape = in_shape[: axis] + [num_flatten] + in_shape[axis + num_axes:]
in_shape[: axis] + \
[num_flatten] + \
in_shape[axis + num_axes:]
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
outputs[0].shape = out_shape outputs[0].shape = out_shape
...@@ -441,7 +445,7 @@ def fully_connected_spec(args, inputs, outputs): ...@@ -441,7 +445,7 @@ def fully_connected_spec(args, inputs, outputs):
axis += len(inputs[0].shape) axis += len(inputs[0].shape)
except TypeError: except TypeError:
return outputs return outputs
outputs[0].shape = [None] * (axis + 1) out_shape = [None] * (axis + 1)
if out_channels is None: if out_channels is None:
try: try:
if args['transW']: if args['transW']:
...@@ -450,11 +454,12 @@ def fully_connected_spec(args, inputs, outputs): ...@@ -450,11 +454,12 @@ def fully_connected_spec(args, inputs, outputs):
out_channels = inputs[1].shape[1] out_channels = inputs[1].shape[1]
except (TypeError, IndexError): except (TypeError, IndexError):
out_channels = None out_channels = None
outputs[0].shape[axis] = out_channels
try: try:
outputs[0].shape[:axis] = inputs[0].shape[:axis] out_shape[axis] = out_channels
except TypeError: out_shape[:axis] = inputs[0].shape[:axis]
except (TypeError, IndexError):
pass pass
outputs[0].shape = out_shape
return outputs return outputs
...@@ -467,11 +472,12 @@ def channel_normalize_spec(args, inputs, outputs): ...@@ -467,11 +472,12 @@ def channel_normalize_spec(args, inputs, outputs):
try: try:
if perm is None: if perm is None:
perm = list(range((len(inputs[0].shape)))) perm = list(range((len(inputs[0].shape))))
outputs[0].shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
for i, axis in enumerate(perm): for i, axis in enumerate(perm):
outputs[0].shape[i] = inputs[0].shape[axis] out_shape[i] = inputs[0].shape[axis]
except (TypeError, IndexError): except (TypeError, IndexError):
outputs[0].shape = None out_shape = None
outputs[0].shape = out_shape
return outputs return outputs
...@@ -512,7 +518,7 @@ def is_spec(args, inputs, outputs): ...@@ -512,7 +518,7 @@ def is_spec(args, inputs, outputs):
def masked_select_spec(args, inputs, outputs): def masked_select_spec(args, inputs, outputs):
_ = locals() _ = locals()
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
outputs[0].shape = [None] outputs[0].shape = (None,)
return outputs return outputs
...@@ -522,8 +528,8 @@ def matmul_spec(args, inputs, outputs): ...@@ -522,8 +528,8 @@ def matmul_spec(args, inputs, outputs):
ta, tb = args['transA'], args['transB'] ta, tb = args['transA'], args['transB']
out_shape = None out_shape = None
try: try:
b_shape = inputs[1].shape[:] b_shape = list(inputs[1].shape[:])
a_shape = out_shape = inputs[0].shape[:] a_shape = out_shape = list(inputs[0].shape[:])
out_shape[-2] = a_shape[-1] if ta else a_shape[-2] out_shape[-2] = a_shape[-1] if ta else a_shape[-2]
out_shape[-1] = b_shape[-2] if tb else b_shape[-1] out_shape[-1] = b_shape[-2] if tb else b_shape[-1]
except TypeError: except TypeError:
...@@ -538,7 +544,7 @@ def moments_spec(args, inputs, outputs): ...@@ -538,7 +544,7 @@ def moments_spec(args, inputs, outputs):
inputs[0].dtype if inputs[0].dtype == 'float64' else 'float32' inputs[0].dtype if inputs[0].dtype == 'float64' else 'float32'
axes, keep_dims = args['axes'], args['keep_dims'] axes, keep_dims = args['axes'], args['keep_dims']
try: try:
out_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
for axis in axes: for axis in axes:
if axis < len(out_shape): if axis < len(out_shape):
out_shape[axis] = 1 out_shape[axis] = 1
...@@ -550,7 +556,7 @@ def moments_spec(args, inputs, outputs): ...@@ -550,7 +556,7 @@ def moments_spec(args, inputs, outputs):
out_shape = squeezed_shape out_shape = squeezed_shape
except TypeError: except TypeError:
out_shape = None out_shape = None
outputs[0].shape = outputs[1].shape = out_shape if axes else [] outputs[0].shape = outputs[1].shape = out_shape if axes else ()
return outputs return outputs
...@@ -570,7 +576,7 @@ def non_zero_spec(args, inputs, outputs): ...@@ -570,7 +576,7 @@ def non_zero_spec(args, inputs, outputs):
_ = locals() _ = locals()
outputs[0].dtype = 'int64' outputs[0].dtype = 'int64'
try: try:
outputs[0].shape = [None, len(inputs[0].shape)] outputs[0].shape = (None, len(inputs[0].shape))
except TypeError: except TypeError:
pass pass
return outputs return outputs
...@@ -580,8 +586,7 @@ def non_zero_spec(args, inputs, outputs): ...@@ -580,8 +586,7 @@ def non_zero_spec(args, inputs, outputs):
def one_hot_spec(args, inputs, outputs): def one_hot_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
try: try:
outputs[0].shape = inputs[0].shape[:] outputs[0].shape = inputs[0].shape[:] + (args['depth'],)
outputs[0].shape.append(args['depth'])
except TypeError: except TypeError:
pass pass
return outputs return outputs
...@@ -593,14 +598,14 @@ def pad_spec(args, inputs, outputs): ...@@ -593,14 +598,14 @@ def pad_spec(args, inputs, outputs):
pads, num_dims = args['pads'], len(args['pads']) // 2 pads, num_dims = args['pads'], len(args['pads']) // 2
out_shape = None out_shape = None
try: try:
out_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
for i in range(num_dims): for i in range(num_dims):
if i < len(out_shape): if i < len(out_shape):
try: try:
out_shape[i] += (pads[i] + pads[i + num_dims]) out_shape[i] += (pads[i] + pads[i + num_dims])
except TypeError: except TypeError:
out_shape[i] = None out_shape[i] = None
except TypeError: except (TypeError, IndexError):
pass pass
outputs[0].shape = out_shape outputs[0].shape = out_shape
return outputs return outputs
...@@ -654,22 +659,22 @@ def reduce_spec(args, inputs, outputs): ...@@ -654,22 +659,22 @@ def reduce_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
axes, keep_dims = args['axes'], args['keep_dims'] axes, keep_dims = args['axes'], args['keep_dims']
if axes is None: if axes is None:
output_shape = [] outputs[0].shape = ()
else: else:
try: try:
output_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
for axis in axes: for axis in axes:
if axis < len(output_shape): if axis < len(out_shape):
output_shape[axis] = 1 out_shape[axis] = 1
if not keep_dims: if not keep_dims:
squeezed_shape = [] squeezed_shape = []
for d in output_shape: for d in out_shape:
if d != 1: if d != 1:
squeezed_shape.append(d) squeezed_shape.append(d)
output_shape = squeezed_shape out_shape = squeezed_shape
except TypeError: outputs[0].shape = out_shape
output_shape = None except (TypeError, IndexError):
outputs[0].shape = output_shape pass
return outputs return outputs
...@@ -680,19 +685,20 @@ def repeat_spec(args, inputs, outputs): ...@@ -680,19 +685,20 @@ def repeat_spec(args, inputs, outputs):
if axis is None: if axis is None:
try: try:
num_elements = math_util.prod(inputs[0].shape[:]) num_elements = math_util.prod(inputs[0].shape[:])
outputs[0].shape = [num_elements * repeats] outputs[0].shape = (num_elements * repeats,)
except TypeError: except TypeError:
outputs[0].shape = [None] outputs[0].shape = (None,)
else: else:
try: try:
outputs[0].shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
except TypeError: except TypeError:
return outputs return outputs
if axis < len(outputs[0].shape): if axis < len(out_shape):
try: try:
outputs[0].shape[axis] *= repeats out_shape[axis] *= repeats
except TypeError: except TypeError:
outputs[0].shape[axis] = None out_shape[axis] = None
outputs[0].shape = out_shape
return outputs return outputs
...@@ -739,29 +745,29 @@ def resize_spec(args, inputs, outputs): ...@@ -739,29 +745,29 @@ def resize_spec(args, inputs, outputs):
'scales_descs' in args: 'scales_descs' in args:
return outputs return outputs
try: try:
out_dims = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
num_axes = len(out_dims) - 2 num_axes = len(out_shape) - 2
axis = len(out_dims) - 2 if args['data_format'] == 'NCHW' else 1 axis = len(out_shape) - 2 if args['data_format'] == 'NCHW' else 1
try: try:
for i in range(num_axes): for i in range(num_axes):
j = axis + i j = axis + i
if args['sizes'] is not None: if args['sizes'] is not None:
if len(args['sizes']) == 1: if len(args['sizes']) == 1:
out_dims[j] = args['sizes'][0] out_shape[j] = args['sizes'][0]
elif len(args['sizes']) == num_axes: elif len(args['sizes']) == num_axes:
out_dims[j] = args['sizes'][i] out_shape[j] = args['sizes'][i]
else: else:
out_dims[j] = args['sizes'][j] out_shape[j] = args['sizes'][j]
elif args['scales'] is not None: elif args['scales'] is not None:
if len(args['scales']) == 1: if len(args['scales']) == 1:
out_dims[j] = int(out_dims[j] * args['scales'][0]) out_shape[j] = int(out_shape[j] * args['scales'][0])
elif len(args['scales']) == num_axes: elif len(args['scales']) == num_axes:
out_dims[j] = int(out_dims[j] * args['scales'][i]) out_shape[j] = int(out_shape[j] * args['scales'][i])
else: else:
out_dims[j] = int(out_dims[j] * args['sizes'][j]) out_shape[j] = int(out_shape[j] * args['sizes'][j])
except IndexError: except IndexError:
return outputs return outputs
outputs[0].shape = out_dims outputs[0].shape = out_shape
except TypeError: except TypeError:
pass pass
return outputs return outputs
...@@ -773,7 +779,7 @@ def roi_pool_spec(args, inputs, outputs): ...@@ -773,7 +779,7 @@ def roi_pool_spec(args, inputs, outputs):
pool_h, pool_w = args['pooled_h'], args['pooled_w'] pool_h, pool_w = args['pooled_h'], args['pooled_w']
out_shape = None out_shape = None
try: try:
out_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
out_shape[2:4] = pool_h, pool_w out_shape[2:4] = pool_h, pool_w
try: try:
out_shape[0] = inputs[1].shape[0] out_shape[0] = inputs[1].shape[0]
...@@ -833,11 +839,12 @@ def softmax_loss_spec(args, inputs, outputs): ...@@ -833,11 +839,12 @@ def softmax_loss_spec(args, inputs, outputs):
outputs[0].dtype = 'float32' outputs[0].dtype = 'float32'
axis, reduction = args['axis'], args['reduction'] axis, reduction = args['axis'], args['reduction']
if reduction != 'NONE': if reduction != 'NONE':
outputs[0].shape = [] outputs[0].shape = ()
else: else:
try: try:
outputs[0].shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
outputs[0].shape.pop(axis) out_shape.pop(axis)
outputs[0].shape = out_shape
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
return outputs return outputs
...@@ -848,7 +855,7 @@ def space_to_depth_spec(args, inputs, outputs): ...@@ -848,7 +855,7 @@ def space_to_depth_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
try: try:
bs = args['block_size'] bs = args['block_size']
out_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
num_axes = len(out_shape) - 2 num_axes = len(out_shape) - 2
if len(out_shape) < 3: if len(out_shape) < 3:
return outputs return outputs
...@@ -885,14 +892,14 @@ def split_spec(args, inputs, outputs): ...@@ -885,14 +892,14 @@ def split_spec(args, inputs, outputs):
try: try:
if axis >= len(inputs[0].shape[:]): if axis >= len(inputs[0].shape[:]):
return outputs return outputs
outputs[i].shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
except TypeError: except TypeError:
return outputs return outputs
if size_splits is not None: if size_splits is not None:
try: try:
outputs[i].shape[axis] = size_splits[i] out_shape[axis] = size_splits[i]
except IndexError: except IndexError:
outputs[i].shape[axis] = None return outputs
elif slice_points is not None: elif slice_points is not None:
try: try:
if i < len(outputs) - 1: if i < len(outputs) - 1:
...@@ -900,14 +907,18 @@ def split_spec(args, inputs, outputs): ...@@ -900,14 +907,18 @@ def split_spec(args, inputs, outputs):
slice_offset += slice_dim slice_offset += slice_dim
else: else:
slice_dim = inputs[0].shape[axis] - slice_offset slice_dim = inputs[0].shape[axis] - slice_offset
out_shape[axis] = slice_dim
except (TypeError, IndexError): except (TypeError, IndexError):
slice_dim = None return outputs
outputs[i].shape[axis] = slice_dim
else: else:
try: try:
outputs[i].shape[axis] //= num_outputs slice_dim = (out_shape[axis] + num_outputs - 1) // num_outputs
except TypeError: if i == num_outputs - 1:
outputs[i].shape[axis] = None slice_dim = out_shape[axis] - slice_dim * (num_outputs - 1)
out_shape[axis] = slice_dim
except (TypeError, IndexError):
return outputs
outputs[i].shape = out_shape
return outputs return outputs
...@@ -941,7 +952,7 @@ def stack_spec(args, inputs, outputs): ...@@ -941,7 +952,7 @@ def stack_spec(args, inputs, outputs):
out_shape = None out_shape = None
for input in inputs: for input in inputs:
if out_shape is None and input.shape is not None: if out_shape is None and input.shape is not None:
out_shape = input.shape[:] out_shape = list(input.shape[:])
try: try:
for i in range(len(out_shape)): for i in range(len(out_shape)):
for input in inputs: for input in inputs:
...@@ -971,7 +982,7 @@ def tile_spec(args, inputs, outputs): ...@@ -971,7 +982,7 @@ def tile_spec(args, inputs, outputs):
repeats = args['repeats'] repeats = args['repeats']
if repeats is not None: if repeats is not None:
try: try:
out_shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
for i, size in enumerate(repeats): for i, size in enumerate(repeats):
if i < len(out_shape): if i < len(out_shape):
try: try:
...@@ -993,9 +1004,10 @@ def transpose_spec(args, inputs, outputs): ...@@ -993,9 +1004,10 @@ def transpose_spec(args, inputs, outputs):
try: try:
if perm is None: if perm is None:
perm = list(range(((len(inputs[0].shape)) - 1), -1, -1)) perm = list(range(((len(inputs[0].shape)) - 1), -1, -1))
outputs[0].shape = inputs[0].shape[:] out_shape = list(inputs[0].shape[:])
for i, axis in enumerate(perm): for i, axis in enumerate(perm):
outputs[0].shape[i] = inputs[0].shape[axis] out_shape[i] = inputs[0].shape[axis]
outputs[0].shape = out_shape
except (TypeError, IndexError): except (TypeError, IndexError):
outputs[0].shape = None outputs[0].shape = None
return outputs return outputs
......
...@@ -27,17 +27,17 @@ from dragon.core.util import nest ...@@ -27,17 +27,17 @@ from dragon.core.util import nest
class Tensor(types.TensorMetaclass): class Tensor(types.TensorMetaclass):
"""Tensor abstraction for graph executing.""" """Tensor abstraction for graph executing."""
def __init__(self, name=None, shape=None, dtype=None): def __init__(self, shape=None, dtype=None, name=None):
"""Create a ``Tensor``. """Create a ``Tensor``.
Parameters Parameters
---------- ----------
name : str, optional shape : Sequence[int], optional
The optional tensor name.
shape : sequence, optional
The optional tensor shape. The optional tensor shape.
dtype : str, optional dtype : str, optional
The optional data type. The optional data type.
name : str, optional
The optional tensor name.
""" """
self._op, self._grad = None, None self._op, self._grad = None, None
...@@ -66,7 +66,7 @@ class Tensor(types.TensorMetaclass): ...@@ -66,7 +66,7 @@ class Tensor(types.TensorMetaclass):
The data type to set. The data type to set.
""" """
self._dtype = value self._dtype = str(value) if value else value
@property @property
def id(self): def id(self):
...@@ -131,8 +131,8 @@ class Tensor(types.TensorMetaclass): ...@@ -131,8 +131,8 @@ class Tensor(types.TensorMetaclass):
Returns Returns
------- -------
Sequence[int] Tuple[int]
The shape. The tensor shape.
""" """
return self._shape return self._shape
...@@ -150,10 +150,9 @@ class Tensor(types.TensorMetaclass): ...@@ -150,10 +150,9 @@ class Tensor(types.TensorMetaclass):
if value is not None: if value is not None:
if not nest.is_sequence(value): if not nest.is_sequence(value):
raise TypeError( raise TypeError(
'The <shape> should be a Sequence. ' 'The <shape> should be a sequence. Got {}.'
'Got {}.'.format(type(value)) .format(type(value).__name__))
) self._shape = tuple(nest.flatten(value))
self._shape = nest.flatten(value)
else: else:
self._shape = value self._shape = value
...@@ -475,7 +474,7 @@ class Tensor(types.TensorMetaclass): ...@@ -475,7 +474,7 @@ class Tensor(types.TensorMetaclass):
Parameters Parameters
---------- ----------
item : Union[int, slice, dragon.Tensor] item : Union[slice, int, dragon.Tensor]
The index. The index.
Returns Returns
...@@ -643,7 +642,7 @@ class Tensor(types.TensorMetaclass): ...@@ -643,7 +642,7 @@ class Tensor(types.TensorMetaclass):
Parameters Parameters
---------- ----------
key : Union[int, slice, dragon.Tensor] key : Union[slice, int, dragon.Tensor]
The index. The index.
value : Union[dragon.Tensor, number] value : Union[dragon.Tensor, number]
The value to set. The value to set.
...@@ -685,6 +684,6 @@ class TensorRef(object): ...@@ -685,6 +684,6 @@ class TensorRef(object):
"""Create a reference not involved with name scope.""" """Create a reference not involved with name scope."""
def __new__(cls, name, shape=None, dtype=None): def __new__(cls, name, shape=None, dtype=None):
tensor = Tensor('', shape=shape, dtype=dtype) tensor_ref = Tensor(shape=shape, dtype=dtype, name='')
tensor._name = name tensor_ref._name = name
return tensor return tensor_ref
...@@ -144,15 +144,15 @@ class EagerTensor(Tensor): ...@@ -144,15 +144,15 @@ class EagerTensor(Tensor):
@property @property
def shape(self): def shape(self):
"""Return the shape of this tensor. """Return tensor shape.
Returns Returns
------- -------
Sequence[int] Tuple[int]
The shape. The tensor shape.
""" """
return self._impl.dims return tuple(self._impl.dims)
@shape.setter @shape.setter
def shape(self, value): def shape(self, value):
...@@ -451,7 +451,7 @@ class EagerTensor(Tensor): ...@@ -451,7 +451,7 @@ class EagerTensor(Tensor):
Parameters Parameters
---------- ----------
item : Union[int, slice, dragon.EagerTensor] item : Union[slice, int, dragon.EagerTensor]
The index. The index.
Returns Returns
...@@ -668,7 +668,7 @@ class EagerTensor(Tensor): ...@@ -668,7 +668,7 @@ class EagerTensor(Tensor):
Parameters Parameters
---------- ----------
key : Union[int, slice, dragon.EagerTensor] key : Union[slice, int, dragon.EagerTensor]
The index. The index.
value : Union[dragon.EagerTensor, number] value : Union[dragon.EagerTensor, number]
The value to set. The value to set.
......
...@@ -36,7 +36,7 @@ if sys.version_info >= (3, 0): ...@@ -36,7 +36,7 @@ if sys.version_info >= (3, 0):
argument.i = value argument.i = value
elif type(value) is bytes: elif type(value) is bytes:
argument.s = value argument.s = value
elif type(value) is str: elif isinstance(value, str):
argument.s = str.encode(value) argument.s = str.encode(value)
elif isinstance(value, Message): elif isinstance(value, Message):
argument.s = value.SerializeToString() argument.s = value.SerializeToString()
...@@ -63,7 +63,7 @@ else: ...@@ -63,7 +63,7 @@ else:
argument.f = value argument.f = value
elif type(value) in (bool, int, long, numpy.int64): elif type(value) in (bool, int, long, numpy.int64):
argument.i = value argument.i = value
elif type(value) is str: elif isinstance(value, str):
argument.s = value argument.s = value
elif type(value) is unicode: elif type(value) is unicode:
argument.s = str(value) argument.s = str(value)
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Structure to represent a tensor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.util import nest
class TensorSpec(object):
"""Spec to describe properties of a tensor."""
def __init__(self, shape, dtype='float32', name=None):
"""Create a TensorSpec.
Parameters
----------
shape : Sequence[int], required
The dimensions.
dtype : str, optional, default='float32'
The optional data type.
name : str, optional
The optional name.
"""
self._shape, self._dtype, self._name = shape, dtype, name
if shape is not None:
self._shape = nest.flatten(shape)
@property
def dtype(self):
"""Return the data type.
Returns
-------
str
The data type.
"""
return self._dtype
@property
def name(self):
"""Return the spec name.
Returns
-------
str
The spec name.
"""
return self._name
@property
def shape(self):
"""Return the dimensions.
Returns
-------
Sequence[int]
The dimensions.
"""
return self._shape
def is_compatible_with(self, spec_or_tensor):
"""Return a bool indicating whether given the spec is compatible.
Returns
-------
bool
**True** if ``shape`` and ``dtype`` are
both compatible otherwise **False**.
"""
def dtype_is_compatible_with(spec_or_tensor):
return self._dtype == spec_or_tensor.dtype
def shape_is_compatible_with(spec_or_tensor):
shape = spec_or_tensor.shape
if self._shape is not None and shape is not None:
if len(self._shape) != len(shape):
return False
for x_dim, y_dim in zip(self._shape, shape):
if x_dim != y_dim:
return False
return True
return \
dtype_is_compatible_with(spec_or_tensor) and \
shape_is_compatible_with(spec_or_tensor)
...@@ -180,7 +180,7 @@ class Workspace(backend.Workspace): ...@@ -180,7 +180,7 @@ class Workspace(backend.Workspace):
```python ```python
# Define a named tensor to feed # Define a named tensor to feed
x = dragon.Tensor('x') x = dragon.Tensor(name='x')
dragon.get_workspace().feed_tensor(x, 0) dragon.get_workspace().feed_tensor(x, 0)
# Feed by specifying a tensor name # Feed by specifying a tensor name
......
...@@ -18,7 +18,7 @@ import itertools ...@@ -18,7 +18,7 @@ import itertools
import numpy import numpy
import warnings import warnings
from dragon.core.autograph.tensor import Tensor from dragon.core.autograph.tensor import TensorRef
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.eager.tensor import EagerTensor from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import config from dragon.core.framework import config
...@@ -191,7 +191,7 @@ class RNNModule(object): ...@@ -191,7 +191,7 @@ class RNNModule(object):
self._weights_count += int(numpy.prod(shape)) self._weights_count += int(numpy.prod(shape))
# Create the flat float32 weights. # Create the flat float32 weights.
self._weights = EagerTensor(shape=[self._weights_count], trainable=True) self._weights = EagerTensor(shape=[self._weights_count], trainable=True)
self._weights_ref = Tensor(self._weights.name) self._weights_ref = TensorRef(self._weights.name)
def _uniform_init(self, shape, dtype='float32'): def _uniform_init(self, shape, dtype='float32'):
stdv = 1. / numpy.sqrt(self.hidden_size) stdv = 1. / numpy.sqrt(self.hidden_size)
......
...@@ -146,7 +146,7 @@ def getitem(self, item): ...@@ -146,7 +146,7 @@ def getitem(self, item):
Parameters Parameters
---------- ----------
item : Union[int, slice, dragon.EagerTensor] item : Union[slice, int, dragon.EagerTensor]
The index. The index.
Returns Returns
...@@ -490,7 +490,7 @@ def setitem(self, key, value): ...@@ -490,7 +490,7 @@ def setitem(self, key, value):
Parameters Parameters
---------- ----------
key : Union[int, slice, dragon.EagerTensor] key : Union[slice, int, dragon.EagerTensor]
The index. The index.
value : Union[dragon.EagerTensor, number] value : Union[dragon.EagerTensor, number]
The value to set. The value to set.
......
...@@ -121,7 +121,7 @@ def getitem(self, item): ...@@ -121,7 +121,7 @@ def getitem(self, item):
Parameters Parameters
---------- ----------
item : Union[int, slice, dragon.Tensor] item : Union[slice, int, dragon.Tensor]
The index. The index.
Returns Returns
...@@ -324,7 +324,7 @@ def setitem(self, key, value): ...@@ -324,7 +324,7 @@ def setitem(self, key, value):
Parameters Parameters
---------- ----------
key : Union[int, slice, dragon.Tensor] key : Union[slice, int, dragon.Tensor]
The index. The index.
value : Union[dragon.Tensor, number] value : Union[dragon.Tensor, number]
The value to set. The value to set.
......
...@@ -58,8 +58,8 @@ from dragon.vm.tensorflow.core.framework.dtypes import variant ...@@ -58,8 +58,8 @@ from dragon.vm.tensorflow.core.framework.dtypes import variant
from dragon.vm.tensorflow.core.framework.ops import convert_to_tensor from dragon.vm.tensorflow.core.framework.ops import convert_to_tensor
from dragon.vm.tensorflow.core.framework.ops import device from dragon.vm.tensorflow.core.framework.ops import device
from dragon.vm.tensorflow.core.framework.ops import name_scope from dragon.vm.tensorflow.core.framework.ops import name_scope
from dragon.vm.tensorflow.core.framework.tensor_spec import TensorSpec
from dragon.vm.tensorflow.core.framework.tensor_shape import TensorShape from dragon.vm.tensorflow.core.framework.tensor_shape import TensorShape
from dragon.vm.tensorflow.core.framework.tensor_spec import TensorSpec
from dragon.vm.tensorflow.core.module.module import Module from dragon.vm.tensorflow.core.module.module import Module
from dragon.vm.tensorflow.core.ops.array_ops import broadcast_to from dragon.vm.tensorflow.core.ops.array_ops import broadcast_to
from dragon.vm.tensorflow.core.ops.array_ops import concat from dragon.vm.tensorflow.core.ops.array_ops import concat
......
...@@ -52,9 +52,9 @@ def constant(value, dtype=None, shape=None, name='Const'): ...@@ -52,9 +52,9 @@ def constant(value, dtype=None, shape=None, name='Const'):
""" """
if dtype is not None: if dtype is not None:
if isinstance(value, numpy.ndarray): if isinstance(value, numpy.ndarray):
value = value.astype(str(dtype)) value = value.astype(dtype)
else: else:
value = numpy.array(value, str(dtype)) value = numpy.array(value, dtype)
else: else:
if not isinstance(value, numpy.ndarray): if not isinstance(value, numpy.ndarray):
value = numpy.array(value) value = numpy.array(value)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/dtypes.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/dtypes.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Data type utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -45,8 +46,34 @@ DT_VARIANT = 21 ...@@ -45,8 +46,34 @@ DT_VARIANT = 21
DT_UINT32 = 22 DT_UINT32 = 22
DT_UINT64 = 23 DT_UINT64 = 23
# Mappings between string names and enum values.
_STRING_TO_ENUM = {
'float16': DT_HALF,
'float32': DT_FLOAT,
'float64': DT_DOUBLE,
'int8': DT_INT8,
'uint8': DT_UINT8,
'int16': DT_INT16,
'uint16': DT_UINT16,
'int32': DT_INT32,
'uint32': DT_UINT32,
'int64': DT_INT64,
'uint64': DT_UINT64,
'string': DT_STRING,
'complex64': DT_COMPLEX64,
'complex128': DT_COMPLEX128,
'bool': DT_BOOL,
'qint8': DT_QINT8,
'quint8': DT_QUINT8,
'qint16': DT_QINT16,
'quint16': DT_QUINT16,
'qint32': DT_QINT32,
'bfloat16': DT_BFLOAT16,
'variant': DT_VARIANT,
}
class DType(object): class DType(str):
"""The basic data type. """The basic data type.
Following data types are defined: Following data types are defined:
...@@ -97,110 +124,251 @@ class DType(object): ...@@ -97,110 +124,251 @@ class DType(object):
""" """
def __init__(self, type_enum): def __init__(self, value):
"""Create a ``DType``. """Create a ``DType``.
Parameters Parameters
---------- ----------
type_enum : DataType value : str
The ``DataType`` value. The string name.
""" """
type_enum = int(type_enum) super(DType, self).__init__()
if type_enum == DT_INVALID: self._name = value
raise TypeError('<type_enum> is not a valid DataType.') self._type_enum = _STRING_TO_ENUM[value]
self._type_enum = type_enum
@property @property
def base_dtype(self): def as_numpy_dtype(self):
return self """Return as the number of numpy data type.
Returns
-------
type
The data type number.
"""
return _TF_TO_NP[self._type_enum]
@property @property
def real_dtype(self): def as_datatype_enum(self):
base = self.base_dtype """Return as the enum value of data type.
if base == complex64:
return float32 Returns
elif base == complex128: -------
return float64 int
else: The enum value.
"""
return self._type_enum
@property
def base_dtype(self):
"""Return the non-referenced data type.
Returns
-------
dragon.vm.tensorflow.dtypes.DType
The data type.
"""
return self return self
@property @property
def is_numpy_compatible(self): def is_numpy_compatible(self):
"""Return whether this data type is compatible with numpy.
Returns
-------
bool
**True** if compatible otherwise **False**.
"""
return self._type_enum in _TF_TO_NP return self._type_enum in _TF_TO_NP
@property @property
def as_numpy_dtype(self): def is_bool(self):
return _TF_TO_NP[self._type_enum] """Return whether this is a boolean type.
Returns
-------
bool
**True** if this is a boolean type otherwise **False**.
"""
return self.base_dtype == bool
@property @property
def as_datatype_enum(self): def is_complex(self):
return self._type_enum """Return whether this is a complex type.
Returns
-------
bool
**True** if this is a complex type otherwise **False**.
"""
return self.base_dtype in (complex64, complex128)
@property @property
def is_bool(self): def is_floating(self):
return self.base_dtype == bool """Return whether this is a floating type.
Returns
-------
bool
**True** if this is a floating type otherwise **False**.
"""
return (self.is_numpy_compatible and
issubclass(self.as_numpy_dtype, np.floating))
@property @property
def is_integer(self): def is_integer(self):
"""Return whether this is a integer type.
Returns
-------
bool
**True** if this is a integer type otherwise **False**.
"""
return (self.is_numpy_compatible and return (self.is_numpy_compatible and
not self.is_quantized and not self.is_quantized and
issubclass(self.as_numpy_dtype, np.integer)) issubclass(self.as_numpy_dtype, np.integer))
@property @property
def is_floating(self): def is_quantized(self):
return self.is_numpy_compatible and \ """Return whether this is a quantized type.
issubclass(self.as_numpy_dtype, np.floating)
@property Returns
def is_complex(self): -------
return self.base_dtype in (complex64, complex128) bool
**True** if this is a quantized type otherwise **False**.
@property """
def is_quantized(self):
return self.base_dtype in [qint8, quint8, qint16, quint16, qint32, bfloat16] return self.base_dtype in [qint8, quint8, qint16, quint16, qint32, bfloat16]
@property @property
def is_unsigned(self): def is_unsigned(self):
"""Return whether this is an unsigned type.
Returns
-------
bool
**True** if this is an unsigned type otherwise **False**.
"""
try: try:
return self.min == 0 return self.min == 0
except TypeError: except TypeError:
return False return False
@property @property
def min(self): def limits(self, clip_negative=True):
"""Return the numerical limits.
Parameters
----------
clip_negative : bool, optional, default=True
**True** to return positive limits only.
Returns
-------
Tuple[number, number]
The limits.
"""
min, max = dtype_range[self.as_numpy_dtype]
if clip_negative:
min = 0
return min, max
@property
def max(self):
"""Return the max representable value.
Returns
-------
number
The max representable value.
"""
if (self.is_quantized or self.base_dtype in if (self.is_quantized or self.base_dtype in
(bool, string, complex64, complex128)): (bool, string, complex64, complex128)):
raise TypeError("Cannot find minimum value of %s." % self) raise TypeError('Cannot find maximum value of %s.' % self)
try: try:
return np.finfo(self.as_numpy_dtype()).min return np.finfo(self.as_numpy_dtype()).max
except (TypeError, ValueError): except (TypeError, ValueError):
try: try:
return np.iinfo(self.as_numpy_dtype()).min return np.iinfo(self.as_numpy_dtype()).max
except (TypeError, ValueError): except (TypeError, ValueError):
raise TypeError("Cannot find minimum value of %s." % self) raise TypeError('Cannot find maximum value of %s.' % self)
@property @property
def max(self): def min(self):
"""Return the min representable value.
Returns
-------
number
The min representable value
"""
if (self.is_quantized or self.base_dtype in if (self.is_quantized or self.base_dtype in
(bool, string, complex64, complex128)): (bool, string, complex64, complex128)):
raise TypeError("Cannot find maximum value of %s." % self) raise TypeError("Cannot find minimum value of %s." % self)
try: try:
return np.finfo(self.as_numpy_dtype()).max return np.finfo(self.as_numpy_dtype()).min
except (TypeError, ValueError): except (TypeError, ValueError):
try: try:
return np.iinfo(self.as_numpy_dtype()).max return np.iinfo(self.as_numpy_dtype()).min
except (TypeError, ValueError): except (TypeError, ValueError):
raise TypeError("Cannot find maximum value of %s." % self) raise TypeError("Cannot find minimum value of %s." % self)
@property @property
def limits(self, clip_negative=True): def name(self):
min, max = dtype_range[self.as_numpy_dtype] """Return the type name.
if clip_negative:
min = 0 Returns
return min, max -------
str
The type name.
"""
return self._name
@property
def real_dtype(self):
"""Return the data type of real part.
Returns
-------
dragon.vm.tensorflow.dtypes.DType
The data type of real part.
"""
base = self.base_dtype
if base == complex64:
return float32
elif base == complex128:
return float64
else:
return self
def is_compatible_with(self, other): def is_compatible_with(self, other):
"""Return whether this data type can be converted as the other.
Parameters
----------
other : dragon.vm.tensorflow.dtypes.DType
The referring data type.
Returns
-------
bool
**True** if compatible otherwise **False**.
"""
other = as_dtype(other) other = as_dtype(other)
return self._type_enum in ( return self._type_enum in (
other.as_datatype_enum, other.base_dtype.as_datatype_enum) other.as_datatype_enum, other.base_dtype.as_datatype_enum)
...@@ -217,16 +385,11 @@ class DType(object): ...@@ -217,16 +385,11 @@ class DType(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
@property
def name(self):
"""Returns the string name for this `DType`."""
return _TYPE_TO_STRING[self._type_enum]
def __int__(self): def __int__(self):
return self._type_enum return self._type_enum
def __str__(self): def __str__(self):
return self.name return self._name
def __repr__(self): def __repr__(self):
return "tf." + self.name return "tf." + self.name
...@@ -235,8 +398,9 @@ class DType(object): ...@@ -235,8 +398,9 @@ class DType(object):
return self._type_enum return self._type_enum
# Define data type range of numpy dtype # Range of numpy dtype
dtype_range = {np.bool_: (False, True), dtype_range = {
np.bool_: (False, True),
np.bool8: (False, True), np.bool8: (False, True),
np.uint8: (0, 255), np.uint8: (0, 255),
np.uint16: (0, 65535), np.uint16: (0, 65535),
...@@ -247,58 +411,33 @@ dtype_range = {np.bool_: (False, True), ...@@ -247,58 +411,33 @@ dtype_range = {np.bool_: (False, True),
np.int32: (-2 ** 31, 2 ** 31 - 1), np.int32: (-2 ** 31, 2 ** 31 - 1),
np.uint32: (0, 2 ** 32 - 1), np.uint32: (0, 2 ** 32 - 1),
np.float32: (-1, 1), np.float32: (-1, 1),
np.float64: (-1, 1)} np.float64: (-1, 1),
# Define standard wrappers for the DataType enum.
float16 = DType(DT_HALF)
half = float16
float32 = DType(DT_FLOAT)
float64 = DType(DT_DOUBLE)
double = float64
int32 = DType(DT_INT32)
uint8 = DType(DT_UINT8)
uint16 = DType(DT_UINT16)
uint64 = DType(DT_UINT32)
uint32 = DType(DT_UINT64)
int16 = DType(DT_INT16)
int8 = DType(DT_INT8)
string = DType(DT_STRING)
complex64 = DType(DT_COMPLEX64)
complex128 = DType(DT_COMPLEX128)
int64 = DType(DT_INT64)
bool = DType(DT_BOOL)
qint8 = DType(DT_QINT8)
quint8 = DType(DT_QUINT8)
qint16 = DType(DT_QINT16)
quint16 = DType(DT_QUINT16)
qint32 = DType(DT_QINT32)
bfloat16 = DType(DT_BFLOAT16)
variant = DType(DT_VARIANT)
# Standard mappings between DataType values and string names.
_TYPE_TO_STRING = {
DT_HALF: "float16",
DT_FLOAT: "float32",
DT_DOUBLE: "float64",
DT_INT32: "int32",
DT_UINT8: "uint8",
DT_UINT16: "uint16",
DT_INT16: "int16",
DT_INT8: "int8",
DT_STRING: "string",
DT_COMPLEX64: "complex64",
DT_COMPLEX128: "complex128",
DT_INT64: "int64",
DT_BOOL: "bool",
DT_QINT8: "qint8",
DT_QUINT8: "quint8",
DT_QINT16: "qint16",
DT_QUINT16: "quint16",
DT_QINT32: "qint32",
DT_BFLOAT16: "bfloat16",
} }
# Define standard wrappers for the string name.
float16 = half = DType('float16')
float32 = DType('float32')
float64 = double = DType('float64')
int32 = DType('int32')
uint8 = DType('uint8')
uint16 = DType('uint16')
uint64 = DType('uint32')
uint32 = DType('uint64')
int16 = DType('int16')
int8 = DType('int8')
string = DType('string')
complex64 = DType('complex64')
complex128 = DType('complex128')
int64 = DType('int64')
bool = DType('bool')
qint8 = DType('qint8')
quint8 = DType('quint8')
qint16 = DType('qint16')
quint16 = DType('quint16')
qint32 = DType('qint32')
bfloat16 = DType('bfloat16')
variant = DType('variant')
# Numpy representation for quantized dtypes. # Numpy representation for quantized dtypes.
_np_qint8 = np.dtype([("qint8", np.int8)]) _np_qint8 = np.dtype([("qint8", np.int8)])
_np_quint8 = np.dtype([("quint8", np.uint8)]) _np_quint8 = np.dtype([("quint8", np.uint8)])
...@@ -379,7 +518,7 @@ _INTERN_TABLE = { ...@@ -379,7 +518,7 @@ _INTERN_TABLE = {
} }
_STRING_TO_TF = { _STRING_TO_TF = {
value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items() key: _INTERN_TABLE[value] for key, value in _STRING_TO_ENUM.items()
} }
_ANY_TO_TF = {} _ANY_TO_TF = {}
...@@ -404,16 +543,13 @@ def as_dtype(type_value): ...@@ -404,16 +543,13 @@ def as_dtype(type_value):
""" """
if isinstance(type_value, DType): if isinstance(type_value, DType):
return type_value return type_value
if isinstance(type_value, np.dtype): if isinstance(type_value, np.dtype):
try: try:
return _NP_TO_TF[type_value.type] return _NP_TO_TF[type_value.type]
except KeyError: except KeyError:
pass pass
try: try:
return _ANY_TO_TF[type_value] return _ANY_TO_TF[type_value]
except KeyError: except KeyError:
pass pass
raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value) raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
...@@ -18,5 +18,4 @@ from dragon.vm.tensorflow.core.framework.constant_op import * ...@@ -18,5 +18,4 @@ from dragon.vm.tensorflow.core.framework.constant_op import *
from dragon.vm.tensorflow.core.framework.dtypes import * from dragon.vm.tensorflow.core.framework.dtypes import *
from dragon.vm.tensorflow.core.framework.ops import device from dragon.vm.tensorflow.core.framework.ops import device
from dragon.vm.tensorflow.core.framework.ops import convert_to_tensor from dragon.vm.tensorflow.core.framework.ops import convert_to_tensor
from dragon.vm.tensorflow.core.framework.tensor_shape import Dimension
from dragon.vm.tensorflow.core.framework.tensor_shape import TensorShape from dragon.vm.tensorflow.core.framework.tensor_shape import TensorShape
...@@ -12,127 +12,87 @@ ...@@ -12,127 +12,87 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/tensor_shape.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/tensor_shape.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Tensor shape utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.autograph.tensor import Tensor
from dragon.core.eager.tensor import EagerTensor
class TensorShape(tuple):
"""Represent the a sequence of dimensions."""
class Dimension(object): def __init__(self, dims):
def __init__(self, value): """Create a ``TensorShape``.
if value is None:
self._value = None
else:
self._value = int(value)
if self._value < 0:
raise ValueError("Dimension %d must be >= 0" % self._value)
@property Parameters
def value(self): ----------
return self._value dims : Sequence[int]
The dimensions.
def __repr__(self): """
return "Dimension(%s)" % repr(self._value) super(TensorShape, self).__init__()
def __str__(self): @property
value = self._value def dims(self):
return "?" if value is None else str(value) """Return the list of dimensions.
def __eq__(self, other):
try:
other = as_dimension(other)
except (TypeError, ValueError):
return NotImplemented
if self._value is None or other.value is None:
return None
return self._value == other.value
def __ne__(self, other):
try:
other = as_dimension(other)
except (TypeError, ValueError):
return NotImplemented
if self._value is None or other.value is None:
return None
return self._value != other.value
def __int__(self):
return self._value
def as_dimension(value):
if isinstance(value, Dimension):
return value
else:
return Dimension(value)
Returns
-------
List[int]
The dimensions.
class TensorShape(object): """
def __init__(self, dims): return list(self)
if dims is None:
self._dims = None
elif isinstance(dims, TensorShape):
self._dims = dims.dims
else:
try:
dims_iter = iter(dims)
except TypeError:
self._dims = [as_dimension(dims)]
else:
self._dims = [as_dimension(d) for d in dims_iter]
@property @property
def dims(self): def ndims(self):
return self._dims """Return the number of dimensions.
Deprecated. See ``TensorShape.rank``.
Returns
-------
int
The number of dimensions.
"""
return len(self)
@property @property
def ndims(self): def rank(self):
if self._dims is None: """Return the rank of shape.
return None
else: Returns
return len(self._dims) -------
int
The rank.
"""
return len(self)
def as_list(self): def as_list(self):
if self._dims is None: """Return the list of dimensions.
raise ValueError("as_list() is not defined on an unknown TensorShape.")
return [dim.value for dim in self._dims] Returns
-------
List[int]
The dimensions.
"""
return list(self)
def __repr__(self): def __repr__(self):
return "TensorShape(%r)" % self._dims return "TensorShape({})".format(list(self))
def __str__(self): def __str__(self):
if self.ndims is None: if self.ndims == 1:
return "<unknown>" return "(%s,)" % self.dims[0]
elif self.ndims == 1:
return "(%s,)" % self._dims[0]
else: else:
return "(%s)" % ", ".join(str(d) for d in self._dims) return "(%s)" % ", ".join(str(d) for d in self.dims)
def __getitem__(self, key): def __getitem__(self, key):
if self._dims is not None:
if isinstance(key, slice): if isinstance(key, slice):
return TensorShape(self._dims[key]) return TensorShape(self.dims[key])
else: else:
return self._dims[key] return self.dims[key]
else:
return Dimension(None)
def dimension_value(dimension):
"""Return the value of specified dimension."""
if isinstance(dimension, Dimension):
return dimension.value
return dimension
def get_shape(self):
"""Construct the shape descriptor."""
return TensorShape(self.shape)
# The Monkey Patching.
EagerTensor.get_shape = get_shape
Tensor.get_shape = get_shape
...@@ -17,16 +17,26 @@ from __future__ import absolute_import ...@@ -17,16 +17,26 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.framework import tensor_spec
from dragon.vm.tensorflow.core.framework import dtypes from dragon.vm.tensorflow.core.framework import dtypes
from dragon.vm.tensorflow.core.framework import tensor_shape from dragon.vm.tensorflow.core.framework import tensor_shape
class TensorSpec(tensor_spec.TensorSpec): class TensorSpec(object):
"""Spec to describe properties of a tensor.""" """Spec to describe properties of a tensor."""
def __init__(self, shape, dtype=dtypes.float32, name=None): def __init__(self, shape, dtype='float32', name=None):
"""Create a ``TensorSpec``.""" """Create a TensorSpec.
Parameters
----------
shape : Sequence[int], required
The dimensions.
dtype : str, optional, default='float32'
The optional data type.
name : str, optional
The optional name.
"""
self._shape = tensor_shape.TensorShape(shape) self._shape = tensor_shape.TensorShape(shape)
try: try:
self._shape_tuple = tuple(self._shape.as_list()) self._shape_tuple = tuple(self._shape.as_list())
...@@ -45,7 +55,7 @@ class TensorSpec(tensor_spec.TensorSpec): ...@@ -45,7 +55,7 @@ class TensorSpec(tensor_spec.TensorSpec):
The data type. The data type.
""" """
return self._dtype.name return str(self._dtype)
@property @property
def name(self): def name(self):
...@@ -70,3 +80,29 @@ class TensorSpec(tensor_spec.TensorSpec): ...@@ -70,3 +80,29 @@ class TensorSpec(tensor_spec.TensorSpec):
""" """
return self._shape.as_list() return self._shape.as_list()
def is_compatible_with(self, spec_or_tensor):
"""Return a bool whether given the spec is compatible.
Returns
-------
bool
**True** if compatible otherwise **False**.
"""
def dtype_is_compatible_with(spec_or_tensor):
return self.dtype == spec_or_tensor.dtype
def shape_is_compatible_with(spec_or_tensor):
shape = spec_or_tensor.shape
if self._shape is not None and shape is not None:
if len(self.shape) != len(shape):
return False
for x_dim, y_dim in zip(self.shape, shape):
if x_dim != y_dim:
return False
return True
return \
dtype_is_compatible_with(spec_or_tensor) and \
shape_is_compatible_with(spec_or_tensor)
...@@ -43,8 +43,8 @@ def Input( ...@@ -43,8 +43,8 @@ def Input(
x = tf.keras.Input(shape=(8,), batch_size=8, dtype='float32') x = tf.keras.Input(shape=(8,), batch_size=8, dtype='float32')
# Create a placeholder aliasing an existing tensor # Create a placeholder aliasing an existing tensor
x = dragon.Tensor('x', shape=(8,), dtype='float32').constant() x = dragon.Tensor(shape=(8,), dtype='float32').constant()
xx = tf.keras.Input(tensor=x) y = tf.keras.Input(tensor=x)
``` ```
Parameters Parameters
...@@ -69,38 +69,29 @@ def Input( ...@@ -69,38 +69,29 @@ def Input(
if 'batch_shape' in kwargs: if 'batch_shape' in kwargs:
batch_shape = kwargs.pop('batch_shape') batch_shape = kwargs.pop('batch_shape')
if shape and batch_shape: if shape and batch_shape:
raise ValueError( raise ValueError('Specify <shape> or <batch_shape>, not both.')
'Specify <shape> or '
'<batch_shape>, not both.'
)
shape = batch_shape shape = batch_shape
else: else:
if shape is not None: if shape is not None:
shape = (batch_size,) + tuple(shape) shape = (batch_size,) + tuple(shape)
if kwargs: if kwargs:
raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
if dtype is None: if dtype is None:
if tensor is not None: if tensor is not None:
dtype = tensor.dtype dtype = tensor.dtype
else: else:
dtype = 'float32' dtype = 'float32'
if shape is None: if shape is None:
if tensor is None: if tensor is None:
raise ValueError('Specify either <shape> or <tensor>.') raise ValueError('Specify either <shape> or <tensor>.')
else: else:
shape = tensor.shape shape = tensor.shape
if isinstance(shape, tensor_shape.TensorShape): if isinstance(shape, tensor_shape.TensorShape):
shape = tuple(shape.as_list()) shape = tuple(shape.as_list())
elif isinstance(shape, six.integer_types): elif isinstance(shape, six.integer_types):
shape = (shape,) shape = (shape,)
placeholder = array_ops.placeholder( placeholder = array_ops.placeholder(
dtype=dtype, shape=shape, name=name if name else 'input') dtype=dtype, shape=shape, name=name if name else 'input')
if tensor is not None: if tensor is not None:
workspace.get_workspace().register_alias(tensor, placeholder.id) workspace.get_workspace().register_alias(tensor, placeholder.id)
return placeholder return placeholder
...@@ -78,8 +78,7 @@ def assert_input_compatibility(input_spec, inputs, layer_name): ...@@ -78,8 +78,7 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
raise ValueError( raise ValueError(
'Layer ' + layer_name + ' expects ' + 'Layer ' + layer_name + ' expects ' +
str(len(input_spec)) + ' inputs, ' str(len(input_spec)) + ' inputs, '
'but it received ' + str(len(inputs)) + ' input tensors.' 'but it received ' + str(len(inputs)) + ' input tensors.')
)
# For each pair of input and spec. # For each pair of input and spec.
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)): for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
if spec is None: if spec is None:
......
...@@ -72,13 +72,11 @@ class Conv(Layer): ...@@ -72,13 +72,11 @@ class Conv(Layer):
def build(self, input_shape): def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
channel_axis = self._get_channel_axis() channel_axis = self._get_channel_axis()
if input_shape.dims[channel_axis].value is None: if input_shape.dims[channel_axis] is None:
raise ValueError( raise ValueError(
'The channel dimension of the input ' 'The channel dimension of the input '
'should be determined, got None.' 'should be determined, got None.')
)
input_dim = int(input_shape[channel_axis]) input_dim = int(input_shape[channel_axis])
# Assume that kernel is packed into NCHW format # Assume that kernel is packed into NCHW format
# for computing the fans correctly # for computing the fans correctly
if self.filters > 0: if self.filters > 0:
...@@ -86,7 +84,6 @@ class Conv(Layer): ...@@ -86,7 +84,6 @@ class Conv(Layer):
else: else:
self.filters = input_dim self.filters = input_dim
kernel_shape = (input_dim, 1) + self.kernel_size kernel_shape = (input_dim, 1) + self.kernel_size
self.kernel = self.add_weight( self.kernel = self.add_weight(
name='kernel', name='kernel',
shape=kernel_shape, shape=kernel_shape,
...@@ -106,7 +103,6 @@ class Conv(Layer): ...@@ -106,7 +103,6 @@ class Conv(Layer):
) )
else: else:
self.bias = None self.bias = None
self.built = True self.built = True
def call(self, inputs): def call(self, inputs):
...@@ -280,17 +276,15 @@ class Conv2DTranspose(Conv2D): ...@@ -280,17 +276,15 @@ class Conv2DTranspose(Conv2D):
def build(self, input_shape): def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
channel_axis = self._get_channel_axis() channel_axis = self._get_channel_axis()
if input_shape.dims[channel_axis].value is None: if input_shape.dims[channel_axis] is None:
raise ValueError( raise ValueError(
'The channel dimension of the inputs ' 'The channel dimension of the inputs '
'should be determined, got None.' 'should be determined, got None.'
) )
input_dim = int(input_shape[channel_axis]) input_dim = int(input_shape[channel_axis])
# Assume that kernel is packed into NCHW format, # Assume that kernel is packed into NCHW format,
# for computing the fans correctly. # for computing the fans correctly.
kernel_shape = (input_dim, self.filters) + self.kernel_size kernel_shape = (input_dim, self.filters) + self.kernel_size
self.kernel = self.add_weight( self.kernel = self.add_weight(
name='kernel', name='kernel',
shape=kernel_shape, shape=kernel_shape,
...@@ -310,7 +304,6 @@ class Conv2DTranspose(Conv2D): ...@@ -310,7 +304,6 @@ class Conv2DTranspose(Conv2D):
) )
else: else:
self.bias = None self.bias = None
self.built = True self.built = True
def call(self, inputs): def call(self, inputs):
...@@ -320,11 +313,9 @@ class Conv2DTranspose(Conv2D): ...@@ -320,11 +313,9 @@ class Conv2DTranspose(Conv2D):
h_axis, w_axis = 2, 3 h_axis, w_axis = 2, 3
else: else:
h_axis, w_axis = 1, 2 h_axis, w_axis = 1, 2
height, width = inputs_shape[h_axis], inputs_shape[w_axis] height, width = inputs_shape[h_axis], inputs_shape[w_axis]
kernel_h, kernel_w = self.kernel_size kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides stride_h, stride_w = self.strides
if self.output_padding is None: if self.output_padding is None:
out_pad_h = out_pad_w = None out_pad_h = out_pad_w = None
else: else:
...@@ -349,7 +340,6 @@ class Conv2DTranspose(Conv2D): ...@@ -349,7 +340,6 @@ class Conv2DTranspose(Conv2D):
output_shape = (batch_size, self.filters, out_height, out_width) output_shape = (batch_size, self.filters, out_height, out_width)
else: else:
output_shape = (batch_size, out_height, out_width, self.filters) output_shape = (batch_size, out_height, out_width, self.filters)
outputs = nn_ops.conv_transpose( outputs = nn_ops.conv_transpose(
input=inputs, input=inputs,
filters=self.kernel, filters=self.kernel,
......
...@@ -83,20 +83,17 @@ class Dense(Layer): ...@@ -83,20 +83,17 @@ class Dense(Layer):
if not (dtype.is_floating or dtype.is_complex): if not (dtype.is_floating or dtype.is_complex):
raise TypeError( raise TypeError(
'Unable to build `Dense` layer with non-floating point ' 'Unable to build `Dense` layer with non-floating point '
'dtype %s' % (dtype,) 'dtype %s' % (dtype,))
)
if self.input_dim is None: if self.input_dim is None:
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
if tensor_shape.dimension_value(input_shape[-1]) is None: if input_shape[-1] is None:
raise ValueError( raise ValueError(
'The last dimension of the inputs should be defined.\n' 'The last dimension of the inputs should be defined.\n'
'Or you should specify <input_dim> in the constructor.' 'Or you should specify <input_dim> in the constructor.')
) last_dim = input_shape[-1]
last_dim = tensor_shape.dimension_value(input_shape[-1])
else: else:
last_dim = self.input_dim last_dim = self.input_dim
self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})
self.kernel = self.add_weight( self.kernel = self.add_weight(
'kernel', 'kernel',
shape=[last_dim, self.units], shape=[last_dim, self.units],
......
...@@ -97,7 +97,7 @@ class BatchNormalization(Layer): ...@@ -97,7 +97,7 @@ class BatchNormalization(Layer):
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims: if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape) raise ValueError('Input has undefined rank:', input_shape)
param_shape = [input_shape.dims[self.axis].value] param_shape = [input_shape.dims[self.axis]]
self.input_spec = InputSpec( self.input_spec = InputSpec(
# Each layer should adapt to the: # Each layer should adapt to the:
# 1) The number of dimensions. # 1) The number of dimensions.
......
...@@ -217,12 +217,7 @@ def fill(dims, value=0, dtype=None, name=None): ...@@ -217,12 +217,7 @@ def fill(dims, value=0, dtype=None, name=None):
dtype = 'int32' dtype = 'int32'
elif dtype == numpy.float64: elif dtype == numpy.float64:
dtype = 'float32' dtype = 'float32'
return init_ops.fill( return init_ops.fill(shape=dims, value=value, dtype=dtype, name=name)
shape=dims,
value=value,
dtype=str(dtype),
name=name,
)
def gather(params, indices, axis=0, name=None): def gather(params, indices, axis=0, name=None):
...@@ -315,7 +310,7 @@ def ones(shape, dtype='float32', name=None): ...@@ -315,7 +310,7 @@ def ones(shape, dtype='float32', name=None):
A optional name for the operation. A optional name for the operation.
""" """
return init_ops.fill(shape, value=1, dtype=str(dtype), name=name) return init_ops.fill(shape, value=1, dtype=dtype, name=name)
def ones_like(input, dtype='float32', name=None): def ones_like(input, dtype='float32', name=None):
...@@ -340,7 +335,7 @@ def ones_like(input, dtype='float32', name=None): ...@@ -340,7 +335,7 @@ def ones_like(input, dtype='float32', name=None):
A optional name for the operation. A optional name for the operation.
""" """
return init_ops.ones_like(input, dtype=str(dtype), name=name) return init_ops.ones_like(input, dtype=dtype, name=name)
def one_hot( def one_hot(
...@@ -488,13 +483,13 @@ def placeholder(dtype=None, shape=None, name=None): ...@@ -488,13 +483,13 @@ def placeholder(dtype=None, shape=None, name=None):
workspace.get_workspace().unique_name( workspace.get_workspace().unique_name(
context.get_name_scope() + name if name else 'Placeholder', context.get_name_scope() + name if name else 'Placeholder',
suffix=':0', namespace='Tensor'), suffix=':0', namespace='Tensor'),
dtype=str(dtype) if dtype else dtype, dtype=dtype if dtype else dtype,
shape=shape, shape=shape,
).constant() ).constant()
def reshape(tensor, shape, name=None): def reshape(tensor, shape, name=None):
r"""Change the dimensions of input. """Change the dimensions of input.
Examples: Examples:
...@@ -527,7 +522,7 @@ def reshape(tensor, shape, name=None): ...@@ -527,7 +522,7 @@ def reshape(tensor, shape, name=None):
def shape(input, name=None): def shape(input, name=None):
r"""Return the shape of input. """Return the shape of input.
Examples: Examples:
...@@ -641,7 +636,7 @@ def split( ...@@ -641,7 +636,7 @@ def split(
axis=0, axis=0,
name=None, name=None,
): ):
r"""Split input into chunks along the given axis. """Split input into chunks along the given axis.
Either number or size of splits will be accepted: Either number or size of splits will be accepted:
...@@ -775,7 +770,7 @@ def zeros(shape, dtype='float32', name=None): ...@@ -775,7 +770,7 @@ def zeros(shape, dtype='float32', name=None):
A optional name for the operation. A optional name for the operation.
""" """
return init_ops.fill(shape, value=0., dtype=str(dtype), name=name) return init_ops.fill(shape, value=0., dtype=dtype, name=name)
def zeros_like(input, dtype='float32', name=None): def zeros_like(input, dtype='float32', name=None):
...@@ -800,4 +795,4 @@ def zeros_like(input, dtype='float32', name=None): ...@@ -800,4 +795,4 @@ def zeros_like(input, dtype='float32', name=None):
A optional name for the operation. A optional name for the operation.
""" """
return init_ops.zeros_like(input, dtype=str(dtype), name=name) return init_ops.zeros_like(input, dtype=dtype, name=name)
...@@ -79,7 +79,7 @@ class Constant(Initializer): ...@@ -79,7 +79,7 @@ class Constant(Initializer):
The output tensor. The output tensor.
""" """
dtype = str(self.dtype) if dtype is None else str(dtype) dtype = str(self.dtype) if dtype is None else dtype
return init_ops.fill(shape, value=self.value, dtype=dtype) return init_ops.fill(shape, value=self.value, dtype=dtype)
...@@ -125,7 +125,7 @@ class RandomNormal(Initializer): ...@@ -125,7 +125,7 @@ class RandomNormal(Initializer):
shape=shape, shape=shape,
mean=self.mean, mean=self.mean,
std=self.stddev, std=self.stddev,
dtype=str(self.dtype) if dtype is None else str(dtype), dtype=str(self.dtype) if dtype is None else dtype,
) )
...@@ -167,12 +167,12 @@ class RandomUniform(Initializer): ...@@ -167,12 +167,12 @@ class RandomUniform(Initializer):
The output tensor. The output tensor.
""" """
dtype = str(self.dtype) if dtype is None else str(dtype) dtype = str(self.dtype) if dtype is None else dtype
return init_ops.random_uniform( return init_ops.random_uniform(
shape=shape, shape=shape,
low=self.minval, low=self.minval,
high=self.maxval, high=self.maxval,
dtype=str(self.dtype) if dtype is None else str(dtype), dtype=str(self.dtype) if dtype is None else dtype,
) )
...@@ -218,7 +218,7 @@ class TruncatedNormal(Initializer): ...@@ -218,7 +218,7 @@ class TruncatedNormal(Initializer):
shape=shape, shape=shape,
mean=self.mean, mean=self.mean,
std=self.stddev, std=self.stddev,
dtype=str(self.dtype) if dtype is None else str(dtype), dtype=str(self.dtype) if dtype is None else dtype,
) )
...@@ -280,14 +280,14 @@ class VarianceScaling(Initializer): ...@@ -280,14 +280,14 @@ class VarianceScaling(Initializer):
shape=shape, shape=shape,
mode=self.mode, mode=self.mode,
scale=self.scale * 2.0, scale=self.scale * 2.0,
dtype=str(self.dtype) if dtype is None else str(dtype) dtype=str(self.dtype) if dtype is None else dtype
) )
else: else:
return init_ops.glorot_uniform( return init_ops.glorot_uniform(
shape=shape, shape=shape,
mode=self.mode, mode=self.mode,
scale=self.scale * 3.0, scale=self.scale * 3.0,
dtype=str(self.dtype) if dtype is None else str(dtype) dtype=str(self.dtype) if dtype is None else dtype
) )
...@@ -374,8 +374,8 @@ class Ones(Initializer): ...@@ -374,8 +374,8 @@ class Ones(Initializer):
The output tensor. The output tensor.
""" """
dtype = str(self.dtype) if dtype is None else str(dtype) dtype = str(self.dtype) if dtype is None else dtype
return init_ops.fill(shape, value=1, dtype=str(dtype)) return init_ops.fill(shape, value=1, dtype=dtype)
class Zeros(Initializer): class Zeros(Initializer):
...@@ -412,7 +412,7 @@ class Zeros(Initializer): ...@@ -412,7 +412,7 @@ class Zeros(Initializer):
The output tensor. The output tensor.
""" """
dtype = str(self.dtype) if dtype is None else str(dtype) dtype = str(self.dtype) if dtype is None else dtype
return init_ops.fill(shape, value=0, dtype=dtype) return init_ops.fill(shape, value=0, dtype=dtype)
......
...@@ -48,4 +48,4 @@ def eye(num_rows, num_columns=None, dtype='float32', name=None): ...@@ -48,4 +48,4 @@ def eye(num_rows, num_columns=None, dtype='float32', name=None):
The output tensor. The output tensor.
""" """
return init_ops.eye(num_rows, num_columns, dtype=str(dtype), name=name) return init_ops.eye(num_rows, num_columns, dtype=dtype, name=name)
...@@ -209,7 +209,7 @@ def cast(x, dtype, name=None): ...@@ -209,7 +209,7 @@ def cast(x, dtype, name=None):
The output tensor. The output tensor.
""" """
return array_ops.cast(x, dtype=str(dtype), name=name) return array_ops.cast(x, dtype=dtype, name=name)
def ceil(x, name=None): def ceil(x, name=None):
...@@ -890,7 +890,7 @@ def range(start, limit=None, delta=1, dtype='int64', name=None): ...@@ -890,7 +890,7 @@ def range(start, limit=None, delta=1, dtype='int64', name=None):
start=start, start=start,
stop=limit, stop=limit,
step=delta, step=delta,
dtype=str(dtype), dtype=dtype,
name=name, name=name,
) )
......
...@@ -56,12 +56,11 @@ def avg_pool( ...@@ -56,12 +56,11 @@ def avg_pool(
The output tensor. The output tensor.
""" """
num_total_dims = input.get_shape().ndims if input.shape is not None:
if num_total_dims is None: num_total_dims = len(input.shape)
else:
num_total_dims = len(ksize) num_total_dims = len(ksize)
num_spatial_dims = num_total_dims - 2 num_spatial_dims = num_total_dims - 2
# Make default parameters
data_format = data_format if data_format else 'NHWC' data_format = data_format if data_format else 'NHWC'
start_axis = 2 if data_format.startswith('NC') else 1 start_axis = 2 if data_format.startswith('NC') else 1
normalize_spatial_args = \ normalize_spatial_args = \
...@@ -74,7 +73,6 @@ def avg_pool( ...@@ -74,7 +73,6 @@ def avg_pool(
ksize = normalize_spatial_args('ksize', ksize) ksize = normalize_spatial_args('ksize', ksize)
strides = normalize_spatial_args('strides', strides) strides = normalize_spatial_args('strides', strides)
padding, pads = normalize_spatial_args('padding', padding) padding, pads = normalize_spatial_args('padding', padding)
return getattr(vision_ops, 'pool{}d'.format(num_spatial_dims))( return getattr(vision_ops, 'pool{}d'.format(num_spatial_dims))(
[input], [input],
kernel_shape=ksize.shape[start_axis:start_axis + num_spatial_dims], kernel_shape=ksize.shape[start_axis:start_axis + num_spatial_dims],
...@@ -173,12 +171,11 @@ def convolution( ...@@ -173,12 +171,11 @@ def convolution(
The output tensor. The output tensor.
""" """
num_total_dims = filters.get_shape().ndims if filters.shape is not None:
if num_total_dims is None: num_total_dims = len(filters.shape)
raise ValueError('Rank of `filters` must be determined.') else:
raise ValueError('Rank of <filters> must be determined.')
num_spatial_dims = num_total_dims - 2 num_spatial_dims = num_total_dims - 2
# Make default parameters
data_format = data_format if data_format else 'NHWC' data_format = data_format if data_format else 'NHWC'
start_axis = 2 if data_format.startswith('NC') else 1 start_axis = 2 if data_format.startswith('NC') else 1
normalize_spatial_args = \ normalize_spatial_args = \
...@@ -191,7 +188,6 @@ def convolution( ...@@ -191,7 +188,6 @@ def convolution(
strides = normalize_spatial_args('strides', strides) strides = normalize_spatial_args('strides', strides)
dilations = normalize_spatial_args('dilations', dilations) dilations = normalize_spatial_args('dilations', dilations)
padding, pads = normalize_spatial_args('padding', padding) padding, pads = normalize_spatial_args('padding', padding)
return getattr(vision_ops, '{}{}d'.format( return getattr(vision_ops, '{}{}d'.format(
kwargs.get('conv_type', 'conv'), num_spatial_dims))( kwargs.get('conv_type', 'conv'), num_spatial_dims))(
[input, filters], [input, filters],
...@@ -241,14 +237,13 @@ def conv_transpose( ...@@ -241,14 +237,13 @@ def conv_transpose(
The output tensor. The output tensor.
""" """
num_total_dims = filters.get_shape().ndims if filters.shape is not None:
if num_total_dims is None: num_total_dims = len(filters.shape)
num_total_dims = input.get_shape().ndims elif input.shape is not None:
if num_total_dims is None: num_total_dims = len(input.shape)
raise ValueError("rank of input or filters must be known.") else:
raise ValueError('Rank of <input> or <filters> must be known.')
num_spatial_dims = num_total_dims - 2 num_spatial_dims = num_total_dims - 2
# Make default parameters
data_format = data_format if data_format else 'NHWC' data_format = data_format if data_format else 'NHWC'
start_axis = 2 if data_format.startswith('NC') else 1 start_axis = 2 if data_format.startswith('NC') else 1
normalize_spatial_args = \ normalize_spatial_args = \
...@@ -264,7 +259,6 @@ def conv_transpose( ...@@ -264,7 +259,6 @@ def conv_transpose(
if padding == 'SAME' and output_shape is None: if padding == 'SAME' and output_shape is None:
raise ValueError('Excepted <output_shape> for same padding.') raise ValueError('Excepted <output_shape> for same padding.')
output_shape = normalize_spatial_args('output_shape', output_shape) output_shape = normalize_spatial_args('output_shape', output_shape)
return getattr(vision_ops, 'conv{}d_transpose'.format(num_spatial_dims))( return getattr(vision_ops, 'conv{}d_transpose'.format(num_spatial_dims))(
[input, filters], [input, filters],
kernel_shape=filters.shape[2:], kernel_shape=filters.shape[2:],
...@@ -608,12 +602,11 @@ def max_pool( ...@@ -608,12 +602,11 @@ def max_pool(
The output tensor. The output tensor.
""" """
num_total_dims = input.get_shape().ndims if input.shape is not None:
if num_total_dims is None: num_total_dims = len(input.shape)
else:
num_total_dims = len(ksize) num_total_dims = len(ksize)
num_spatial_dims = num_total_dims - 2 num_spatial_dims = num_total_dims - 2
# Make default parameters
data_format = data_format if data_format else 'NHWC' data_format = data_format if data_format else 'NHWC'
start_axis = 2 if data_format.startswith('NC') else 1 start_axis = 2 if data_format.startswith('NC') else 1
normalize_spatial_args = \ normalize_spatial_args = \
...@@ -626,7 +619,6 @@ def max_pool( ...@@ -626,7 +619,6 @@ def max_pool(
ksize = normalize_spatial_args('ksize', ksize) ksize = normalize_spatial_args('ksize', ksize)
strides = normalize_spatial_args('strides', strides) strides = normalize_spatial_args('strides', strides)
padding, pads = normalize_spatial_args('padding', padding) padding, pads = normalize_spatial_args('padding', padding)
return getattr(vision_ops, 'pool{}d'.format(num_spatial_dims))( return getattr(vision_ops, 'pool{}d'.format(num_spatial_dims))(
[input], [input],
kernel_shape=ksize[start_axis:start_axis + num_spatial_dims], kernel_shape=ksize[start_axis:start_axis + num_spatial_dims],
......
...@@ -50,7 +50,7 @@ def random_normal( ...@@ -50,7 +50,7 @@ def random_normal(
The output tensor. The output tensor.
""" """
_, dtype, init_fn = seed, str(dtype), init_ops.random_normal _, dtype, init_fn = seed, dtype, init_ops.random_normal
return init_fn(shape, mean, stddev, dtype=dtype, name=name) return init_fn(shape, mean, stddev, dtype=dtype, name=name)
...@@ -87,7 +87,7 @@ def random_uniform( ...@@ -87,7 +87,7 @@ def random_uniform(
The output tensor. The output tensor.
""" """
_, dtype, init_fn = seed, str(dtype), init_ops.random_uniform _, dtype, init_fn = seed, dtype, init_ops.random_uniform
return init_fn(shape, minval, maxval, dtype=dtype, name=name) return init_fn(shape, minval, maxval, dtype=dtype, name=name)
...@@ -125,5 +125,5 @@ def truncated_normal( ...@@ -125,5 +125,5 @@ def truncated_normal(
The output tensor. The output tensor.
""" """
_, dtype, init_fn = seed, str(dtype), init_ops.truncated_normal _, dtype, init_fn = seed, dtype, init_ops.truncated_normal
return init_fn(shape, mean, stddev, dtype=dtype, name=name) return init_fn(shape, mean, stddev, dtype=dtype, name=name)
...@@ -43,27 +43,21 @@ class Variable(VariableMetaclass, EagerTensor): ...@@ -43,27 +43,21 @@ class Variable(VariableMetaclass, EagerTensor):
): ):
"""Create a ``Variable``.""" """Create a ``Variable``."""
super(Variable, self).__init__(trainable=trainable) super(Variable, self).__init__(trainable=trainable)
name = name if name else 'Variable' name = name if name else 'Variable'
self._name = context.get_name_scope() + name + ':0' self._name = context.get_name_scope() + name + ':0'
# Determine th value. # Determine th value.
if isinstance(initial_value, EagerTensor): if isinstance(initial_value, EagerTensor):
initial_value = initial_value.numpy() initial_value = initial_value.numpy()
elif isinstance(initial_value, Tensor): elif isinstance(initial_value, Tensor):
initial_value = initial_value.get_value() initial_value = initial_value.get_value()
# Determine the data type. # Determine the data type.
if not isinstance(initial_value, numpy.ndarray): if not isinstance(initial_value, numpy.ndarray):
initial_value = numpy.array( initial_value = numpy.array(initial_value, dtype if dtype else dtype)
initial_value, str(dtype) if dtype else dtype)
elif dtype is not None: elif dtype is not None:
initial_value = initial_value.astype(str(dtype)) initial_value = initial_value.astype(dtype)
# Determine the tensor shape. # Determine the tensor shape.
if shape is not None: if shape is not None:
initial_value = initial_value.reshape(shape) initial_value = initial_value.reshape(shape)
self._from_numpy(initial_value, copy=False) self._from_numpy(initial_value, copy=False)
@property @property
...@@ -96,7 +90,6 @@ def get_default_initializer(name, shape=None, dtype=dtypes.float32): ...@@ -96,7 +90,6 @@ def get_default_initializer(name, shape=None, dtype=dtypes.float32):
# Defaults: float32. # Defaults: float32.
if dtype is None: if dtype is None:
dtype = dtypes.float32 dtype = dtypes.float32
# Xavier for float16, float32, float64. # Xavier for float16, float32, float64.
if dtype.is_floating: if dtype.is_floating:
initializer = init_ops.glorot_uniform_initializer() initializer = init_ops.glorot_uniform_initializer()
......
...@@ -67,20 +67,20 @@ class TestTensor(unittest.TestCase): ...@@ -67,20 +67,20 @@ class TestTensor(unittest.TestCase):
self.assertEqual(dragon.Tensor().ndim, 0) self.assertEqual(dragon.Tensor().ndim, 0)
self.assertEqual(dragon.Tensor(shape=(2,)).ndim, 1) self.assertEqual(dragon.Tensor(shape=(2,)).ndim, 1)
self.assertEqual(dragon.Tensor().shape, None) self.assertEqual(dragon.Tensor().shape, None)
self.assertEqual(dragon.Tensor(shape=(2,)).shape, [2]) self.assertEqual(dragon.Tensor(shape=(2,)).shape, (2,))
self.assertEqual(dragon.Tensor().size, 0) self.assertEqual(dragon.Tensor().size, 0)
self.assertEqual(dragon.Tensor(shape=(2, None)).size, math.inf) self.assertEqual(dragon.Tensor(shape=(2, None)).size, math.inf)
self.assertEqual(dragon.Tensor(shape=(2,)).size, 2) self.assertEqual(dragon.Tensor(shape=(2,)).size, 2)
self.assertEqual(dragon.Tensor().dtype, None) self.assertEqual(dragon.Tensor().dtype, None)
self.assertEqual(dragon.Tensor(dtype='float32').dtype, 'float32') self.assertEqual(dragon.Tensor(dtype='float32').dtype, 'float32')
self.assertEqual(dragon.EagerTensor(shape=(2,)).ndim, 1) self.assertEqual(dragon.EagerTensor(shape=(2,)).ndim, 1)
self.assertEqual(dragon.EagerTensor(shape=(2,)).shape, [2]) self.assertEqual(dragon.EagerTensor(shape=(2,)).shape, (2,))
self.assertEqual(dragon.EagerTensor(shape=(2,)).size, 2) self.assertEqual(dragon.EagerTensor(shape=(2,)).size, 2)
self.assertEqual(dragon.EagerTensor(shape=(2,), dtype='float32').dtype, 'float32') self.assertEqual(dragon.EagerTensor(shape=(2,), dtype='float32').dtype, 'float32')
self.assertEqual(dragon.EagerTensor().device, dragon.EagerTensor().device) self.assertEqual(dragon.EagerTensor().device, dragon.EagerTensor().device)
self.assertNotEqual(a.__hash__(), b.__hash__()) self.assertNotEqual(a.__hash__(), b.__hash__())
self.assertNotEqual(a.__repr__(), b.__repr__()) self.assertNotEqual(a.__repr__(), b.__repr__())
self.assertNotEqual(b.__repr__(), dragon.EagerTensor([2]).__repr__()) self.assertNotEqual(b.__repr__(), dragon.EagerTensor((2,)).__repr__())
self.assertEqual(int(a.constant().set_value(1)), 1) self.assertEqual(int(a.constant().set_value(1)), 1)
self.assertEqual(float(dragon.Tensor.convert_to(1)), 1.) self.assertEqual(float(dragon.Tensor.convert_to(1)), 1.)
self.assertEqual(int(b.set_value(1)), 1) self.assertEqual(int(b.set_value(1)), 1)
...@@ -117,7 +117,7 @@ class TestTensor(unittest.TestCase): ...@@ -117,7 +117,7 @@ class TestTensor(unittest.TestCase):
x = dragon.EagerTensor(data, copy=True) x = dragon.EagerTensor(data, copy=True)
x_to_dlpack = dragon.dlpack.to_dlpack(x) x_to_dlpack = dragon.dlpack.to_dlpack(x)
x_from_dlpack = dragon.dlpack.from_dlpack(x_to_dlpack) x_from_dlpack = dragon.dlpack.from_dlpack(x_to_dlpack)
self.assertEqual(x_from_dlpack.shape, list(data.shape)) self.assertEqual(x_from_dlpack.shape, data.shape)
self.assertEqual(x_from_dlpack.dtype, str(data.dtype)) self.assertEqual(x_from_dlpack.dtype, str(data.dtype))
self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5) self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5)
...@@ -130,7 +130,7 @@ class TestTensor(unittest.TestCase): ...@@ -130,7 +130,7 @@ class TestTensor(unittest.TestCase):
x_from_dlpack = dragon.dlpack.from_dlpack(x_to_dlpack) x_from_dlpack = dragon.dlpack.from_dlpack(x_to_dlpack)
self.assertEqual(x_from_dlpack.device.type, 'cuda') self.assertEqual(x_from_dlpack.device.type, 'cuda')
self.assertEqual(x_from_dlpack.device.index, 0) self.assertEqual(x_from_dlpack.device.index, 0)
self.assertEqual(x_from_dlpack.shape, list(data.shape)) self.assertEqual(x_from_dlpack.shape, data.shape)
self.assertEqual(x_from_dlpack.dtype, str(data.dtype)) self.assertEqual(x_from_dlpack.dtype, str(data.dtype))
self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5) self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5)
...@@ -150,7 +150,7 @@ class TestWorkspace(unittest.TestCase): ...@@ -150,7 +150,7 @@ class TestWorkspace(unittest.TestCase):
w = dragon.Workspace() w = dragon.Workspace()
with w.as_default(): with w.as_default():
v1, v2 = dragon.EagerTensor(1), np.array(2) v1, v2 = dragon.EagerTensor(1), np.array(2)
x = dragon.Tensor('test_feed_tensor/x') x = dragon.Tensor(name='test_feed_tensor/x')
w.feed_tensor(x, v1) w.feed_tensor(x, v1)
self.assertEqual(int(x), 1) self.assertEqual(int(x), 1)
w.feed_tensor(x, v2) w.feed_tensor(x, v2)
...@@ -159,7 +159,7 @@ class TestWorkspace(unittest.TestCase): ...@@ -159,7 +159,7 @@ class TestWorkspace(unittest.TestCase):
def test_merge_form(self): def test_merge_form(self):
w1, w2 = dragon.Workspace(), dragon.Workspace() w1, w2 = dragon.Workspace(), dragon.Workspace()
with w1.as_default(): with w1.as_default():
x = dragon.Tensor('test_merge_from/x').set_value(0) x = dragon.Tensor(name='test_merge_from/x').set_value(0)
w2.merge_from(w1) w2.merge_from(w1)
with w2.as_default(): with w2.as_default():
self.assertEqual(int(x), 0) self.assertEqual(int(x), 0)
......
...@@ -67,7 +67,7 @@ class OpTestCase(unittest.TestCase): ...@@ -67,7 +67,7 @@ class OpTestCase(unittest.TestCase):
dtype = symbols[i][1].dtype dtype = symbols[i][1].dtype
shape = symbols[i][1].shape shape = symbols[i][1].shape
super(OpTestCase, self).assertEqual(dtype, str(values[i].dtype)) super(OpTestCase, self).assertEqual(dtype, str(values[i].dtype))
super(OpTestCase, self).assertEqual(shape, list(shape)) super(OpTestCase, self).assertEqual(shape, values[i].shape)
inputs[symbols[i][0]] = values[i] inputs[symbols[i][0]] = values[i]
first = inputs[:num_first] if num_first > 1 else inputs[0] first = inputs[:num_first] if num_first > 1 else inputs[0]
second = inputs[num_first:len(inputs)] if num_second > 1 else inputs[num_first] second = inputs[num_first:len(inputs)] if num_second > 1 else inputs[num_first]
...@@ -239,7 +239,7 @@ class TestActivationOps(OpTestCase): ...@@ -239,7 +239,7 @@ class TestActivationOps(OpTestCase):
for x_shape, w_shape, data_format in entries: for x_shape, w_shape, data_format in entries:
data1 = uniform(x_shape) data1 = uniform(x_shape)
data2 = np.ones(w_shape, 'float32') * 0.25 data2 = np.ones(w_shape, 'float32') * 0.25
x, w = new_tensor(data1), new_tensor(data2) x, w = new_tensor(data1), new_tensor(data2.flatten())
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([x, w]) tape.watch([x, w])
y = dragon.nn.prelu([x, w], data_format=data_format) y = dragon.nn.prelu([x, w], data_format=data_format)
...@@ -632,7 +632,7 @@ class TestArrayOps(OpTestCase): ...@@ -632,7 +632,7 @@ class TestArrayOps(OpTestCase):
tape.watch(x) tape.watch(x)
y = dragon.masked_select([x, x > 2]) y = dragon.masked_select([x, x > 2])
dx = tape.gradient(y, [x], output_gradients=[y])[0] dx = tape.gradient(y, [x], output_gradients=[y])[0]
self.assertEqual([y, dx], [data[data > 2], grad]) self.assertEqual([y, dx], [data[data > 2], grad], test_symbols=False)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_masked_select_cuda(self): def test_masked_select_cuda(self):
...@@ -645,7 +645,8 @@ class TestArrayOps(OpTestCase): ...@@ -645,7 +645,8 @@ class TestArrayOps(OpTestCase):
data = arange((2, 3)) data = arange((2, 3))
x = new_tensor(data) x = new_tensor(data)
y = dragon.nonzero(x > 2) y = dragon.nonzero(x > 2)
self.assertEqual(y, np.stack(np.nonzero(data > 2), axis=1)) self.assertEqual(
y, np.stack(np.nonzero(data > 2), axis=1), test_symbols=False)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_non_zero_cuda(self): def test_non_zero_cuda(self):
...@@ -699,7 +700,7 @@ class TestArrayOps(OpTestCase): ...@@ -699,7 +700,7 @@ class TestArrayOps(OpTestCase):
tape.watch(x) tape.watch(x)
y = dragon.repeat(x, axis, repeats) y = dragon.repeat(x, axis, repeats)
grad = arange(y.shape) grad = arange(y.shape)
grad_shape = y.shape[:-1] + [y.shape[-1] // 2, 2] grad_shape = y.shape[:-1] + (y.shape[-1] // 2, 2)
dy = new_tensor(grad) dy = new_tensor(grad)
dx = tape.gradient(y, [x], output_gradients=[dy])[0] dx = tape.gradient(y, [x], output_gradients=[dy])[0]
self.assertEqual( self.assertEqual(
...@@ -2271,8 +2272,8 @@ class TestNormalizationOps(OpTestCase): ...@@ -2271,8 +2272,8 @@ class TestNormalizationOps(OpTestCase):
data4, data5 = arange(w_shape) * .1, arange(w_shape, 1) * .1 data4, data5 = arange(w_shape) * .1, arange(w_shape, 1) * .1
data6 = uniform(x_shape) data6 = uniform(x_shape)
x = new_tensor(data1) x = new_tensor(data1)
w, b = new_tensor(data2), new_tensor(data3) w, b = new_tensor(data2.flatten()), new_tensor(data3.flatten())
rm, rv = new_tensor(data4), new_tensor(data5) rm, rv = new_tensor(data4.flatten()), new_tensor(data5.flatten())
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([x, w, b]) tape.watch([x, w, b])
y = dragon.nn.batch_norm( y = dragon.nn.batch_norm(
...@@ -2330,7 +2331,7 @@ class TestNormalizationOps(OpTestCase): ...@@ -2330,7 +2331,7 @@ class TestNormalizationOps(OpTestCase):
data2, data3 = arange(w_shape, 1) * .1, arange(w_shape) * .1 data2, data3 = arange(w_shape, 1) * .1, arange(w_shape) * .1
data6 = arange(x_shape) * .1 data6 = arange(x_shape) * .1
x = new_tensor(data1) x = new_tensor(data1)
w, b = new_tensor(data2), new_tensor(data3) w, b = new_tensor(data2.flatten()), new_tensor(data3.flatten())
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([x, w, b]) tape.watch([x, w, b])
y = dragon.nn.group_norm( y = dragon.nn.group_norm(
...@@ -2374,7 +2375,7 @@ class TestNormalizationOps(OpTestCase): ...@@ -2374,7 +2375,7 @@ class TestNormalizationOps(OpTestCase):
data2, data3 = arange(w_shape, 1) * .1, arange(w_shape) * .1 data2, data3 = arange(w_shape, 1) * .1, arange(w_shape) * .1
data6 = arange(x_shape) * 10. data6 = arange(x_shape) * 10.
x = new_tensor(data1) x = new_tensor(data1)
w, b = new_tensor(data2), new_tensor(data3) w, b = new_tensor(data2.flatten()), new_tensor(data3.flatten())
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([x, w, b]) tape.watch([x, w, b])
y = dragon.nn.instance_norm([x, w, b], axis=axis, eps=eps) y = dragon.nn.instance_norm([x, w, b], axis=axis, eps=eps)
...@@ -2417,7 +2418,7 @@ class TestNormalizationOps(OpTestCase): ...@@ -2417,7 +2418,7 @@ class TestNormalizationOps(OpTestCase):
data2, data3 = arange(w_shape, 1) * .1, arange(w_shape) * .1 data2, data3 = arange(w_shape, 1) * .1, arange(w_shape) * .1
data6 = arange(x_shape) * 10. data6 = arange(x_shape) * 10.
x = new_tensor(data1) x = new_tensor(data1)
w, b = new_tensor(data2), new_tensor(data3) w, b = new_tensor(data2.flatten()), new_tensor(data3.flatten())
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([x, w, b]) tape.watch([x, w, b])
y = dragon.nn.layer_norm([x, w, b], axis=axis, eps=eps) y = dragon.nn.layer_norm([x, w, b], axis=axis, eps=eps)
...@@ -2587,7 +2588,7 @@ class TestTensorOps(OpTestCase): ...@@ -2587,7 +2588,7 @@ class TestTensorOps(OpTestCase):
grad[data > 2] = 1 grad[data > 2] = 1
grad *= data grad *= data
x = new_tensor(data) x = new_tensor(data)
self.assertEqual(x[x > 2], data[data > 2]) self.assertEqual(x[x > 2], data[data > 2], test_symbols=False)
entries = [0, entries = [0,
slice(None, None, None), slice(None, None, None),
slice(0, None, None), slice(0, None, None),
...@@ -2885,13 +2886,13 @@ class TestVisionOps(OpTestCase): ...@@ -2885,13 +2886,13 @@ class TestVisionOps(OpTestCase):
with execution_context().mode(execution): with execution_context().mode(execution):
for x_shape, b_shape, data_format in entries: for x_shape, b_shape, data_format in entries:
data1, data2 = arange(x_shape), arange(b_shape) data1, data2 = arange(x_shape), arange(b_shape)
x, w = new_tensor(data1), new_tensor(data2) x, b = new_tensor(data1), new_tensor(data2.flatten())
with dragon.GradientTape() as tape: with dragon.GradientTape() as tape:
tape.watch([x, w]) tape.watch([x, b])
y = dragon.nn.bias_add([x, w], data_format) y = dragon.nn.bias_add([x, b], data_format)
dx, dw = tape.gradient(y, [x, w], output_gradients=[x]) dx, db = tape.gradient(y, [x, b], output_gradients=[x])
self.assertEqual( self.assertEqual(
[y, dx, dw], [y, dx, db],
[data1 + data2, data1, reduce_like(data1, data2).flatten()]) [data1 + data2, data1, reduce_like(data1, data2).flatten()])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
...@@ -3422,7 +3423,7 @@ def new_tensor(data): ...@@ -3422,7 +3423,7 @@ def new_tensor(data):
"""Create a new tensor for current execution.""" """Create a new tensor for current execution."""
if execution_context().executing_eagerly(): if execution_context().executing_eagerly():
return dragon.EagerTensor(data, copy=True) return dragon.EagerTensor(data, copy=True)
return dragon.Tensor(None, data.shape, str(data.dtype)).set_value(data) return dragon.Tensor(data.shape, str(data.dtype)).set_value(data)
def process_indices(item): def process_indices(item):
......
...@@ -24,6 +24,7 @@ from dragon.vm.torch import vision ...@@ -24,6 +24,7 @@ from dragon.vm.torch import vision
# Classes # Classes
from dragon.vm.torch.autograd import Variable from dragon.vm.torch.autograd import Variable
from dragon.vm.torch.cpp import device from dragon.vm.torch.cpp import device
from dragon.vm.torch.cpp import dtype
from dragon.vm.torch.cpp import Size from dragon.vm.torch.cpp import Size
from dragon.vm.torch.tensor import ByteTensor from dragon.vm.torch.tensor import ByteTensor
from dragon.vm.torch.tensor import CharTensor from dragon.vm.torch.tensor import CharTensor
...@@ -116,3 +117,21 @@ from dragon.vm.torch.serialization import load ...@@ -116,3 +117,21 @@ from dragon.vm.torch.serialization import load
from dragon.vm.torch.serialization import save from dragon.vm.torch.serialization import save
from dragon.vm.torch.tensor import empty from dragon.vm.torch.tensor import empty
from dragon.vm.torch.tensor import tensor from dragon.vm.torch.tensor import tensor
# Aliases
bool = dtype('bool')
int8 = dtype('int8')
uint8 = dtype('uint8')
int16 = short = dtype('int16')
int32 = int = dtype('int32')
int64 = long = dtype('int64')
qint8 = dtype('qint8')
quint8 = dtype('quint8')
qint32 = dtype('qint32')
bfloat16 = dtype('bfloat16')
float16 = half = dtype('float16')
float32 = float = dtype('float32')
float64 = double = dtype('float64')
complex32 = dtype('complex32')
complex64 = dtype('complex64')
complex128 = dtype('complex128')
...@@ -58,7 +58,7 @@ class Size(tuple): ...@@ -58,7 +58,7 @@ class Size(tuple):
class device(object): class device(object):
"""Represent the device where tensor will be allocated.""" """Represent the device spec."""
def __init__(self, type='cpu', index=0): def __init__(self, type='cpu', index=0):
"""Create a ``device``. """Create a ``device``.
...@@ -100,6 +100,57 @@ class device(object): ...@@ -100,6 +100,57 @@ class device(object):
return 'device(type={}, index={})'.format(self.type, self.index) return 'device(type={}, index={})'.format(self.type, self.index)
class dtype(str):
"""The basic data type.
Following data types are defined:
* ``torch.float16`` or ``torch.half``: 16-bit half-precision floating-point.
* ``torch.float32`` or ``torch.float``: 32-bit single-precision floating-point.
* ``torch.float64`` or ``torch.double``: 64-bit double-precision floating-point.
* ``torch.bfloat16``: 16-bit truncated floating-point.
* ``torch.complex32``: 32-bit single-precision complex.
* ``torch.complex64``: 64-bit single-precision complex.
* ``torch.complex128``: 128-bit double-precision complex.
* ``torch.int8``: 8-bit signed integer.
* ``torch.uint8``: 8-bit unsigned integer.
* ``torch.int16`` or ``torch.short``: 16-bit signed integer.
* ``torch.int32`` or ``torch.int``: 32-bit signed integer.
* ``torch.int64`` or ``torch.long``: 64-bit signed integer.
* ``torch.bool``: Boolean.
* ``torch.qint8``: Quantized 8-bit signed integer.
* ``torch.quint8``: Quantized 8-bit unsigned integer.
* ``torch.qint32``: Quantized 32-bit signed integer.
"""
def __init__(self, s):
"""Create a ``dtype``.
Parameters
----------
s : str
The data type descriptor.
"""
super(dtype, self).__init__()
def from_numpy(array): def from_numpy(array):
"""Create a tensor from the given numpy array. """Create a tensor from the given numpy array.
......
...@@ -636,7 +636,7 @@ def getitem(self, item): ...@@ -636,7 +636,7 @@ def getitem(self, item):
Parameters Parameters
---------- ----------
item : Union[int, slice, dragon.vm.torch.Tensor] item : Union[slice, int, dragon.vm.torch.Tensor]
The index. The index.
Returns Returns
...@@ -1318,7 +1318,7 @@ def setitem(self, key, value): ...@@ -1318,7 +1318,7 @@ def setitem(self, key, value):
Parameters Parameters
---------- ----------
key : Union[int, slice, dragon.vm.torch.Tensor] key : Union[slice, int, dragon.vm.torch.Tensor]
The index. The index.
value : Union[dragon.vm.torch.Tensor, number] value : Union[dragon.vm.torch.Tensor, number]
The value to set. The value to set.
......
...@@ -1940,7 +1940,7 @@ class Tensor(object): ...@@ -1940,7 +1940,7 @@ class Tensor(object):
Parameters Parameters
---------- ----------
item : Union[int, slice, dragon.vm.torch.Tensor] item : Union[slice, int, dragon.vm.torch.Tensor]
The index. The index.
Returns Returns
...@@ -2198,7 +2198,7 @@ class Tensor(object): ...@@ -2198,7 +2198,7 @@ class Tensor(object):
Parameters Parameters
---------- ----------
key : Union[int, slice, dragon.vm.torch.Tensor] key : Union[slice, int, dragon.vm.torch.Tensor]
The index. The index.
value : Union[dragon.vm.torch.Tensor, number] value : Union[dragon.vm.torch.Tensor, number]
The value to set. The value to set.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!