Commit 9fc5249b by Ting PAN

Simplify the feeding to repeated tensor arguments

Summary:
This commit feeds the repeated tensor arguments with
the entire array instead of the piecewise scalars.
1 parent b93bde0d
......@@ -26,7 +26,7 @@ run
<style>
h1:before {
content: "tensorrt.";
content: "tensorrt.onnx.";
color: #103d3e;
}
</style>
......@@ -57,6 +57,7 @@ class Operator(object):
The attribute dict.
"""
return {'op_type': self.__class__.__name__, 'arguments': {}}
@classmethod
def blend(cls, op_type=None, **kwargs):
......@@ -76,7 +77,7 @@ class Operator(object):
pre_callback=callback,
)
def feed_arg(self, ws, name, value, dtype='int64'):
def feed_arg(self, ws, name, value, dtype):
"""Set the value of tensor argument."""
ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device)
......
......@@ -35,23 +35,24 @@ class Activation(Operator):
class Dropout(Operator):
"""Dropout operator."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'Dropout',
'arguments': {'ratio_desc': '${HANDLE}/ratio'},
'arguments': {
'ratio_desc': '${HANDLE}/ratio',
},
}
def feed(self, ws, handle, ratio):
self.feed_arg(ws, '{}/ratio'.format(handle), ratio, 'float32')
def setup(self, ws, handle, ratio):
self.feed_arg(ws, '%s/ratio' % handle, ratio, 'float32')
def forward(self, inputs, ratio, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs,
return self.dispatch(
inputs, outputs,
callback=lambda ws, handle:
self.feed(ws, handle, ratio))
self.setup(ws, handle, ratio),
)
class DropBlock2d(Dropout):
......@@ -66,9 +67,9 @@ class DropBlock2d(Dropout):
return {
'op_type': 'DropBlock2d',
'arguments': {
'ratio_desc': '${HANDLE}/ratio',
'block_size': self.block_size,
'data_format': self.data_format,
'ratio_desc': '${HANDLE}/ratio',
},
}
......@@ -76,13 +77,12 @@ class DropBlock2d(Dropout):
class DropPath(Dropout):
"""DropPath operator."""
def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'DropPath',
'arguments': {'ratio_desc': '${HANDLE}/ratio'},
'arguments': {
'ratio_desc': '${HANDLE}/ratio',
},
}
......@@ -96,7 +96,9 @@ class Elu(Activation):
def attributes(self):
return {
'op_type': 'Elu',
'arguments': {'alpha': float(self.alpha)},
'arguments': {
'alpha': float(self.alpha),
},
}
......@@ -146,7 +148,9 @@ class PRelu(Operator):
def attributes(self):
return {
'op_type': 'PRelu',
'arguments': {'data_format': self.data_format},
'arguments': {
'data_format': self.data_format,
},
}
def forward(self, inputs):
......@@ -163,20 +167,21 @@ class Relu(Activation):
def attributes(self):
return {
'op_type': 'Relu',
'arguments': {'alpha': float(self.alpha)},
'arguments': {
'alpha': float(self.alpha),
},
}
class Relu6(Activation):
"""Relu6 operator."""
def __init__(self, key, dev, **kwargs):
super(Relu6, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'Relu',
'arguments': {'max_value': 6.},
'arguments': {
'max_value': 6.,
},
}
......@@ -208,5 +213,7 @@ class Softmax(Activation):
def attributes(self):
return {
'op_type': 'Softmax',
'arguments': {'axis': self.axis},
'arguments': {
'axis': self.axis,
},
}
......@@ -50,7 +50,9 @@ class Cast(Operator):
def attributes(self):
return {
'op_type': 'Cast',
'arguments': {'dtype': self.dtype},
'arguments': {
'dtype': self.dtype,
},
}
def forward(self, inputs, inplace=False):
......@@ -102,23 +104,18 @@ class ChannelNormalize(Operator):
'mean': self.mean,
'std': self.std,
'dtype': self.dtype,
'perm_descs': [
'${{HANDLE}}/perm[{}]'
.format(n) for n in range(self.ndim)],
'perm_desc': '${HANDLE}/perm',
}
}
def feed(self, ws, handle, perm):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/perm[{}]'.format(handle, i),
perm[i], 'int64')
def setup(self, ws, handle, perm):
self.feed_arg(ws, '%s/perm' % handle, perm, 'int64')
def forward(self, inputs, perm):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, perm),
self.setup(ws, handle, perm),
)
......@@ -153,7 +150,9 @@ class Concat(Operator):
def attributes(self):
return {
'op_type': 'Concat',
'arguments': {'axis': self.axis},
'arguments': {
'axis': self.axis,
},
}
def forward(self, inputs):
......@@ -195,23 +194,18 @@ class Expand(Operator):
return {
'op_type': 'Expand',
'arguments': {
'dims_descs': [
'${{HANDLE}}/dims[{}]'
.format(n) for n in range(self.ndim)],
}
'dims_desc': '${HANDLE}/dims',
},
}
def feed(self, ws, handle, dims):
for i, dim in enumerate(dims):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
dim, 'int64')
def setup(self, ws, handle, dims):
self.feed_arg(ws, '%s/dims' % handle, dims, 'int64')
def forward(self, inputs, dims):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, dims),
self.setup(ws, handle, dims),
)
......@@ -262,12 +256,6 @@ class Flatten(Operator):
class Identity(Operator):
"""Identity operator."""
def __init__(self, key, dev, **kwargs):
super(Identity, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'Identity', 'arguments': {}}
def forward(self, inputs, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs)
......@@ -310,36 +298,22 @@ class LinSpace(Operator):
'arguments': {
'axis': self.axis,
'dtype': self.dtype,
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'start_descs': [
'${{HANDLE}}/start[{}]'
.format(n) for n in range(self.num_intervals)],
'stop_descs': [
'${{HANDLE}}/stop[{}]'
.format(n) for n in range(self.num_intervals)],
}
}
def feed(self, ws, handle, shape, starts, stops):
for i, dim in enumerate(shape):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
dim, 'int64')
for i in range(len(starts)):
self.feed_arg(
ws, '{}/start[{}]'.format(handle, i),
starts[i], 'float64')
self.feed_arg(
ws, '{}/stop[{}]'.format(handle, i),
stops[i], 'float64')
'dims_desc': '${HANDLE}/dims',
'start_desc': '${HANDLE}/start',
'stop_desc': '${HANDLE}/stop',
}
}
def setup(self, ws, handle, shape, starts, stops):
self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
self.feed_arg(ws, '%s/start' % handle, starts, 'float64')
self.feed_arg(ws, '%s/stop' % handle, stops, 'float64')
def forward(self, shape, starts, stops, trainable=False):
out = self.dispatch(
[], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, shape, starts, stops),
self.setup(ws, handle, shape, starts, stops),
no_grad=True,
)
out._requires_grad = trainable
......@@ -349,12 +323,6 @@ class LinSpace(Operator):
class MaskedSelect(Operator):
"""MaskedSelect operator."""
def __init__(self, key, dev, **kwargs):
super(MaskedSelect, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'MaskedSelect', 'arguments': {}}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()])
......@@ -406,12 +374,6 @@ class Multinomial(Operator):
class NonZero(Operator):
"""NonZero operator."""
def __init__(self, key, dev, **kwargs):
super(NonZero, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'NonZero', 'arguments': {}}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()])
......@@ -454,23 +416,18 @@ class Pad(Operator):
'arguments': {
'mode': self.mode,
'value': self.value,
'pads_descs': [
'${{HANDLE}}/pads[{}]'
.format(n) for n in range(self.ndim * 2)],
'pads_desc': '${HANDLE}/pads',
}
}
def feed(self, ws, handle, pads):
for i, e in enumerate(pads):
self.feed_arg(
ws, '{}/pads[{}]'.format(handle, i),
e, 'int64')
def setup(self, ws, handle, pads):
self.feed_arg(ws, '%s/pads' % handle, pads, 'int64')
def forward(self, inputs, pads):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, pads),
self.setup(ws, handle, pads),
)
......@@ -490,14 +447,14 @@ class Permutation(Operator):
}
}
def feed(self, ws, handle, limit):
self.feed_arg(ws, '{}/limit'.format(handle), limit, 'int64')
def setup(self, ws, handle, limit):
self.feed_arg(ws, '%s/limit' % handle, limit, 'int64')
def forward(self, limit, trainable=False):
out = self.dispatch(
[], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, limit),
self.setup(ws, handle, limit),
no_grad=True,
)
out._requires_grad = trainable
......@@ -517,23 +474,18 @@ class Range(Operator):
'op_type': 'Range',
'arguments': {
'dtype': self.dtype,
'slice_descs': [
'${{HANDLE}}/slice[{}]'
.format(n) for n in range(self.num_args)],
'slice_desc': '${HANDLE}/slice',
}
}
def feed(self, ws, handle, slice_args):
for i in range(len(slice_args)):
self.feed_arg(
ws, '{}/slice[{}]'.format(handle, i),
slice_args[i], 'float64')
def setup(self, ws, handle, slice_args):
self.feed_arg(ws, '%s/slice' % handle, slice_args, 'float64')
def forward(self, slice_args, trainable=False):
out = self.dispatch(
[], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, slice_args),
self.setup(ws, handle, slice_args),
no_grad=True,
)
out._requires_grad = trainable
......@@ -586,68 +538,46 @@ class Repeat(Operator):
class Reshape(Operator):
"""Reshape operator."""
def __init__(self, key, dev, **kwargs):
super(Reshape, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Reshape',
'arguments': {
'dims_descs': [
'${{HANDLE}}/dims[{}]'
.format(n) for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
}
}
def feed(self, ws, handle, shape):
for i, e in enumerate(shape):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
e, 'int64')
def setup(self, ws, handle, shape):
self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
def forward(self, inputs, shape, inplace=False):
return self.dispatch(
inputs, [self.alloc(inputs[0]) if inplace else self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, shape),
self.setup(ws, handle, shape),
)
class Slice(Operator):
"""Slice operator."""
def __init__(self, key, dev, **kwargs):
super(Slice, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Slice',
'arguments': {
'starts_descs': [
'${{HANDLE}}/starts[{}]'
.format(n) for n in range(self.ndim)],
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.ndim)],
'starts_desc': '${HANDLE}/starts',
'sizes_desc': '${HANDLE}/sizes',
}
}
def feed(self, ws, handle, starts, sizes):
for i in range(len(starts)):
self.feed_arg(
ws, '{}/starts[{}]'.format(handle, i),
starts[i], 'int64')
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
def setup(self, ws, handle, starts, sizes):
self.feed_arg(ws, '%s/starts' % handle, starts, 'int64')
self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
def forward(self, inputs, starts, sizes):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes),
self.setup(ws, handle, starts, sizes),
)
......@@ -658,9 +588,6 @@ class Shape(Operator):
super(Shape, self).__init__(key, dev, **kwargs)
self._device = device_spec.DeviceSpec()
def attributes(self):
return {'op_type': 'Shape', 'arguments': {}}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()], no_grad=True)
......@@ -729,6 +656,8 @@ class Squeeze(Operator):
class Stack(Operator):
"""Stack Operator."""
def __init__(self, key, dev, **kwargs):
super(Stack, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
......@@ -746,31 +675,22 @@ class Stack(Operator):
class Tile(Operator):
"""Tile operator."""
def __init__(self, key, dev, **kwargs):
super(Tile, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Tile',
'arguments': {
'repeats_descs': [
'${{HANDLE}}/repeats[{}]'
.format(n) for n in range(self.ndim)],
'repeats_desc': '${HANDLE}/repeats',
}
}
def feed(self, ws, handle, repeats):
for i, size in enumerate(repeats):
self.feed_arg(
ws, '{}/repeats[{}]'.format(handle, i),
size, 'int64')
def setup(self, ws, handle, repeats):
self.feed_arg(ws, '%s/repeats' % handle, repeats, 'int64')
def forward(self, inputs, repeats):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, repeats),
self.setup(ws, handle, repeats),
)
......@@ -785,23 +705,20 @@ class Transpose(Operator):
return {
'op_type': 'Transpose',
'arguments': {
'perm_descs': [
'${{HANDLE}}/perm[{}]'
.format(n) for n in range(self.ndim)],
'perm_desc': '${HANDLE}/perm'
if self.ndim > 0 else None,
}
}
def feed(self, ws, handle, perm):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/perm[{}]'.format(handle, i),
perm[i], 'int64')
def setup(self, ws, handle, perm):
if perm is not None:
self.feed_arg(ws, '%s/perm' % handle, perm, 'int64')
def forward(self, inputs, perm):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, perm),
self.setup(ws, handle, perm) if perm else None,
)
......@@ -827,7 +744,8 @@ class TopK(Operator):
}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc(), self.alloc()], no_grad=True)
return self.dispatch(
inputs, [self.alloc(), self.alloc()], no_grad=True)
class Unique(Operator):
......@@ -856,11 +774,5 @@ class Unique(Operator):
class Where(Operator):
"""Where operator."""
def __init__(self, key, dev, **kwargs):
super(Where, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'Where', 'arguments': {}}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc()])
......@@ -20,38 +20,25 @@ from dragon.core.framework.ops import Operator
class Assign(Operator):
"""Assign operator."""
def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Assign',
'arguments': {
'starts_descs': [
'${{HANDLE}}/starts[{}]'
.format(n) for n in range(self.ndim)],
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.ndim)],
'starts_desc': '${HANDLE}/starts',
'sizes_desc': '${HANDLE}/sizes',
},
}
def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/starts[{}]'.format(handle, i),
starts[i], 'int64')
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
def setup(self, ws, handle, starts, sizes):
self.feed_arg(ws, '%s/starts' % handle, starts, 'int64')
self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
def forward(self, inputs, starts, sizes, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(
inputs, outputs,
callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes),
self.setup(ws, handle, starts, sizes),
no_grad=True,
)
......@@ -59,12 +46,6 @@ class Assign(Operator):
class MaskedAssign(Operator):
"""MaskedAssign operator."""
def __init__(self, key, dev, **kwargs):
super(MaskedAssign, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'MaskedAssign', 'arguments': {}}
def forward(self, inputs, inplace=False):
outputs = [self.alloc(inputs[0]) if inplace else self.alloc()]
return self.dispatch(inputs, outputs, no_grad=True)
......@@ -25,18 +25,13 @@ class Initializer(Operator):
self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32')
def feed(self, ws, handle, shape):
for i, dim in enumerate(shape):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
dim, 'int64')
def setup(self, ws, handle, shape):
self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
def forward(self, shape, shape_as=None, out=None, trainable=None):
out = self.dispatch(
[] if shape_as is None else [shape_as],
[self.alloc(out)],
callback=lambda ws, handle:
self.feed(ws, handle, shape),
[] if shape_as is None else [shape_as], [self.alloc(out)],
callback=lambda ws, handle: self.setup(ws, handle, shape),
no_grad=True,
)
if trainable is not None:
......@@ -57,9 +52,7 @@ class Eye(Initializer):
'arguments': {
'k': self.k,
'dtype': self.dtype,
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -75,9 +68,7 @@ class Fill(Initializer):
'arguments': {
'dtype': self.dtype,
'value': float(self.value),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -97,9 +88,7 @@ class GlorotNormal(Initializer):
'dtype': self.dtype,
'scale': float(self.scale),
'mode': self.mode.lower(),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -119,9 +108,7 @@ class GlorotUniform(Initializer):
'dtype': self.dtype,
'scale': float(self.scale),
'mode': self.mode.lower(),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -141,9 +128,7 @@ class RandomNormal(Initializer):
'dtype': self.dtype,
'mean': float(self.mean),
'std': float(self.std),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -163,9 +148,7 @@ class RandomUniform(Initializer):
'dtype': self.dtype,
'low': float(self.low),
'high': float(self.high),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -185,8 +168,6 @@ class TruncatedNormal(Initializer):
'dtype': self.dtype,
'mean': float(self.mean),
'std': float(self.std),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -29,7 +29,7 @@ class Loss(Operator):
'op_type': self.__class__.__name__,
'arguments': {
'reduction': self.reduction,
}
},
}
def forward(self, inputs):
......@@ -39,16 +39,10 @@ class Loss(Operator):
class L1Loss(Loss):
"""L1Loss operator."""
def __init__(self, key, dev, **kwargs):
super(L1Loss, self).__init__(key, dev, **kwargs)
class L2Loss(Loss):
"""L2Loss operator."""
def __init__(self, key, dev, **kwargs):
super(L2Loss, self).__init__(key, dev, **kwargs)
class NLLLoss(Loss):
"""NLLLoss operator."""
......@@ -72,9 +66,6 @@ class NLLLoss(Loss):
class SigmoidCrossEntropy(Loss):
"""SigmoidCrossEntropy operator."""
def __init__(self, key, dev, **kwargs):
super(SigmoidCrossEntropy, self).__init__(key, dev, **kwargs)
class SmoothL1Loss(Loss):
"""SmoothL1Loss operator."""
......
......@@ -37,7 +37,7 @@ class BatchNorm(Operator):
'momentum': self.momentum,
'epsilon': self.epsilon,
'use_stats': self.use_stats,
}
},
}
def forward(self, inputs):
......@@ -60,7 +60,7 @@ class GroupNorm(Operator):
'axis': self.axis,
'group': self.group,
'epsilon': self.epsilon,
}
},
}
def forward(self, inputs):
......@@ -87,7 +87,7 @@ class LpNormalize(Operator):
'num_axes': self.num_axes,
'epsilon': self.epsilon,
'reduction': self.reduction,
}
},
}
def forward(self, inputs):
......@@ -114,7 +114,7 @@ class LocalResponseNorm(Operator):
'beta': self.beta,
'bias': self.bias,
'data_format': self.data_format,
}
},
}
def forward(self, inputs):
......
......@@ -20,12 +20,6 @@ from dragon.core.framework.ops import Operator
class LSTMCell(Operator):
"""LSTMCell operator."""
def __init__(self, key, dev, **kwargs):
super(LSTMCell, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'LSTMCell', 'arguments': {}}
def forward(self, inputs):
outputs = [self.alloc() for _ in range(2)]
return self.dispatch(inputs, outputs)
......@@ -54,7 +48,7 @@ class Recurrent(Operator):
'rnn_input_mode': 'linear',
'dropout_ratio': self.dropout_ratio,
'phase': 'TRAIN' if self.is_training else 'TEST'
}
},
}
def forward(self, inputs):
......@@ -87,7 +81,7 @@ class RNNParamSet(Operator):
'layer_id': self.layer_id,
'param_id': self.param_id,
'rnn_mode': self.mode,
}
},
}
def forward(self, inputs):
......
......@@ -91,7 +91,9 @@ class BiasAdd(Operator):
def attributes(self):
return {
'op_type': 'BiasAdd',
'arguments': {'data_format': self.data_format},
'arguments': {
'data_format': self.data_format,
},
}
def forward(self, inputs, inplace=False):
......@@ -102,9 +104,6 @@ class BiasAdd(Operator):
class Conv2d(ConvNd):
"""Conv2d operator."""
def __init__(self, key, dev, **kwargs):
super(Conv2d, self).__init__(key, dev, **kwargs)
class ConvTranspose2d(ConvNd):
"""ConvTranspose2d operator."""
......@@ -154,16 +153,10 @@ class DepthToSpace(Operator):
class DepthwiseConv2d(ConvNd):
"""DepthwiseConv2d operator."""
def __init__(self, key, dev, **kwargs):
super(DepthwiseConv2d, self).__init__(key, dev, **kwargs)
class Pool2d(PoolNd):
"""Pool2d operator."""
def __init__(self, key, dev, **kwargs):
super(Pool2d, self).__init__(key, dev, **kwargs)
class Resize(Operator):
"""Resize operator."""
......@@ -182,31 +175,25 @@ class Resize(Operator):
'arguments': {
'mode': self.mode,
'align_corners': self.align_corners,
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.num_sizes)],
'scales_descs': [
'${{HANDLE}}/scales[{}]'
.format(n) for n in range(self.num_scales)],
'data_format': self.data_format,
}
'sizes_desc': '${HANDLE}/sizes'
if self.num_sizes > 0 else None,
'scales_desc': '${HANDLE}/scales'
if self.num_scales > 0 else None,
},
}
def feed(self, ws, handle, sizes, scales):
for i in range(self.num_sizes):
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
for i in range(self.num_scales):
self.feed_arg(
ws, '{}/scales[{}]'.format(handle, i),
scales[i], 'float32')
def setup(self, ws, handle, sizes, scales):
if sizes is not None:
self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
if scales is not None:
self.feed_arg(ws, '%s/scales' % handle, scales, 'float32')
def forward(self, inputs, sizes=None, scales=None):
return self.dispatch(
inputs, [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, sizes, scales),
self.setup(ws, handle, sizes, scales),
)
......
......@@ -102,6 +102,7 @@ class Function(object):
The attribute dict.
"""
return {'op_type': self.__class__.__name__, 'arguments': {}}
def dispatch(
self,
......@@ -124,7 +125,7 @@ class Function(object):
pre_callback=callback,
)
def feed_arg(self, ws, name, value, dtype='int64'):
def feed_arg(self, ws, name, value, dtype):
"""Set the value of tensor argument."""
ws.FeedTensor(name, numpy.array(value, dtype), self._arg_device)
......
......@@ -134,16 +134,10 @@ class BatchNorm(function.Function):
class Conv2d(_ConvNd):
"""Conv2d function."""
def __init__(self, key, dev, **kwargs):
super(Conv2d, self).__init__(key, dev, **kwargs)
class ConvTranspose2d(_ConvNd):
"""ConvTranspose2d function."""
def __init__(self, key, dev, **kwargs):
super(ConvTranspose2d, self).__init__(key, dev, **kwargs)
class CTCLoss(_Loss):
"""CTCLoss function."""
......@@ -165,30 +159,28 @@ class CTCLoss(_Loss):
class DepthwiseConv2d(_ConvNd):
"""DepthwiseConv2d function."""
def __init__(self, key, dev, **kwargs):
super(DepthwiseConv2d, self).__init__(key, dev, **kwargs)
class Dropout(function.Function):
"""Dropout function."""
def __init__(self, key, dev, **kwargs):
super(Dropout, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'Dropout',
'arguments': {'ratio_desc': '${HANDLE}/ratio'},
'arguments': {
'ratio_desc': '${HANDLE}/ratio',
},
}
def feed(self, ws, handle, ratio):
def setup(self, ws, handle, ratio):
self.feed_arg(ws, '{}/ratio'.format(handle), ratio, 'float32')
def forward(self, input, ratio, inplace=False):
out = input if inplace else self.alloc()
return self.dispatch([input], [out],
return self.dispatch(
[input], [out],
callback=lambda ws, handle:
self.feed(ws, handle, ratio))
self.setup(ws, handle, ratio),
)
class DropBlock2d(Dropout):
......@@ -202,9 +194,9 @@ class DropBlock2d(Dropout):
return {
'op_type': 'DropBlock2d',
'arguments': {
'ratio_desc': '${HANDLE}/ratio',
'block_size': self.block_size,
'data_format': 'NCHW',
'block_size': self.block_size,
'ratio_desc': '${HANDLE}/ratio',
}
}
......@@ -212,13 +204,12 @@ class DropBlock2d(Dropout):
class DropPath(Dropout):
"""DropPath function."""
def __init__(self, key, dev, **kwargs):
super(DropPath, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'DropPath',
'arguments': {'ratio_desc': '${HANDLE}/ratio'},
'arguments': {
'ratio_desc': '${HANDLE}/ratio',
},
}
......@@ -253,7 +244,7 @@ class GroupNorm(function.Function):
'axis': 1,
'group': self.group,
'epsilon': self.epsilon,
}
},
}
def forward(self, input, weight, bias):
......@@ -299,41 +290,32 @@ class HardSwish(_Activation):
class L1Loss(_Loss):
"""L1Loss function."""
def __init__(self, key, dev, **kwargs):
super(L1Loss, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'L1Loss',
'arguments': {
'scale': 1.,
'reduction': self.reduction,
}
},
}
class L2Loss(_Loss):
"""L2Loss function."""
def __init__(self, key, dev, **kwargs):
super(L2Loss, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'L2Loss',
'arguments': {
'scale': 2.,
'reduction': self.reduction,
}
},
}
class Linear(function.Function):
"""Linear function."""
def __init__(self, key, dev, **kwargs):
super(Linear, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'FullyConnected',
......@@ -368,7 +350,7 @@ class LocalResponseNorm(function.Function):
'beta': self.beta,
'bias': self.bias,
'data_format': 'NCHW',
}
},
}
def forward(self, input):
......@@ -393,7 +375,7 @@ class LpNormalize(function.Function):
'epsilon': self.epsilon,
'num_axes': 1,
'reduction': 'SUM',
}
},
}
def forward(self, input, out=None):
......@@ -403,12 +385,6 @@ class LpNormalize(function.Function):
class LSTMCell(function.Function):
"""LSTMCell function."""
def __init__(self, key, dev, **kwargs):
super(LSTMCell, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'LSTMCell', 'arguments': {}}
def forward(self, input, cx):
outputs = [self.alloc() for _ in range(2)]
return self.dispatch([input, cx], outputs)
......@@ -428,7 +404,7 @@ class NLLLoss(_Loss):
'axis': 1,
'reduction': self.reduction,
'ignore_index': self.ignore_index,
}
},
}
......@@ -437,7 +413,6 @@ class Pad(function.Function):
def __init__(self, key, dev, **kwargs):
super(Pad, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.value = kwargs.get('value', 0.)
self.mode = kwargs.get('mode', 'CONSTANT')
......@@ -447,42 +422,28 @@ class Pad(function.Function):
'arguments': {
'mode': self.mode,
'value': self.value,
'pads_descs': [
'${{HANDLE}}/pads[{}]'
.format(n) for n in range(self.ndim * 2)
],
}
'pads_desc': '${HANDLE}/pads',
},
}
def feed(self, ws, handle, pads):
for i, e in enumerate(pads):
self.feed_arg(
ws,
'{}/pads[{}]'.format(handle, i),
e, 'int64'
)
def setup(self, ws, handle, pads):
self.feed_arg(ws, '%s/pads' % handle, pads, 'int64')
def forward(self, input, pads):
return self.dispatch(
[input], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, pads),
self.setup(ws, handle, pads),
)
class Pool2d(_PoolNd):
"""Pool2d function."""
def __init__(self, key, dev, **kwargs):
super(Pool2d, self).__init__(key, dev, **kwargs)
class PRelu(function.Function):
"""PRelu function."""
def __init__(self, key, dev, **kwargs):
super(PRelu, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'PRelu',
......@@ -552,9 +513,6 @@ class Relu(_Activation):
class Relu6(_Activation):
"""Relu6 function."""
def __init__(self, key, dev, **kwargs):
super(Relu6, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'Relu',
......@@ -583,30 +541,24 @@ class Resize(function.Function):
'mode': self.mode,
'align_corners': self.align_corners,
'data_format': 'NCHW',
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.num_sizes)],
'scales_descs': [
'${{HANDLE}}/scales[{}]'
.format(n) for n in range(self.num_scales)],
}
'sizes_desc': '${HANDLE}/sizes'
if self.num_sizes > 0 else None,
'scales_desc': '${HANDLE}/scales'
if self.num_scales > 0 else None,
},
}
def feed(self, ws, handle, sizes, scales):
for i in range(self.num_sizes):
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
for i in range(self.num_scales):
self.feed_arg(
ws, '{}/scales[{}]'.format(handle, i),
scales[i], 'float32')
def setup(self, ws, handle, sizes, scales):
if sizes is not None:
self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
if scales is not None:
self.feed_arg(ws, '%s/scales' % handle, scales, 'float32')
def forward(self, input, sizes=None, scales=None):
return self.dispatch(
[input], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, sizes, scales),
self.setup(ws, handle, sizes, scales),
)
......@@ -646,15 +598,12 @@ class RNNParamSet(function.Function):
class SigmoidCrossEntropy(_Loss):
"""SigmoidCrossEntropy function."""
def __init__(self, key, dev, **kwargs):
super(SigmoidCrossEntropy, self).__init__(key, dev, **kwargs)
def attributes(self):
return {
'op_type': 'SigmoidCrossEntropy',
'arguments': {
'reduction': self.reduction,
}
},
}
......@@ -693,7 +642,7 @@ class SmoothL1Loss(_Loss):
'arguments': {
'beta': float(self.beta),
'reduction': self.reduction,
}
},
}
......@@ -727,7 +676,7 @@ class SparseSoftmaxCrossEntropy(_Loss):
'axis': 1,
'reduction': self.reduction,
'ignore_index': self.ignore_index,
}
},
}
......
......@@ -42,38 +42,25 @@ class ArgReduce(function.Function):
class Assign(function.Function):
"""Assign function."""
def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Assign',
'arguments': {
'starts_descs': [
'${{HANDLE}}/starts[{}]'
.format(n) for n in range(self.ndim)],
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.ndim)],
'starts_desc': '${HANDLE}/starts',
'sizes_desc': '${HANDLE}/sizes',
},
}
def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/starts[{}]'.format(handle, i),
starts[i], 'int64')
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
def setup(self, ws, handle, starts, sizes):
self.feed_arg(ws, '%s/starts' % handle, starts, 'int64')
self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
def forward(self, out, starts, sizes, input):
self._check_device([input, out])
return self.dispatch(
[out, input], [self.alloc(out)],
callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes),
self.setup(ws, handle, starts, sizes),
no_grad=True,
check_device=False,
)
......@@ -128,7 +115,6 @@ class ChannelNormalize(function.Function):
def __init__(self, key, dev, **kwargs):
super(ChannelNormalize, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
self.ndim = kwargs.get('ndim', 0)
self.mean = kwargs.get('mean', None)
self.std = kwargs.get('std', None)
self.dtype = kwargs.get('dtype', 'float32')
......@@ -141,23 +127,18 @@ class ChannelNormalize(function.Function):
'mean': self.mean,
'std': self.std,
'dtype': self.dtype,
'perm_descs': [
'${{HANDLE}}/perm[{}]'
.format(n) for n in range(self.ndim)],
'perm_desc': '${HANDLE}/perm',
}
}
def feed(self, ws, handle, perm):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/perm[{}]'.format(handle, i),
perm[i], 'int64')
def setup(self, ws, handle, perm):
self.feed_arg(ws, '%s/perm' % handle, perm, 'int64')
def forward(self, input, perm):
return self.dispatch(
[input], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, perm),
self.setup(ws, handle, perm),
)
......@@ -226,31 +207,20 @@ class Cumulative(function.Function):
class Expand(function.Function):
"""Expand function."""
def __init__(self, key, dev, **kwargs):
super(Expand, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Expand',
'arguments': {
'dims_descs': [
'${{HANDLE}}/dims[{}]'
.format(n) for n in range(self.ndim)],
},
'arguments': {'dims_desc': '${HANDLE}/dims'},
}
def feed(self, ws, handle, times):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
times[i], 'int64')
def setup(self, ws, handle, times):
self.feed_arg(ws, '%s/dims' % handle, times, 'int64')
def forward(self, input, dims):
return self.dispatch(
[input], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, dims),
self.setup(ws, handle, dims),
)
......@@ -299,12 +269,6 @@ class IndexSelect(function.Function):
class MaskedAssign(function.Function):
"""MaskedAssign function."""
def __init__(self, key, dev, **kwargs):
super(MaskedAssign, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'MaskedAssign', 'arguments': {}}
def forward(self, out, mask, input):
return self.dispatch([out, input, mask], [self.alloc(out)])
......@@ -312,12 +276,6 @@ class MaskedAssign(function.Function):
class MaskedSelect(function.Function):
"""MaskedSelect function."""
def __init__(self, key, dev, **kwargs):
super(MaskedSelect, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'MaskedSelect', 'arguments': {}}
def forward(self, input, mask, out=None):
return self.dispatch([input, mask], [self.alloc(out)])
......@@ -347,12 +305,6 @@ class Multinomial(function.Function):
class NonZero(function.Function):
"""NonZero function."""
def __init__(self, key, dev, **kwargs):
super(NonZero, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'NonZero', 'arguments': {}}
def forward(self, input, out=None):
return self.dispatch([input], [self.alloc(out)], no_grad=True)
......@@ -401,68 +353,46 @@ class Reduce(function.Function):
class Reshape(function.Function):
"""Reshape function."""
def __init__(self, key, dev, **kwargs):
super(Reshape, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Reshape',
'arguments': {
'dims_descs': [
'${{HANDLE}}/dims[{}]'
.format(n) for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
def feed(self, ws, handle, shape):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
shape[i], 'int64')
def setup(self, ws, handle, shape):
self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
def forward(self, input, shape, out=None):
return self.dispatch(
[input], [self.alloc(out)],
callback=lambda ws, handle:
self.feed(ws, handle, shape),
self.setup(ws, handle, shape),
)
class Slice(function.Function):
"""Slice function."""
def __init__(self, key, dev, **kwargs):
super(Slice, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Slice',
'arguments': {
'starts_descs': [
'${{HANDLE}}/starts[{}]'
.format(n) for n in range(self.ndim)],
'sizes_descs': [
'${{HANDLE}}/sizes[{}]'
.format(n) for n in range(self.ndim)],
'starts_desc': '${HANDLE}/starts',
'sizes_desc': '${HANDLE}/sizes',
},
}
def feed(self, ws, handle, starts, sizes):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/starts[{}]'.format(handle, i),
starts[i], 'int64')
self.feed_arg(
ws, '{}/sizes[{}]'.format(handle, i),
sizes[i], 'int64')
def setup(self, ws, handle, starts, sizes):
self.feed_arg(ws, '%s/starts' % handle, starts, 'int64')
self.feed_arg(ws, '%s/sizes' % handle, sizes, 'int64')
def forward(self, input, starts, sizes):
return self.dispatch(
[input], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, starts, sizes)
self.setup(ws, handle, starts, sizes)
)
......@@ -551,33 +481,22 @@ class Squeeze(function.Function):
class Tile(function.Function):
"""Tile function."""
def __init__(self, key, dev, **kwargs):
super(Tile, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
def attributes(self):
return {
'op_type': 'Tile',
'arguments': {
'repeats_descs': [
'${{HANDLE}}/repeats[{}]'
.format(n) for n in range(self.ndim)],
'repeats_desc': '${HANDLE}/repeats',
},
}
def feed(self, ws, handle, repeats):
for i in range(self.ndim):
self.feed_arg(
ws,
'{}/repeats[{}]'.format(handle, i),
repeats[i], 'int64',
)
def setup(self, ws, handle, repeats):
self.feed_arg(ws, '%s/repeats' % handle, repeats, 'int64')
def forward(self, input, times):
return self.dispatch(
[input], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, times),
self.setup(ws, handle, times),
)
......@@ -592,23 +511,19 @@ class Transpose(function.Function):
return {
'op_type': 'Transpose',
'arguments': {
'perm_descs': [
'${{HANDLE}}/perm[{}]'
.format(n) for n in range(self.ndim)],
'perm_desc': '${HANDLE}/perm'
if self.ndim > 0 else None,
},
}
def feed(self, ws, handle, perm):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/perm[{}]'.format(handle, i),
perm[i], 'int64')
def setup(self, ws, handle, perm):
self.feed_arg(ws, '%s/perm' % handle, perm, 'int64')
def forward(self, input, perm):
return self.dispatch(
[input], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, perm),
self.setup(ws, handle, perm),
)
......@@ -683,11 +598,5 @@ class UnSqueeze(function.Function):
class Where(function.Function):
"""Where function."""
def __init__(self, key, dev, **kwargs):
super(Where, self).__init__(key, dev, **kwargs)
def attributes(self):
return {'op_type': 'Where', 'arguments': {}}
def forward(self, condition, x, y):
return self.dispatch([x, y, condition], [self.alloc()])
......@@ -22,20 +22,16 @@ class _Initializer(function.Function):
def __init__(self, key, dev, **kwargs):
super(_Initializer, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.dtype = kwargs.get('dtype', 'float32')
def feed(self, ws, handle, shape):
for i in range(self.ndim):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
shape[i], 'int64')
def setup(self, ws, handle, shape):
self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
def forward(self, out, shape, shape_like=None):
return self.dispatch(
[] if shape_like is None else [shape_like], [out],
callback=lambda ws, handle:
self.feed(ws, handle, shape),
self.setup(ws, handle, shape),
no_grad=True,
)
......@@ -53,9 +49,7 @@ class Eye(_Initializer):
'arguments': {
'k': self.k,
'dtype': self.dtype,
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -73,9 +67,7 @@ class Fill(_Initializer):
'arguments': {
'dtype': self.dtype,
'value': float(self.value),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -85,7 +77,6 @@ class LinSpace(function.Function):
def __init__(self, key, dev, **kwargs):
super(LinSpace, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.num_intervals = kwargs.get('num_intervals', 1)
self.dtype = kwargs.get('dtype', 'int64')
self.axis = kwargs.get('axis', 0)
......@@ -96,36 +87,22 @@ class LinSpace(function.Function):
'arguments': {
'axis': self.axis,
'dtype': self.dtype,
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'start_descs': [
'${{HANDLE}}/start[{}]'
.format(n) for n in range(self.num_intervals)],
'stop_descs': [
'${{HANDLE}}/stop[{}]'
.format(n) for n in range(self.num_intervals)],
'dims_desc': '${HANDLE}/dims',
'start_desc': '${HANDLE}/start',
'stop_desc': '${HANDLE}/stop',
}
}
def feed(self, ws, handle, shape, starts, stops):
for i, dim in enumerate(shape):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
dim, 'int64')
for i in range(len(starts)):
self.feed_arg(
ws, '{}/start[{}]'.format(handle, i),
starts[i], 'float64')
self.feed_arg(
ws, '{}/stop[{}]'.format(handle, i),
stops[i], 'float64')
def setup(self, ws, handle, shape, starts, stops):
self.feed_arg(ws, '%s/dims' % handle, shape, 'int64')
self.feed_arg(ws, '%s/start' % handle, starts, 'float64')
self.feed_arg(ws, '%s/stop' % handle, stops, 'float64')
def forward(self, shape, starts, stops, out=None):
return self.dispatch(
[], [self.alloc(out)],
callback=lambda ws, handle:
self.feed(ws, handle, shape, starts, stops),
self.setup(ws, handle, shape, starts, stops),
no_grad=True,
)
......@@ -146,14 +123,14 @@ class Permutation(function.Function):
}
}
def feed(self, ws, handle, limit):
def setup(self, ws, handle, limit):
self.feed_arg(ws, '{}/limit'.format(handle), limit, 'int64')
def forward(self, limit, out=None):
return self.dispatch(
[], [self.alloc(out)],
callback=lambda ws, handle:
self.feed(ws, handle, limit),
self.setup(ws, handle, limit),
no_grad=True,
)
......@@ -173,9 +150,7 @@ class RandomNormal(_Initializer):
'dtype': self.dtype,
'mean': float(self.mean),
'std': float(self.std),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -195,9 +170,7 @@ class RandomUniform(_Initializer):
'dtype': self.dtype,
'low': float(self.low),
'high': float(self.high),
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'dims_desc': '${HANDLE}/dims',
},
}
......@@ -207,7 +180,6 @@ class Range(function.Function):
def __init__(self, key, dev, **kwargs):
super(Range, self).__init__(key, dev, **kwargs)
self.num_args = kwargs.get('num_args', 3)
self.dtype = kwargs.get('dtype', 'int64')
def attributes(self):
......@@ -215,22 +187,17 @@ class Range(function.Function):
'op_type': 'Range',
'arguments': {
'dtype': self.dtype,
'slice_descs': [
'${{HANDLE}}/slice[{}]'
.format(n) for n in range(self.num_args)],
'slice_desc': '${HANDLE}/slice',
}
}
def feed(self, ws, handle, slice_args):
for i in range(len(slice_args)):
self.feed_arg(
ws, '{}/slice[{}]'.format(handle, i),
slice_args[i], 'float64')
def setup(self, ws, handle, slice_args):
self.feed_arg(ws, '%s/slice' % handle, slice_args, 'float64')
def forward(self, slice_args, out=None):
return self.dispatch(
[], [self.alloc(out)],
callback=lambda ws, handle:
self.feed(ws, handle, slice_args),
self.setup(ws, handle, slice_args),
no_grad=True,
)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!