Commit 9fc5249b by Ting PAN

Simplify the feeding to repeated tensor arguments

Summary:
This commit feeds the repeated tensor arguments with
the entire array instead of the piecewise scalars.
1 parent b93bde0d
...@@ -26,7 +26,7 @@ run ...@@ -26,7 +26,7 @@ run
<style> <style>
h1:before { h1:before {
content: "tensorrt."; content: "tensorrt.onnx.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -57,6 +57,7 @@ class Operator(object): ...@@ -57,6 +57,7 @@ class Operator(object):
The attribute dict. The attribute dict.
""" """
return {'op_type': self.__class__.__name__, 'arguments': {}}
@classmethod @classmethod
def blend(cls, op_type=None, **kwargs): def blend(cls, op_type=None, **kwargs):
...@@ -76,7 +77,7 @@ class Operator(object): ...@@ -76,7 +77,7 @@ class Operator(object):
pre_callback=callback, pre_callback=callback,
) )
def feed_arg(self, ws, name, value, dtype='int64'): def feed_arg(self, ws, name, value, dtype):
"""Set the value of tensor argument.""" """Set the value of tensor argument."""
ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device) ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device)
......
...@@ -35,23 +35,24 @@ class Activation(Operator): ...@@ -35,23 +35,24 @@ class Activation(Operator):
class Dropout(Operator): class Dropout(Operator):
"""Dropout operator.""" """Dropout operator."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Dropout', 'op_type': 'Dropout',
'arguments': {'ratio_desc': '${HANDLE}/ratio'}, 'arguments': {
'ratio_desc': '${HANDLE}/ratio',
},
} }
def feed(self, ws, handle, ratio): def setup(self, ws, handle, ratio):
self.feed_arg(ws, '{}/ratio'.format(handle), ratio, 'float32') self.feed_arg(ws, '%s/ratio' % handle, ratio, 'float32')
def forward(self, inputs, ratio, inplace=False): def forward(self, inputs, ratio, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()] outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs, return self.dispatch(
callback=lambda ws, handle: inputs, outputs,
self.feed(ws, handle, ratio)) callback=lambda ws, handle:
self.setup(ws, handle, ratio),
)
class DropBlock2d(Dropout): class DropBlock2d(Dropout):
...@@ -66,9 +67,9 @@ class DropBlock2d(Dropout): ...@@ -66,9 +67,9 @@ class DropBlock2d(Dropout):
return { return {
'op_type': 'DropBlock2d', 'op_type': 'DropBlock2d',
'arguments': { 'arguments': {
'ratio_desc': '${HANDLE}/ratio',
'block_size': self.block_size, 'block_size': self.block_size,
'data_format': self.data_format, 'data_format': self.data_format,
'ratio_desc': '${HANDLE}/ratio',
}, },
} }
...@@ -76,13 +77,12 @@ class DropBlock2d(Dropout): ...@@ -76,13 +77,12 @@ class DropBlock2d(Dropout):
class DropPath(Dropout): class DropPath(Dropout):
"""DropPath operator.""" """DropPath operator."""
def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'DropPath', 'op_type': 'DropPath',
'arguments': {'ratio_desc': '${HANDLE}/ratio'}, 'arguments': {
'ratio_desc': '${HANDLE}/ratio',
},
} }
...@@ -96,7 +96,9 @@ class Elu(Activation): ...@@ -96,7 +96,9 @@ class Elu(Activation):
def attributes(self): def attributes(self):
return { return {
'op_type': 'Elu', 'op_type': 'Elu',
'arguments': {'alpha': float(self.alpha)}, 'arguments': {
'alpha': float(self.alpha),
},
} }
...@@ -146,7 +148,9 @@ class PRelu(Operator): ...@@ -146,7 +148,9 @@ class PRelu(Operator):
def attributes(self): def attributes(self):
return { return {
'op_type': 'PRelu', 'op_type': 'PRelu',
'arguments': {'data_format': self.data_format}, 'arguments': {
'data_format': self.data_format,
},
} }
def forward(self, inputs): def forward(self, inputs):
...@@ -163,20 +167,21 @@ class Relu(Activation): ...@@ -163,20 +167,21 @@ class Relu(Activation):
def attributes(self): def attributes(self):
return { return {
'op_type': 'Relu', 'op_type': 'Relu',
'arguments': {'alpha': float(self.alpha)}, 'arguments': {
'alpha': float(self.alpha),
},
} }
class Relu6(Activation): class Relu6(Activation):
"""Relu6 operator.""" """Relu6 operator."""
def __init__(self, key, dev, **kwargs):
super(Relu6, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Relu', 'op_type': 'Relu',
'arguments': {'max_value': 6.}, 'arguments': {
'max_value': 6.,
},
} }
...@@ -208,5 +213,7 @@ class Softmax(Activation): ...@@ -208,5 +213,7 @@ class Softmax(Activation):
def attributes(self): def attributes(self):
return { return {
'op_type': 'Softmax', 'op_type': 'Softmax',
'arguments': {'axis': self.axis}, 'arguments': {
'axis': self.axis,
},
} }
...@@ -20,38 +20,25 @@ from dragon.core.framework.ops import Operator ...@@ -20,38 +20,25 @@ from dragon.core.framework.ops import Operator
class Assign(Operator): class Assign(Operator):
"""Assign operator.""" """Assign operator."""
def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Assign', 'op_type': 'Assign',
'arguments': { 'arguments': {
'starts_descs': [ 'starts_desc': '${HANDLE}/starts',
'${{HANDLE}}/starts[{}]' 'sizes_desc': '${HANDLE}/sizes',
.format(n) for n in range(self.ndim)],
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.ndim)],
}, },
} }
def feed(self, ws, handle, starts, sizes): def setup(self, ws, handle, starts, sizes):
for i in range(self.ndim): self.feed_arg(ws, '%s/starts' % handle, starts, 'int64')
self.feed_arg( self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
ws, '{}/starts[{}]'.format(handle, i),
starts[i], 'int64')
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
def forward(self, inputs, starts, sizes, inplace=False): def forward(self, inputs, starts, sizes, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()] outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch( return self.dispatch(
inputs, outputs, inputs, outputs,
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes), self.setup(ws, handle, starts, sizes),
no_grad=True, no_grad=True,
) )
...@@ -59,12 +46,6 @@ class Assign(Operator): ...@@ -59,12 +46,6 @@ class Assign(Operator):
class MaskedAssign(Operator): class MaskedAssign(Operator):
"""MaskedAssign operator.""" """MaskedAssign operator."""
def __init__(self, key, dev, **kwargs):
super(MaskedAssign, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'MaskedAssign', 'arguments': {}}
def forward(self, inputs, inplace=False): def forward(self, inputs, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()] outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs, no_grad=True) return self.dispatch(inputs, outputs, no_grad=True)
...@@ -25,18 +25,13 @@ class Initializer(Operator): ...@@ -25,18 +25,13 @@ class Initializer(Operator):
self.ndim = kwargs.get('ndim', 0) self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32') self.dtype = kwargs.get('dtype', 'float32')
def feed(self, ws, handle, shape): def setup(self, ws, handle, shape):
for i, dim in enumerate(shape): self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
dim, 'int64')
def forward(self, shape, shape_as=None, out=None, trainable=None): def forward(self, shape, shape_as=None, out=None, trainable=None):
out = self.dispatch( out = self.dispatch(
[] if shape_as is None else [shape_as], [] if shape_as is None else [shape_as], [self.alloc(out)],
[self.alloc(out)], callback=lambda ws, handle: self.setup(ws, handle, shape),
callback=lambda ws, handle:
self.feed(ws, handle, shape),
no_grad=True, no_grad=True,
) )
if trainable is not None: if trainable is not None:
...@@ -57,9 +52,7 @@ class Eye(Initializer): ...@@ -57,9 +52,7 @@ class Eye(Initializer):
'arguments': { 'arguments': {
'k': self.k, 'k': self.k,
'dtype': self.dtype, 'dtype': self.dtype,
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -75,9 +68,7 @@ class Fill(Initializer): ...@@ -75,9 +68,7 @@ class Fill(Initializer):
'arguments': { 'arguments': {
'dtype': self.dtype, 'dtype': self.dtype,
'value': float(self.value), 'value': float(self.value),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -97,9 +88,7 @@ class GlorotNormal(Initializer): ...@@ -97,9 +88,7 @@ class GlorotNormal(Initializer):
'dtype': self.dtype, 'dtype': self.dtype,
'scale': float(self.scale), 'scale': float(self.scale),
'mode': self.mode.lower(), 'mode': self.mode.lower(),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -119,9 +108,7 @@ class GlorotUniform(Initializer): ...@@ -119,9 +108,7 @@ class GlorotUniform(Initializer):
'dtype': self.dtype, 'dtype': self.dtype,
'scale': float(self.scale), 'scale': float(self.scale),
'mode': self.mode.lower(), 'mode': self.mode.lower(),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -141,9 +128,7 @@ class RandomNormal(Initializer): ...@@ -141,9 +128,7 @@ class RandomNormal(Initializer):
'dtype': self.dtype, 'dtype': self.dtype,
'mean': float(self.mean), 'mean': float(self.mean),
'std': float(self.std), 'std': float(self.std),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -163,9 +148,7 @@ class RandomUniform(Initializer): ...@@ -163,9 +148,7 @@ class RandomUniform(Initializer):
'dtype': self.dtype, 'dtype': self.dtype,
'low': float(self.low), 'low': float(self.low),
'high': float(self.high), 'high': float(self.high),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -185,8 +168,6 @@ class TruncatedNormal(Initializer): ...@@ -185,8 +168,6 @@ class TruncatedNormal(Initializer):
'dtype': self.dtype, 'dtype': self.dtype,
'mean': float(self.mean), 'mean': float(self.mean),
'std': float(self.std), 'std': float(self.std),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -29,7 +29,7 @@ class Loss(Operator): ...@@ -29,7 +29,7 @@ class Loss(Operator):
'op_type': self.__class__.__name__, 'op_type': self.__class__.__name__,
'arguments': { 'arguments': {
'reduction': self.reduction, 'reduction': self.reduction,
} },
} }
def forward(self, inputs): def forward(self, inputs):
...@@ -39,16 +39,10 @@ class Loss(Operator): ...@@ -39,16 +39,10 @@ class Loss(Operator):
class L1Loss(Loss): class L1Loss(Loss):
"""L1Loss operator.""" """L1Loss operator."""
def __init__(self, key, dev, **kwargs):
super(L1Loss, self).__init__(key, dev, **kwargs)
class L2Loss(Loss): class L2Loss(Loss):
"""L2Loss operator.""" """L2Loss operator."""
def __init__(self, key, dev, **kwargs):
super(L2Loss, self).__init__(key, dev, **kwargs)
class NLLLoss(Loss): class NLLLoss(Loss):
"""NLLLoss operator.""" """NLLLoss operator."""
...@@ -72,9 +66,6 @@ class NLLLoss(Loss): ...@@ -72,9 +66,6 @@ class NLLLoss(Loss):
class SigmoidCrossEntropy(Loss): class SigmoidCrossEntropy(Loss):
"""SigmoidCrossEntropy operator.""" """SigmoidCrossEntropy operator."""
def __init__(self, key, dev, **kwargs):
super(SigmoidCrossEntropy, self).__init__(key, dev, **kwargs)
class SmoothL1Loss(Loss): class SmoothL1Loss(Loss):
"""SmoothL1Loss operator.""" """SmoothL1Loss operator."""
......
...@@ -37,7 +37,7 @@ class BatchNorm(Operator): ...@@ -37,7 +37,7 @@ class BatchNorm(Operator):
'momentum': self.momentum, 'momentum': self.momentum,
'epsilon': self.epsilon, 'epsilon': self.epsilon,
'use_stats': self.use_stats, 'use_stats': self.use_stats,
} },
} }
def forward(self, inputs): def forward(self, inputs):
...@@ -60,7 +60,7 @@ class GroupNorm(Operator): ...@@ -60,7 +60,7 @@ class GroupNorm(Operator):
'axis': self.axis, 'axis': self.axis,
'group': self.group, 'group': self.group,
'epsilon': self.epsilon, 'epsilon': self.epsilon,
} },
} }
def forward(self, inputs): def forward(self, inputs):
...@@ -87,7 +87,7 @@ class LpNormalize(Operator): ...@@ -87,7 +87,7 @@ class LpNormalize(Operator):
'num_axes': self.num_axes, 'num_axes': self.num_axes,
'epsilon': self.epsilon, 'epsilon': self.epsilon,
'reduction': self.reduction, 'reduction': self.reduction,
} },
} }
def forward(self, inputs): def forward(self, inputs):
...@@ -114,7 +114,7 @@ class LocalResponseNorm(Operator): ...@@ -114,7 +114,7 @@ class LocalResponseNorm(Operator):
'beta': self.beta, 'beta': self.beta,
'bias': self.bias, 'bias': self.bias,
'data_format': self.data_format, 'data_format': self.data_format,
} },
} }
def forward(self, inputs): def forward(self, inputs):
......
...@@ -20,12 +20,6 @@ from dragon.core.framework.ops import Operator ...@@ -20,12 +20,6 @@ from dragon.core.framework.ops import Operator
class LSTMCell(Operator): class LSTMCell(Operator):
"""LSTMCell operator.""" """LSTMCell operator."""
def __init__(self, key, dev, **kwargs):
super(LSTMCell, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'LSTMCell', 'arguments': {}}
def forward(self, inputs): def forward(self, inputs):
outputs = [self.alloc() for _ in range(2)] outputs = [self.alloc() for _ in range(2)]
return self.dispatch(inputs, outputs) return self.dispatch(inputs, outputs)
...@@ -54,7 +48,7 @@ class Recurrent(Operator): ...@@ -54,7 +48,7 @@ class Recurrent(Operator):
'rnn_input_mode': 'linear', 'rnn_input_mode': 'linear',
'dropout_ratio': self.dropout_ratio, 'dropout_ratio': self.dropout_ratio,
'phase': 'TRAIN' if self.is_training else 'TEST' 'phase': 'TRAIN' if self.is_training else 'TEST'
} },
} }
def forward(self, inputs): def forward(self, inputs):
...@@ -87,7 +81,7 @@ class RNNParamSet(Operator): ...@@ -87,7 +81,7 @@ class RNNParamSet(Operator):
'layer_id': self.layer_id, 'layer_id': self.layer_id,
'param_id': self.param_id, 'param_id': self.param_id,
'rnn_mode': self.mode, 'rnn_mode': self.mode,
} },
} }
def forward(self, inputs): def forward(self, inputs):
......
...@@ -91,7 +91,9 @@ class BiasAdd(Operator): ...@@ -91,7 +91,9 @@ class BiasAdd(Operator):
def attributes(self): def attributes(self):
return { return {
'op_type': 'BiasAdd', 'op_type': 'BiasAdd',
'arguments': {'data_format': self.data_format}, 'arguments': {
'data_format': self.data_format,
},
} }
def forward(self, inputs, inplace=False): def forward(self, inputs, inplace=False):
...@@ -102,9 +104,6 @@ class BiasAdd(Operator): ...@@ -102,9 +104,6 @@ class BiasAdd(Operator):
class Conv2d(ConvNd): class Conv2d(ConvNd):
"""Conv2d operator.""" """Conv2d operator."""
def __init__(self, key, dev, **kwargs):
super(Conv2d, self).__init__(key, dev, **kwargs)
class ConvTranspose2d(ConvNd): class ConvTranspose2d(ConvNd):
"""ConvTranspose2d operator.""" """ConvTranspose2d operator."""
...@@ -154,16 +153,10 @@ class DepthToSpace(Operator): ...@@ -154,16 +153,10 @@ class DepthToSpace(Operator):
class DepthwiseConv2d(ConvNd): class DepthwiseConv2d(ConvNd):
"""DepthwiseConv2d operator.""" """DepthwiseConv2d operator."""
def __init__(self, key, dev, **kwargs):
super(DepthwiseConv2d, self).__init__(key, dev, **kwargs)
class Pool2d(PoolNd): class Pool2d(PoolNd):
"""Pool2d operator.""" """Pool2d operator."""
def __init__(self, key, dev, **kwargs):
super(Pool2d, self).__init__(key, dev, **kwargs)
class Resize(Operator): class Resize(Operator):
"""Resize operator.""" """Resize operator."""
...@@ -182,31 +175,25 @@ class Resize(Operator): ...@@ -182,31 +175,25 @@ class Resize(Operator):
'arguments': { 'arguments': {
'mode': self.mode, 'mode': self.mode,
'align_corners': self.align_corners, 'align_corners': self.align_corners,
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.num_sizes)],
'scales_descs': [
'${{HANDLE}}/scales[{}]'
.format(n) for n in range(self.num_scales)],
'data_format': self.data_format, 'data_format': self.data_format,
} 'sizes_desc': '${HANDLE}/sizes'
if self.num_sizes > 0 else None,
'scales_desc': '${HANDLE}/scales'
if self.num_scales > 0 else None,
},
} }
def feed(self, ws, handle, sizes, scales): def setup(self, ws, handle, sizes, scales):
for i in range(self.num_sizes): if sizes is not None:
self.feed_arg( self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
ws, '{}/sizes[{}]'.format(handle, i), if scales is not None:
sizes[i], 'int64') self.feed_arg(ws, '%s/scales' % handle, scales, 'float32')
for i in range(self.num_scales):
self.feed_arg(
ws, '{}/scales[{}]'.format(handle, i),
scales[i], 'float32')
def forward(self, inputs, sizes=None, scales=None): def forward(self, inputs, sizes=None, scales=None):
return self.dispatch( return self.dispatch(
inputs, [self.alloc()], inputs, [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, sizes, scales), self.setup(ws, handle, sizes, scales),
) )
......
...@@ -102,6 +102,7 @@ class Function(object): ...@@ -102,6 +102,7 @@ class Function(object):
The attribute dict. The attribute dict.
""" """
return {'op_type': self.__class__.__name__, 'arguments': {}}
def dispatch( def dispatch(
self, self,
...@@ -124,7 +125,7 @@ class Function(object): ...@@ -124,7 +125,7 @@ class Function(object):
pre_callback=callback, pre_callback=callback,
) )
def feed_arg(self, ws, name, value, dtype='int64'): def feed_arg(self, ws, name, value, dtype):
"""Set the value of tensor argument.""" """Set the value of tensor argument."""
ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device) ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device)
......
...@@ -134,16 +134,10 @@ class BatchNorm(function.Function): ...@@ -134,16 +134,10 @@ class BatchNorm(function.Function):
class Conv2d(_ConvNd): class Conv2d(_ConvNd):
"""Conv2d function.""" """Conv2d function."""
def __init__(self, key, dev, **kwargs):
super(Conv2d, self).__init__(key, dev, **kwargs)
class ConvTranspose2d(_ConvNd): class ConvTranspose2d(_ConvNd):
"""ConvTranspose2d function.""" """ConvTranspose2d function."""
def __init__(self, key, dev, **kwargs):
super(ConvTranspose2d, self).__init__(key, dev, **kwargs)
class CTCLoss(_Loss): class CTCLoss(_Loss):
"""CTCLoss function.""" """CTCLoss function."""
...@@ -165,30 +159,28 @@ class CTCLoss(_Loss): ...@@ -165,30 +159,28 @@ class CTCLoss(_Loss):
class DepthwiseConv2d(_ConvNd): class DepthwiseConv2d(_ConvNd):
"""DepthwiseConv2d function.""" """DepthwiseConv2d function."""
def __init__(self, key, dev, **kwargs):
super(DepthwiseConv2d, self).__init__(key, dev, **kwargs)
class Dropout(function.Function): class Dropout(function.Function):
"""Dropout function.""" """Dropout function."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Dropout', 'op_type': 'Dropout',
'arguments': {'ratio_desc': '${HANDLE}/ratio'}, 'arguments': {
'ratio_desc': '${HANDLE}/ratio',
},
} }
def feed(self, ws, handle, ratio): def setup(self, ws, handle, ratio):
self.feed_arg(ws, '{}/ratio'.format(handle), ratio, 'float32') self.feed_arg(ws, '{}/ratio'.format(handle), ratio, 'float32')
def forward(self, input, ratio, inplace=False): def forward(self, input, ratio, inplace=False):
out = input if inplace else self.alloc() out = input if inplace else self.alloc()
return self.dispatch([input], [out], return self.dispatch(
callback=lambda ws, handle: [input], [out],
self.feed(ws, handle, ratio)) callback=lambda ws, handle:
self.setup(ws, handle, ratio),
)
class DropBlock2d(Dropout): class DropBlock2d(Dropout):
...@@ -202,9 +194,9 @@ class DropBlock2d(Dropout): ...@@ -202,9 +194,9 @@ class DropBlock2d(Dropout):
return { return {
'op_type': 'DropBlock2d', 'op_type': 'DropBlock2d',
'arguments': { 'arguments': {
'ratio_desc': '${HANDLE}/ratio',
'block_size': self.block_size,
'data_format': 'NCHW', 'data_format': 'NCHW',
'block_size': self.block_size,
'ratio_desc': '${HANDLE}/ratio',
} }
} }
...@@ -212,13 +204,12 @@ class DropBlock2d(Dropout): ...@@ -212,13 +204,12 @@ class DropBlock2d(Dropout):
class DropPath(Dropout): class DropPath(Dropout):
"""DropPath function.""" """DropPath function."""
def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'DropPath', 'op_type': 'DropPath',
'arguments': {'ratio_desc': '${HANDLE}/ratio'}, 'arguments': {
'ratio_desc': '${HANDLE}/ratio',
},
} }
...@@ -253,7 +244,7 @@ class GroupNorm(function.Function): ...@@ -253,7 +244,7 @@ class GroupNorm(function.Function):
'axis': 1, 'axis': 1,
'group': self.group, 'group': self.group,
'epsilon': self.epsilon, 'epsilon': self.epsilon,
} },
} }
def forward(self, input, weight, bias): def forward(self, input, weight, bias):
...@@ -299,41 +290,32 @@ class HardSwish(_Activation): ...@@ -299,41 +290,32 @@ class HardSwish(_Activation):
class L1Loss(_Loss): class L1Loss(_Loss):
"""L1Loss function.""" """L1Loss function."""
def __init__(self, key, dev, **kwargs):
super(L1Loss, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'L1Loss', 'op_type': 'L1Loss',
'arguments': { 'arguments': {
'scale': 1., 'scale': 1.,
'reduction': self.reduction, 'reduction': self.reduction,
} },
} }
class L2Loss(_Loss): class L2Loss(_Loss):
"""L2Loss function.""" """L2Loss function."""
def __init__(self, key, dev, **kwargs):
super(L2Loss, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'L2Loss', 'op_type': 'L2Loss',
'arguments': { 'arguments': {
'scale': 2., 'scale': 2.,
'reduction': self.reduction, 'reduction': self.reduction,
} },
} }
class Linear(function.Function): class Linear(function.Function):
"""Linear function.""" """Linear function."""
def __init__(self, key, dev, **kwargs):
super(Linear, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'FullyConnected', 'op_type': 'FullyConnected',
...@@ -368,7 +350,7 @@ class LocalResponseNorm(function.Function): ...@@ -368,7 +350,7 @@ class LocalResponseNorm(function.Function):
'beta': self.beta, 'beta': self.beta,
'bias': self.bias, 'bias': self.bias,
'data_format': 'NCHW', 'data_format': 'NCHW',
} },
} }
def forward(self, input): def forward(self, input):
...@@ -393,7 +375,7 @@ class LpNormalize(function.Function): ...@@ -393,7 +375,7 @@ class LpNormalize(function.Function):
'epsilon': self.epsilon, 'epsilon': self.epsilon,
'num_axes': 1, 'num_axes': 1,
'reduction': 'SUM', 'reduction': 'SUM',
} },
} }
def forward(self, input, out=None): def forward(self, input, out=None):
...@@ -403,12 +385,6 @@ class LpNormalize(function.Function): ...@@ -403,12 +385,6 @@ class LpNormalize(function.Function):
class LSTMCell(function.Function): class LSTMCell(function.Function):
"""LSTMCell function.""" """LSTMCell function."""
def __init__(self, key, dev, **kwargs):
super(LSTMCell, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'LSTMCell', 'arguments': {}}
def forward(self, input, cx): def forward(self, input, cx):
outputs = [self.alloc() for _ in range(2)] outputs = [self.alloc() for _ in range(2)]
return self.dispatch([input, cx], outputs) return self.dispatch([input, cx], outputs)
...@@ -428,7 +404,7 @@ class NLLLoss(_Loss): ...@@ -428,7 +404,7 @@ class NLLLoss(_Loss):
'axis': 1, 'axis': 1,
'reduction': self.reduction, 'reduction': self.reduction,
'ignore_index': self.ignore_index, 'ignore_index': self.ignore_index,
} },
} }
...@@ -437,7 +413,6 @@ class Pad(function.Function): ...@@ -437,7 +413,6 @@ class Pad(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Pad, self).__init__(key, dev, **kwargs) super(Pad, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.value = kwargs.get('value', 0.) self.value = kwargs.get('value', 0.)
self.mode = kwargs.get('mode', 'CONSTANT') self.mode = kwargs.get('mode', 'CONSTANT')
...@@ -447,42 +422,28 @@ class Pad(function.Function): ...@@ -447,42 +422,28 @@ class Pad(function.Function):
'arguments': { 'arguments': {
'mode': self.mode, 'mode': self.mode,
'value': self.value, 'value': self.value,
'pads_descs': [ 'pads_desc': '${HANDLE}/pads',
'${{HANDLE}}/pads[{}]' },
.format(n) for n in range(self.ndim * 2)
],
}
} }
def feed(self, ws, handle, pads): def setup(self, ws, handle, pads):
for i, e in enumerate(pads): self.feed_arg(ws, '%s/pads' % handle, pads, 'int64')
self.feed_arg(
ws,
'{}/pads[{}]'.format(handle, i),
e, 'int64'
)
def forward(self, input, pads): def forward(self, input, pads):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, pads), self.setup(ws, handle, pads),
) )
class Pool2d(_PoolNd): class Pool2d(_PoolNd):
"""Pool2d function.""" """Pool2d function."""
def __init__(self, key, dev, **kwargs):
super(Pool2d, self).__init__(key, dev, **kwargs)
class PRelu(function.Function): class PRelu(function.Function):
"""PRelu function.""" """PRelu function."""
def __init__(self, key, dev, **kwargs):
super(PRelu, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'PRelu', 'op_type': 'PRelu',
...@@ -552,9 +513,6 @@ class Relu(_Activation): ...@@ -552,9 +513,6 @@ class Relu(_Activation):
class Relu6(_Activation): class Relu6(_Activation):
"""Relu6 function.""" """Relu6 function."""
def __init__(self, key, dev, **kwargs):
super(Relu6, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Relu', 'op_type': 'Relu',
...@@ -583,30 +541,24 @@ class Resize(function.Function): ...@@ -583,30 +541,24 @@ class Resize(function.Function):
'mode': self.mode, 'mode': self.mode,
'align_corners': self.align_corners, 'align_corners': self.align_corners,
'data_format': 'NCHW', 'data_format': 'NCHW',
'sizes_descs': [ 'sizes_desc': '${HANDLE}/sizes'
'${{HANDLE}}/sizes[{}]' if self.num_sizes > 0 else None,
.format(n) for n in range(self.num_sizes)], 'scales_desc': '${HANDLE}/scales'
'scales_descs': [ if self.num_scales > 0 else None,
'${{HANDLE}}/scales[{}]' },
.format(n) for n in range(self.num_scales)],
}
} }
def feed(self, ws, handle, sizes, scales): def setup(self, ws, handle, sizes, scales):
for i in range(self.num_sizes): if sizes is not None:
self.feed_arg( self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
ws, '{}/sizes[{}]'.format(handle, i), if scales is not None:
sizes[i], 'int64') self.feed_arg(ws, '%s/scales' % handle, scales, 'float32')
for i in range(self.num_scales):
self.feed_arg(
ws, '{}/scales[{}]'.format(handle, i),
scales[i], 'float32')
def forward(self, input, sizes=None, scales=None): def forward(self, input, sizes=None, scales=None):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, sizes, scales), self.setup(ws, handle, sizes, scales),
) )
...@@ -646,15 +598,12 @@ class RNNParamSet(function.Function): ...@@ -646,15 +598,12 @@ class RNNParamSet(function.Function):
class SigmoidCrossEntropy(_Loss): class SigmoidCrossEntropy(_Loss):
"""SigmoidCrossEntropy function.""" """SigmoidCrossEntropy function."""
def __init__(self, key, dev, **kwargs):
super(SigmoidCrossEntropy, self).__init__(key, dev, **kwargs)
def attributes(self): def attributes(self):
return { return {
'op_type': 'SigmoidCrossEntropy', 'op_type': 'SigmoidCrossEntropy',
'arguments': { 'arguments': {
'reduction': self.reduction, 'reduction': self.reduction,
} },
} }
...@@ -693,7 +642,7 @@ class SmoothL1Loss(_Loss): ...@@ -693,7 +642,7 @@ class SmoothL1Loss(_Loss):
'arguments': { 'arguments': {
'beta': float(self.beta), 'beta': float(self.beta),
'reduction': self.reduction, 'reduction': self.reduction,
} },
} }
...@@ -727,7 +676,7 @@ class SparseSoftmaxCrossEntropy(_Loss): ...@@ -727,7 +676,7 @@ class SparseSoftmaxCrossEntropy(_Loss):
'axis': 1, 'axis': 1,
'reduction': self.reduction, 'reduction': self.reduction,
'ignore_index': self.ignore_index, 'ignore_index': self.ignore_index,
} },
} }
......
...@@ -42,38 +42,25 @@ class ArgReduce(function.Function): ...@@ -42,38 +42,25 @@ class ArgReduce(function.Function):
class Assign(function.Function): class Assign(function.Function):
"""Assign function.""" """Assign function."""
def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Assign', 'op_type': 'Assign',
'arguments': { 'arguments': {
'starts_descs': [ 'starts_desc': '${HANDLE}/starts',
'${{HANDLE}}/starts[{}]' 'sizes_desc': '${HANDLE}/sizes',
.format(n) for n in range(self.ndim)],
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.ndim)],
}, },
} }
def feed(self, ws, handle, starts, sizes): def setup(self, ws, handle, starts, sizes):
for i in range(self.ndim): self.feed_arg(ws, '%s/starts' % handle, starts, 'int64')
self.feed_arg( self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
ws, '{}/starts[{}]'.format(handle, i),
starts[i], 'int64')
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
def forward(self, out, starts, sizes, input): def forward(self, out, starts, sizes, input):
self._check_device([input, out]) self._check_device([input, out])
return self.dispatch( return self.dispatch(
[out, input], [self.alloc(out)], [out, input], [self.alloc(out)],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes), self.setup(ws, handle, starts, sizes),
no_grad=True, no_grad=True,
check_device=False, check_device=False,
) )
...@@ -128,7 +115,6 @@ class ChannelNormalize(function.Function): ...@@ -128,7 +115,6 @@ class ChannelNormalize(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(ChannelNormalize, self).__init__(key, dev, **kwargs) super(ChannelNormalize, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1) self.axis = kwargs.get('axis', -1)
self.ndim = kwargs.get('ndim', 0)
self.mean = kwargs.get('mean', None) self.mean = kwargs.get('mean', None)
self.std = kwargs.get('std', None) self.std = kwargs.get('std', None)
self.dtype = kwargs.get('dtype', 'float32') self.dtype = kwargs.get('dtype', 'float32')
...@@ -141,23 +127,18 @@ class ChannelNormalize(function.Function): ...@@ -141,23 +127,18 @@ class ChannelNormalize(function.Function):
'mean': self.mean, 'mean': self.mean,
'std': self.std, 'std': self.std,
'dtype': self.dtype, 'dtype': self.dtype,
'perm_descs': [ 'perm_desc': '${HANDLE}/perm',
'${{HANDLE}}/perm[{}]'
.format(n) for n in range(self.ndim)],
} }
} }
def feed(self, ws, handle, perm): def setup(self, ws, handle, perm):
for i in range(self.ndim): self.feed_arg(ws, '%s/perm' % handle, perm, 'int64')
self.feed_arg(
ws, '{}/perm[{}]'.format(handle, i),
perm[i], 'int64')
def forward(self, input, perm): def forward(self, input, perm):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, perm), self.setup(ws, handle, perm),
) )
...@@ -226,31 +207,20 @@ class Cumulative(function.Function): ...@@ -226,31 +207,20 @@ class Cumulative(function.Function):
class Expand(function.Function): class Expand(function.Function):
"""Expand function.""" """Expand function."""
def __init__(self, key, dev, **kwargs):
super(Expand, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Expand', 'op_type': 'Expand',
'arguments': { 'arguments': {'dims_desc': '${HANDLE}/dims'},
'dims_descs': [
'${{HANDLE}}/dims[{}]'
.format(n) for n in range(self.ndim)],
},
} }
def feed(self, ws, handle, times): def setup(self, ws, handle, times):
for i in range(self.ndim): self.feed_arg(ws, '%s/dims' % handle, times, 'int64')
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
times[i], 'int64')
def forward(self, input, dims): def forward(self, input, dims):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, dims), self.setup(ws, handle, dims),
) )
...@@ -299,12 +269,6 @@ class IndexSelect(function.Function): ...@@ -299,12 +269,6 @@ class IndexSelect(function.Function):
class MaskedAssign(function.Function): class MaskedAssign(function.Function):
"""MaskedAssign function.""" """MaskedAssign function."""
def __init__(self, key, dev, **kwargs):
super(MaskedAssign, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'MaskedAssign', 'arguments': {}}
def forward(self, out, mask, input): def forward(self, out, mask, input):
return self.dispatch([out, input, mask], [self.alloc(out)]) return self.dispatch([out, input, mask], [self.alloc(out)])
...@@ -312,12 +276,6 @@ class MaskedAssign(function.Function): ...@@ -312,12 +276,6 @@ class MaskedAssign(function.Function):
class MaskedSelect(function.Function): class MaskedSelect(function.Function):
"""MaskedSelect function.""" """MaskedSelect function."""
def __init__(self, key, dev, **kwargs):
super(MaskedSelect, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'MaskedSelect', 'arguments': {}}
def forward(self, input, mask, out=None): def forward(self, input, mask, out=None):
return self.dispatch([input, mask], [self.alloc(out)]) return self.dispatch([input, mask], [self.alloc(out)])
...@@ -347,12 +305,6 @@ class Multinomial(function.Function): ...@@ -347,12 +305,6 @@ class Multinomial(function.Function):
class NonZero(function.Function): class NonZero(function.Function):
"""NonZero function.""" """NonZero function."""
def __init__(self, key, dev, **kwargs):
super(NonZero, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'NonZero', 'arguments': {}}
def forward(self, input, out=None): def forward(self, input, out=None):
return self.dispatch([input], [self.alloc(out)], no_grad=True) return self.dispatch([input], [self.alloc(out)], no_grad=True)
...@@ -401,68 +353,46 @@ class Reduce(function.Function): ...@@ -401,68 +353,46 @@ class Reduce(function.Function):
class Reshape(function.Function): class Reshape(function.Function):
"""Reshape function.""" """Reshape function."""
def __init__(self, key, dev, **kwargs):
super(Reshape, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Reshape', 'op_type': 'Reshape',
'arguments': { 'arguments': {
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'
.format(n) for n in range(self.ndim)],
}, },
} }
def feed(self, ws, handle, shape): def setup(self, ws, handle, shape):
for i in range(self.ndim): self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
shape[i], 'int64')
def forward(self, input, shape, out=None): def forward(self, input, shape, out=None):
return self.dispatch( return self.dispatch(
[input], [self.alloc(out)], [input], [self.alloc(out)],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, shape), self.setup(ws, handle, shape),
) )
class Slice(function.Function): class Slice(function.Function):
"""Slice function.""" """Slice function."""
def __init__(self, key, dev, **kwargs):
super(Slice, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Slice', 'op_type': 'Slice',
'arguments': { 'arguments': {
'starts_descs': [ 'starts_desc': '${HANDLE}/starts',
'${{HANDLE}}/starts[{}]' 'sizes_desc': '${HANDLE}/sizes',
.format(n) for n in range(self.ndim)],
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.ndim)],
}, },
} }
def feed(self, ws, handle, starts, sizes): def setup(self, ws, handle, starts, sizes):
for i in range(self.ndim): self.feed_arg(ws, '%s/starts' % handle, starts, 'int64')
self.feed_arg( self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
ws, '{}/starts[{}]'.format(handle, i),
starts[i], 'int64')
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
def forward(self, input, starts, sizes): def forward(self, input, starts, sizes):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes) self.setup(ws, handle, starts, sizes)
) )
...@@ -551,33 +481,22 @@ class Squeeze(function.Function): ...@@ -551,33 +481,22 @@ class Squeeze(function.Function):
class Tile(function.Function): class Tile(function.Function):
"""Tile function.""" """Tile function."""
def __init__(self, key, dev, **kwargs):
super(Tile, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self): def attributes(self):
return { return {
'op_type': 'Tile', 'op_type': 'Tile',
'arguments': { 'arguments': {
'repeats_descs': [ 'repeats_desc': '${HANDLE}/repeats',
'${{HANDLE}}/repeats[{}]'
.format(n) for n in range(self.ndim)],
}, },
} }
def feed(self, ws, handle, repeats): def setup(self, ws, handle, repeats):
for i in range(self.ndim): self.feed_arg(ws, '%s/repeats' % handle, repeats, 'int64')
self.feed_arg(
ws,
'{}/repeats[{}]'.format(handle, i),
repeats[i], 'int64',
)
def forward(self, input, times): def forward(self, input, times):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, times), self.setup(ws, handle, times),
) )
...@@ -592,23 +511,19 @@ class Transpose(function.Function): ...@@ -592,23 +511,19 @@ class Transpose(function.Function):
return { return {
'op_type': 'Transpose', 'op_type': 'Transpose',
'arguments': { 'arguments': {
'perm_descs': [ 'perm_desc': '${HANDLE}/perm'
'${{HANDLE}}/perm[{}]' if self.ndim > 0 else None,
.format(n) for n in range(self.ndim)],
}, },
} }
def feed(self, ws, handle, perm): def setup(self, ws, handle, perm):
for i in range(self.ndim): self.feed_arg(ws, '%s/perm' % handle, perm, 'int64')
self.feed_arg(
ws, '{}/perm[{}]'.format(handle, i),
perm[i], 'int64')
def forward(self, input, perm): def forward(self, input, perm):
return self.dispatch( return self.dispatch(
[input], [self.alloc()], [input], [self.alloc()],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, perm), self.setup(ws, handle, perm),
) )
...@@ -683,11 +598,5 @@ class UnSqueeze(function.Function): ...@@ -683,11 +598,5 @@ class UnSqueeze(function.Function):
class Where(function.Function): class Where(function.Function):
"""Where function.""" """Where function."""
def __init__(self, key, dev, **kwargs):
super(Where, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'Where', 'arguments': {}}
def forward(self, condition, x, y): def forward(self, condition, x, y):
return self.dispatch([x, y, condition], [self.alloc()]) return self.dispatch([x, y, condition], [self.alloc()])
...@@ -22,20 +22,16 @@ class _Initializer(function.Function): ...@@ -22,20 +22,16 @@ class _Initializer(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(_Initializer, self).__init__(key, dev, **kwargs) super(_Initializer, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32') self.dtype = kwargs.get('dtype', 'float32')
def feed(self, ws, handle, shape): def setup(self, ws, handle, shape):
for i in range(self.ndim): self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
shape[i], 'int64')
def forward(self, out, shape, shape_like=None): def forward(self, out, shape, shape_like=None):
return self.dispatch( return self.dispatch(
[] if shape_like is None else [shape_like], [out], [] if shape_like is None else [shape_like], [out],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, shape), self.setup(ws, handle, shape),
no_grad=True, no_grad=True,
) )
...@@ -53,9 +49,7 @@ class Eye(_Initializer): ...@@ -53,9 +49,7 @@ class Eye(_Initializer):
'arguments': { 'arguments': {
'k': self.k, 'k': self.k,
'dtype': self.dtype, 'dtype': self.dtype,
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -73,9 +67,7 @@ class Fill(_Initializer): ...@@ -73,9 +67,7 @@ class Fill(_Initializer):
'arguments': { 'arguments': {
'dtype': self.dtype, 'dtype': self.dtype,
'value': float(self.value), 'value': float(self.value),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -85,7 +77,6 @@ class LinSpace(function.Function): ...@@ -85,7 +77,6 @@ class LinSpace(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(LinSpace, self).__init__(key, dev, **kwargs) super(LinSpace, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.num_intervals = kwargs.get('num_intervals', 1) self.num_intervals = kwargs.get('num_intervals', 1)
self.dtype = kwargs.get('dtype', 'int64') self.dtype = kwargs.get('dtype', 'int64')
self.axis = kwargs.get('axis', 0) self.axis = kwargs.get('axis', 0)
...@@ -96,36 +87,22 @@ class LinSpace(function.Function): ...@@ -96,36 +87,22 @@ class LinSpace(function.Function):
'arguments': { 'arguments': {
'axis': self.axis, 'axis': self.axis,
'dtype': self.dtype, 'dtype': self.dtype,
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n) 'start_desc': '${HANDLE}/start',
for n in range(self.ndim)], 'stop_desc': '${HANDLE}/stop',
'start_descs': [
'${{HANDLE}}/start[{}]'
.format(n) for n in range(self.num_intervals)],
'stop_descs': [
'${{HANDLE}}/stop[{}]'
.format(n) for n in range(self.num_intervals)],
} }
} }
def feed(self, ws, handle, shape, starts, stops): def setup(self, ws, handle, shape, starts, stops):
for i, dim in enumerate(shape): self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
self.feed_arg( self.feed_arg(ws, '%s/start' % handle, starts, 'float64')
ws, '{}/dims[{}]'.format(handle, i), self.feed_arg(ws, '%s/stop' % handle, stops, 'float64')
dim, 'int64')
for i in range(len(starts)):
self.feed_arg(
ws, '{}/start[{}]'.format(handle, i),
starts[i], 'float64')
self.feed_arg(
ws, '{}/stop[{}]'.format(handle, i),
stops[i], 'float64')
def forward(self, shape, starts, stops, out=None): def forward(self, shape, starts, stops, out=None):
return self.dispatch( return self.dispatch(
[], [self.alloc(out)], [], [self.alloc(out)],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, shape, starts, stops), self.setup(ws, handle, shape, starts, stops),
no_grad=True, no_grad=True,
) )
...@@ -146,14 +123,14 @@ class Permutation(function.Function): ...@@ -146,14 +123,14 @@ class Permutation(function.Function):
} }
} }
def feed(self, ws, handle, limit): def setup(self, ws, handle, limit):
self.feed_arg(ws, '{}/limit'.format(handle), limit, 'int64') self.feed_arg(ws, '{}/limit'.format(handle), limit, 'int64')
def forward(self, limit, out=None): def forward(self, limit, out=None):
return self.dispatch( return self.dispatch(
[], [self.alloc(out)], [], [self.alloc(out)],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, limit), self.setup(ws, handle, limit),
no_grad=True, no_grad=True,
) )
...@@ -173,9 +150,7 @@ class RandomNormal(_Initializer): ...@@ -173,9 +150,7 @@ class RandomNormal(_Initializer):
'dtype': self.dtype, 'dtype': self.dtype,
'mean': float(self.mean), 'mean': float(self.mean),
'std': float(self.std), 'std': float(self.std),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -195,9 +170,7 @@ class RandomUniform(_Initializer): ...@@ -195,9 +170,7 @@ class RandomUniform(_Initializer):
'dtype': self.dtype, 'dtype': self.dtype,
'low': float(self.low), 'low': float(self.low),
'high': float(self.high), 'high': float(self.high),
'dims_descs': [ 'dims_desc': '${HANDLE}/dims',
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
}, },
} }
...@@ -207,7 +180,6 @@ class Range(function.Function): ...@@ -207,7 +180,6 @@ class Range(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Range, self).__init__(key, dev, **kwargs) super(Range, self).__init__(key, dev, **kwargs)
self.num_args = kwargs.get('num_args', 3)
self.dtype = kwargs.get('dtype', 'int64') self.dtype = kwargs.get('dtype', 'int64')
def attributes(self): def attributes(self):
...@@ -215,22 +187,17 @@ class Range(function.Function): ...@@ -215,22 +187,17 @@ class Range(function.Function):
'op_type': 'Range', 'op_type': 'Range',
'arguments': { 'arguments': {
'dtype': self.dtype, 'dtype': self.dtype,
'slice_descs': [ 'slice_desc': '${HANDLE}/slice',
'${{HANDLE}}/slice[{}]'
.format(n) for n in range(self.num_args)],
} }
} }
def feed(self, ws, handle, slice_args): def setup(self, ws, handle, slice_args):
for i in range(len(slice_args)): self.feed_arg(ws, '%s/slice' % handle, slice_args, 'float64')
self.feed_arg(
ws, '{}/slice[{}]'.format(handle, i),
slice_args[i], 'float64')
def forward(self, slice_args, out=None): def forward(self, slice_args, out=None):
return self.dispatch( return self.dispatch(
[], [self.alloc(out)], [], [self.alloc(out)],
callback=lambda ws, handle: callback=lambda ws, handle:
self.feed(ws, handle, slice_args), self.setup(ws, handle, slice_args),
no_grad=True, no_grad=True,
) )
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!