Commit 77179032 by Ting PAN

Refactor Shape Module

1 parent 2d1b7752
...@@ -17,16 +17,16 @@ class InitializeOp: public Operator<Context> { ...@@ -17,16 +17,16 @@ class InitializeOp: public Operator<Context> {
public: public:
InitializeOp(const OperatorDef& op_def, Workspace* ws) InitializeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
static_shape(OperatorBase::GetRepeatedArg<int>("static_shape")), dims_desc(OperatorBase::GetRepeatedArg<string>("dims")),
dynamic_shape(OperatorBase::GetSingleArg<string>("dynamic_shape", "")) {} shape_desc(OperatorBase::GetSingleArg<string>("shape", "")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<string> dims_desc;
string shape_desc;
TensorFiller filler; TensorFiller filler;
vector<int> static_shape;
string dynamic_shape;
}; };
template <class Context> template <class Context>
......
...@@ -16,24 +16,19 @@ class ArangeOp final : public Operator<Context> { ...@@ -16,24 +16,19 @@ class ArangeOp final : public Operator<Context> {
public: public:
ArangeOp(const OperatorDef& op_def, Workspace* ws) ArangeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
start(OperatorBase::GetSingleArg<int>("static_start", 0)), start_desc(OperatorBase::GetSingleArg<string>("start", "")),
stop(OperatorBase::GetSingleArg<int>("static_stop", -1)), stop_desc(OperatorBase::GetSingleArg<string>("stop", "")),
step(OperatorBase::GetSingleArg<int>("static_step", 1)), step_desc(OperatorBase::GetSingleArg<string>("step", "")),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) { dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) {}
dynamic_start_ = OperatorBase::GetSingleArg<string>("dynamic_start", "");
dynamic_stop_ = OperatorBase::GetSingleArg<string>("dynamic_stop", "");
dynamic_step_ = OperatorBase::GetSingleArg<string>("dynamic_step", "");
}
void RunOnDevice() override;
void Reshape(); void Reshape();
void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
string start_desc, stop_desc, step_desc, dtype;
TIndex start, stop, step, count; TIndex start, stop, step, count;
Tensor* dynamic_start, *dynamic_stop, *dynamic_step;
string dynamic_start_, dynamic_stop_, dynamic_step_;
string dtype;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -16,19 +16,13 @@ class TileOp : public Operator<Context> { ...@@ -16,19 +16,13 @@ class TileOp : public Operator<Context> {
public: public:
TileOp(const OperatorDef& op_def, Workspace* ws) TileOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
multiples(OperatorBase::GetRepeatedArg<int>("multiples")) { multiples_desc(OperatorBase::GetRepeatedArg<string>("multiples")) {}
for (int i = 0; i < multiples.size(); i++)
if (multiples[i] > 1)
process_axes.push_back({ multiples[i], i });
std::sort(process_axes.begin(), process_axes.end());
}
void RunOnDevice() override; void RunOnDevice() override;
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
protected: protected:
vector<int> multiples; vector<string> multiples_desc;
vector< pair<int, int> > process_axes;
TIndex axis, multiple, outer_dim, ex_inner_dim; TIndex axis, multiple, outer_dim, ex_inner_dim;
Tensor* dest, *source; Tensor* dest, *source;
}; };
...@@ -38,12 +32,7 @@ class TileGradientOp : public Operator<Context> { ...@@ -38,12 +32,7 @@ class TileGradientOp : public Operator<Context> {
public: public:
TileGradientOp(const OperatorDef& op_def, Workspace* ws) TileGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
multiples(OperatorBase::GetRepeatedArg<int>("multiples")) { multiples_desc(OperatorBase::GetRepeatedArg<string>("multiples")) {
for (int i = 0; i < multiples.size(); i++)
if (multiples[i] > 1)
process_axes.push_back({ multiples[i], i });
std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end());
DISABLE_SHARE_GRADIENT; DISABLE_SHARE_GRADIENT;
} }
...@@ -51,8 +40,7 @@ class TileGradientOp : public Operator<Context> { ...@@ -51,8 +40,7 @@ class TileGradientOp : public Operator<Context> {
template<typename T> void TileRunWithType(); template<typename T> void TileRunWithType();
protected: protected:
vector<int> multiples; vector<string> multiples_desc;
vector< pair<int, int> > process_axes;
TIndex axis, multiple, outer_dim, ex_inner_dim; TIndex axis, multiple, outer_dim, ex_inner_dim;
Tensor* dest, *source; Tensor* dest, *source;
}; };
......
...@@ -81,13 +81,17 @@ class FusedBatchNormOp : public Operator<Context> { ...@@ -81,13 +81,17 @@ class FusedBatchNormOp : public Operator<Context> {
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {} use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup() { NOT_IMPLEMENTED; } void Setup();
void RunOnDevice() override { NOT_IMPLEMENTED; } void RunOnDevice() override;
template <typename T> void RunWithType() { NOT_IMPLEMENTED; } template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected: protected:
float momentum, eps; float momentum, eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex axis, N, C, S, NC, NS; TIndex axis, N, C, S, NC, NS;
string data_format; string data_format;
int use_stats; int use_stats;
...@@ -103,15 +107,19 @@ class FusedBatchNormGradientOp : public Operator<Context> { ...@@ -103,15 +107,19 @@ class FusedBatchNormGradientOp : public Operator<Context> {
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { } use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { }
void Setup() { NOT_IMPLEMENTED; } void Setup();
void ShareGradient() override; void ShareGradient() override;
void RunOnDevice() override { NOT_IMPLEMENTED; } void RunOnDevice() override;
template <typename T> void RunWithType() { NOT_IMPLEMENTED; } template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected: protected:
float eps; float eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex axis, N, C, S, NC, NS; TIndex axis, N, C, S, NC, NS;
string data_format; string data_format;
int use_stats; int use_stats;
......
...@@ -16,8 +16,7 @@ class BilinearResizeOp : public Operator<Context> { ...@@ -16,8 +16,7 @@ class BilinearResizeOp : public Operator<Context> {
public: public:
BilinearResizeOp(const OperatorDef& op_def, Workspace* ws) BilinearResizeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")), dsize_desc(OperatorBase::GetRepeatedArg<string>("dsize")),
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)), fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)), fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
...@@ -29,8 +28,7 @@ class BilinearResizeOp : public Operator<Context> { ...@@ -29,8 +28,7 @@ class BilinearResizeOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<int> static_dsize; vector<string> dsize_desc;
vector<string> dynamic_dsize;
float fy, fx; float fy, fx;
string data_format; string data_format;
TIndex n, c, h, w, out_h, out_w, spatial_axis; TIndex n, c, h, w, out_h, out_w, spatial_axis;
......
...@@ -22,8 +22,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -22,8 +22,7 @@ class ConvOpBase : public Operator<Context> {
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")), padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
num_output(OperatorBase::GetSingleArg<int>("num_output", 1)), num_output(OperatorBase::GetSingleArg<int>("num_output", 1)),
group(OperatorBase::GetSingleArg<int>("group", 1)), group(OperatorBase::GetSingleArg<int>("group", 1)),
static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")), output_dims_desc(OperatorBase::GetRepeatedArg<string>("output_shape")) {
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")) {
if (data_format == "NCHW") spatial_axis = 2; if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1; else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format; else LOG(FATAL) << "Unknown data format: " << data_format;
...@@ -42,8 +41,7 @@ class ConvOpBase : public Operator<Context> { ...@@ -42,8 +41,7 @@ class ConvOpBase : public Operator<Context> {
TIndex conv_in_channels, conv_out_channels; TIndex conv_in_channels, conv_out_channels;
TIndex conv_out_spatial_dim, kernel_dim; TIndex conv_out_spatial_dim, kernel_dim;
TIndex col_offset, output_offset, weight_offset, x_offset, y_offset; TIndex col_offset, output_offset, weight_offset, x_offset, y_offset;
vector<int> static_dsize; vector<string> output_dims_desc;
vector<string> dynamic_dsize;
bool is_1x1; bool is_1x1;
void Setup(); void Setup();
......
...@@ -16,8 +16,7 @@ class NNResizeOp : public Operator<Context> { ...@@ -16,8 +16,7 @@ class NNResizeOp : public Operator<Context> {
public: public:
NNResizeOp(const OperatorDef& op_def, Workspace* ws) NNResizeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")), dsize_desc(OperatorBase::GetRepeatedArg<string>("dsize")),
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)), fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)), fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) { data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
...@@ -30,8 +29,7 @@ class NNResizeOp : public Operator<Context> { ...@@ -30,8 +29,7 @@ class NNResizeOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<int> static_dsize; vector<string> dsize_desc;
vector<string> dynamic_dsize;
float fy, fx; float fy, fx;
string data_format; string data_format;
TIndex n, c, h, w, out_h, out_w, spatial_axis; TIndex n, c, h, w, out_h, out_w, spatial_axis;
......
...@@ -323,10 +323,10 @@ void At(const int count, ...@@ -323,10 +323,10 @@ void At(const int count,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const T* indices, const int* indices,
const T* x, const T* x,
T* y, T* y,
Context* context); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void AtGrad(const int count, void AtGrad(const int count,
...@@ -334,10 +334,9 @@ void AtGrad(const int count, ...@@ -334,10 +334,9 @@ void AtGrad(const int count,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const T* indices, const int* indices,
const T* dy, const T* dy,
T* dx, T* dx);
Context* context);
/******************** ndarray.concat ********************/ /******************** ndarray.concat ********************/
......
...@@ -410,7 +410,7 @@ class Tensor(object): ...@@ -410,7 +410,7 @@ class Tensor(object):
""" """
def wrapper_indices(indices): def wrapper_indices(indices):
tensor = Tensor(GetTensorName()) tensor = Tensor(GetTensorName())
ws.FeedTensor(tensor, np.array(indices, dtype=np.float32)) ws.FeedTensor(tensor, np.array(indices, dtype=np.int32))
return tensor return tensor
if not isinstance(item, tuple): if not isinstance(item, tuple):
...@@ -422,8 +422,7 @@ class Tensor(object): ...@@ -422,8 +422,7 @@ class Tensor(object):
output.shape[0] = 1 output.shape[0] = 1
return output return output
else: else:
# ND Crop raise TypeError('Unsupported type of indices: {}'.format(type(item)))
item = (item, )
starts = [] starts = []
ends = [] ends = []
output_dims = [] output_dims = []
...@@ -853,6 +852,21 @@ class Tensor(object): ...@@ -853,6 +852,21 @@ class Tensor(object):
""" """
raise NotImplementedError('Implemented in <vm.tensorflow.framework.tensor_shape>') raise NotImplementedError('Implemented in <vm.tensorflow.framework.tensor_shape>')
def eval(self, feed_dict=None):
"""Run and return the computing results of this tensor.
Parameters
----------
feed_dict : dict
The values to feed.
Returns
-------
numpy.ndarray
The values of this tensor in the backend.
"""
raise NotImplementedError('Implemented in <vm.theano.compile.function>')
############################################ ############################################
# # # #
# MISC # # MISC #
...@@ -970,6 +984,39 @@ class Tensor(object): ...@@ -970,6 +984,39 @@ class Tensor(object):
elif nout == 1: return outputs[0] elif nout == 1: return outputs[0]
else: return None else: return None
@classmethod
def Convert(cls, value, dtype='float32'):
"""Convert the given value to a tensor.
Parameters
----------
value : numerical type
The value to convert.
dtype : str
The data type of the tensor.
Returns
-------
Tensor
The tensor converted with given value.
"""
if isinstance(value, Tensor):
return value
else:
if isinstance(value, (list, tuple)):
np_value = np.array(value, dtype=dtype)
elif isinstance(value, np.ndarray):
np_value = value.astype(dtype=dtype)
else:
try:
np_value = np.array(value, dtype=dtype)
except:
raise TypeError('{} value can not be converted to tensor.'.format(type(value)))
tensor = Tensor(shape=list(np_value.shape), dtype=dtype)
tensor.set_value(np_value)
return tensor
def Fill(self, type, **kwargs): def Fill(self, type, **kwargs):
"""Fill self with the specific type of filler. """Fill self with the specific type of filler.
......
...@@ -13,12 +13,14 @@ List Brief ...@@ -13,12 +13,14 @@ List Brief
============================== ============================================================================= ============================== =============================================================================
`Tensor.name`_ Return or Set the name. `Tensor.name`_ Return or Set the name.
`Tensor.shape`_ Return or Set the shape. `Tensor.shape`_ Return or Set the shape.
`Tensor.get_shape`_ Return the shape.
`Tensor.dtype`_ Return or Set the data type. `Tensor.dtype`_ Return or Set the data type.
`Tensor.set_value`_ Feed the values to C++ backend. `Tensor.set_value`_ Feed the values to C++ backend.
`Tensor.get_value`_ Fetch the values from C++ backend. `Tensor.get_value`_ Fetch the values from C++ backend.
`Tensor.copy`_ Return a Tensor with same content. `Tensor.copy`_ Return a Tensor with same content.
`Tensor.reshape`_ Reshape the dimensions of input. `Tensor.reshape`_ Reshape the dimensions of input.
`Tensor.dimshuffle`_ Shuffle the dimensions. `Tensor.dimshuffle`_ Shuffle the dimensions.
`Tensor.eval`_ Run and return the computing results of this tensor.
`Tensor.CreateOperator`_ Construct a new Tensor with specific operator descriptor. `Tensor.CreateOperator`_ Construct a new Tensor with specific operator descriptor.
`Tensor.Fill`_ Fill self with the specific type of filler. `Tensor.Fill`_ Fill self with the specific type of filler.
`Tensor.PrintExpressions`_ Return the stringified internal expressions. `Tensor.PrintExpressions`_ Return the stringified internal expressions.
...@@ -102,12 +104,14 @@ API Reference ...@@ -102,12 +104,14 @@ API Reference
.. _Tensor.name: #dragon.core.tensor.Tensor.name .. _Tensor.name: #dragon.core.tensor.Tensor.name
.. _Tensor.shape: #dragon.core.tensor.Tensor.shape .. _Tensor.shape: #dragon.core.tensor.Tensor.shape
.. _Tensor.get_shape: #dragon.core.tensor.Tensor.get_shape
.. _Tensor.dtype: #dragon.core.tensor.Tensor.dtype .. _Tensor.dtype: #dragon.core.tensor.Tensor.dtype
.. _Tensor.set_value: #dragon.core.tensor.Tensor.set_value .. _Tensor.set_value: #dragon.core.tensor.Tensor.set_value
.. _Tensor.get_value: #dragon.core.tensor.Tensor.get_value .. _Tensor.get_value: #dragon.core.tensor.Tensor.get_value
.. _Tensor.copy: #dragon.core.tensor.Tensor.copy .. _Tensor.copy: #dragon.core.tensor.Tensor.copy
.. _Tensor.reshape: #dragon.core.tensor.Tensor.reshape .. _Tensor.reshape: #dragon.core.tensor.Tensor.reshape
.. _Tensor.dimshuffle: #dragon.core.tensor.Tensor.dimshuffle .. _Tensor.dimshuffle: #dragon.core.tensor.Tensor.dimshuffle
.. _Tensor.eval: #dragon.core.tensor.Tensor.eval
.. _Tensor.CreateOperator: #dragon.core.tensor.Tensor.CreateOperator .. _Tensor.CreateOperator: #dragon.core.tensor.Tensor.CreateOperator
.. _Tensor.Fill: #dragon.core.tensor.Tensor.Fill .. _Tensor.Fill: #dragon.core.tensor.Tensor.Fill
.. _Tensor.PrintExpressions: #dragon.core.tensor.Tensor.PrintExpressions .. _Tensor.PrintExpressions: #dragon.core.tensor.Tensor.PrintExpressions
......
...@@ -6,34 +6,48 @@ ...@@ -6,34 +6,48 @@
from . import * from . import *
def _wrap_input_shape(arguments, shape):
if isinstance(shape, Tensor):
arguments['extra_inputs'] = shape
arguments['shape'] = shape.name
elif isinstance(shape, (list, tuple)):
arguments['extra_inputs'] = [Tensor.Convert(dim, dtype='int32') for dim in shape]
arguments['dims'] = [dim.name for dim in arguments['extra_inputs']]
arguments['shape'] = None
else:
raise TypeError('Unsupported type of shape: {}'.format(type(shape)))
return arguments
def _wrap_output_shape(output, shape):
if not isinstance(shape, Tensor):
if any(isinstance(dim, Tensor) for dim in shape): return output
output.shape = [dim for dim in shape]
return output
def Fill(shape, value=0, **kwargs): def Fill(shape, value=0, **kwargs):
"""Return a Tensor with specific value filled. """Return a Tensor with specific value filled.
Parameters Parameters
---------- ----------
shape : list, tuple or Tensor shape : list, tuple or Tensor
The shape of the new tensor. The output shape.
value : basic numerical type value : basic numerical type
The value of the new tensor. The value to fill.
Returns Returns
------- -------
Tensor Tensor
The value-filled Tensor. The constant-filled tensor.
""" """
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['value'] = float(value) arguments['value'] = float(value)
if not isinstance(shape, Tensor): arguments = _wrap_input_shape(arguments, shape)
arguments['static_shape'] = shape
else:
arguments['dynamic_shape'] = shape.name
arguments['extra_inputs'] = shape
del arguments['shape']
output = Tensor.CreateOperator([], nout=1, op_type='Fill', **arguments) output = Tensor.CreateOperator([], nout=1, op_type='Fill', **arguments)
output.shape = arguments['static_shape'] if 'static_shape' in arguments else None return _wrap_output_shape(output, shape)
return output
def RandomUniform(shape, low=-1.0, high=1.0, **kwargs): def RandomUniform(shape, low=-1.0, high=1.0, **kwargs):
...@@ -57,16 +71,9 @@ def RandomUniform(shape, low=-1.0, high=1.0, **kwargs): ...@@ -57,16 +71,9 @@ def RandomUniform(shape, low=-1.0, high=1.0, **kwargs):
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['low'] = float(low) arguments['low'] = float(low)
arguments['high'] = float(high) arguments['high'] = float(high)
if not isinstance(shape, Tensor): arguments = _wrap_input_shape(arguments, shape)
arguments['static_shape'] = shape
else:
arguments['dynamic_shape'] = shape.name
arguments['extra_inputs'] = shape
del arguments['shape']
output = Tensor.CreateOperator([], nout=1, op_type='RandomUniform', **arguments) output = Tensor.CreateOperator([], nout=1, op_type='RandomUniform', **arguments)
output.shape = arguments['static_shape'] if 'static_shape' in arguments else None return _wrap_output_shape(output, shape)
return output
def RandomNormal(shape, mean=0.0, std=1.0, **kwargs): def RandomNormal(shape, mean=0.0, std=1.0, **kwargs):
...@@ -90,16 +97,9 @@ def RandomNormal(shape, mean=0.0, std=1.0, **kwargs): ...@@ -90,16 +97,9 @@ def RandomNormal(shape, mean=0.0, std=1.0, **kwargs):
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['mean'] = float(mean) arguments['mean'] = float(mean)
arguments['std'] = float(std) arguments['std'] = float(std)
if not isinstance(shape, Tensor): arguments = _wrap_input_shape(arguments, shape)
arguments['static_shape'] = shape
else:
arguments['dynamic_shape'] = shape.name
arguments['extra_inputs'] = shape
del arguments['shape']
output = Tensor.CreateOperator([], nout=1, op_type='RandomNormal', **arguments) output = Tensor.CreateOperator([], nout=1, op_type='RandomNormal', **arguments)
output.shape = arguments['static_shape'] if 'static_shape' in arguments else None return _wrap_output_shape(output, shape)
return output
def TruncatedNormal(shape, mean=0.0, std=1.0, **kwargs): def TruncatedNormal(shape, mean=0.0, std=1.0, **kwargs):
...@@ -127,16 +127,9 @@ def TruncatedNormal(shape, mean=0.0, std=1.0, **kwargs): ...@@ -127,16 +127,9 @@ def TruncatedNormal(shape, mean=0.0, std=1.0, **kwargs):
arguments['std'] = float(std) arguments['std'] = float(std)
arguments['low'] = float(mean - 2.0 * std) arguments['low'] = float(mean - 2.0 * std)
arguments['high'] = float(mean + 2.0 * std) arguments['high'] = float(mean + 2.0 * std)
if not isinstance(shape, Tensor): arguments = _wrap_input_shape(arguments, shape)
arguments['static_shape'] = shape
else:
arguments['dynamic_shape'] = shape.name
arguments['extra_inputs'] = shape
del arguments['shape']
output = Tensor.CreateOperator([], nout=1, op_type='TruncatedNormal', **arguments) output = Tensor.CreateOperator([], nout=1, op_type='TruncatedNormal', **arguments)
output.shape = arguments['static_shape'] if 'static_shape' in arguments else None return _wrap_output_shape(output, shape)
return output
def GlorotUniform(shape, scale=3.0, mode='FAN_IN', **kwargs): def GlorotUniform(shape, scale=3.0, mode='FAN_IN', **kwargs):
...@@ -162,16 +155,9 @@ def GlorotUniform(shape, scale=3.0, mode='FAN_IN', **kwargs): ...@@ -162,16 +155,9 @@ def GlorotUniform(shape, scale=3.0, mode='FAN_IN', **kwargs):
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['scale'] = float(scale) arguments['scale'] = float(scale)
arguments['mode'] = mode.lower() arguments['mode'] = mode.lower()
if not isinstance(shape, Tensor): arguments = _wrap_input_shape(arguments, shape)
arguments['static_shape'] = shape
else:
arguments['dynamic_shape'] = shape.name
arguments['extra_inputs'] = shape
del arguments['shape']
output = Tensor.CreateOperator([], nout=1, op_type='GlorotUniform', **arguments) output = Tensor.CreateOperator([], nout=1, op_type='GlorotUniform', **arguments)
output.shape = arguments['static_shape'] if 'static_shape' in arguments else None return _wrap_output_shape(output, shape)
return output
def GlorotNormal(shape, scale=2.0, mode='FAN_IN', **kwargs): def GlorotNormal(shape, scale=2.0, mode='FAN_IN', **kwargs):
...@@ -197,13 +183,6 @@ def GlorotNormal(shape, scale=2.0, mode='FAN_IN', **kwargs): ...@@ -197,13 +183,6 @@ def GlorotNormal(shape, scale=2.0, mode='FAN_IN', **kwargs):
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['scale'] = float(scale) arguments['scale'] = float(scale)
arguments['mode'] = mode.lower() arguments['mode'] = mode.lower()
if not isinstance(shape, Tensor): arguments = _wrap_input_shape(arguments, shape)
arguments['static_shape'] = shape
else:
arguments['dynamic_shape'] = shape.name
arguments['extra_inputs'] = shape
del arguments['shape']
output = Tensor.CreateOperator([], nout=1, op_type='GlorotNormal', **arguments) output = Tensor.CreateOperator([], nout=1, op_type='GlorotNormal', **arguments)
output.shape = arguments['static_shape'] if 'static_shape' in arguments else None return _wrap_output_shape(output, shape)
return output \ No newline at end of file
\ No newline at end of file
...@@ -470,7 +470,7 @@ def Tile(inputs, multiples, **kwargs): ...@@ -470,7 +470,7 @@ def Tile(inputs, multiples, **kwargs):
---------- ----------
input : Tensor input : Tensor
The input tensor. The input tensor.
multiples : list of int multiples : list
The multiple of each axis. The multiple of each axis.
Returns Returns
...@@ -481,15 +481,21 @@ def Tile(inputs, multiples, **kwargs): ...@@ -481,15 +481,21 @@ def Tile(inputs, multiples, **kwargs):
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['extra_inputs'] = [Tensor.Convert(multiple, dtype='int32') for multiple in multiples]
arguments['multiples'] = [multiple.name for multiple in arguments['extra_inputs']]
output = Tensor.CreateOperator(nout=1, op_type='Tile', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Tile', **arguments)
if inputs.shape is not None: if inputs.shape is not None:
if len(inputs.shape) != len(multiples): if len(inputs.shape) != len(multiples):
raise ValueError('input ndim is {}, but multiples provide {}'. \ raise ValueError('The num of dimensions of input is {}, but provided {}.'
format(len(inputs.shape), len(multiples))) .format(len(inputs.shape), len(multiples)))
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
for i, multiple in enumerate(multiples): for i, multiple in enumerate(multiples):
if output.shape[i] is None or \
isinstance(output.shape[i], Tensor):
output.shape[i] = None
else:
output.shape[i] *= multiple output.shape[i] *= multiple
return output return output
...@@ -755,7 +761,7 @@ def Arange(start, stop=None, step=1, dtype='FLOAT32', **kwargs): ...@@ -755,7 +761,7 @@ def Arange(start, stop=None, step=1, dtype='FLOAT32', **kwargs):
step : int or Tensor step : int or Tensor
The interval between two elements. The interval between two elements.
dtype : str dtype : str
The data type. ``FLOAT32`` or ``INT32``. The data type. ``float32`` or ``int32``.
Returns Returns
------- -------
...@@ -764,30 +770,26 @@ def Arange(start, stop=None, step=1, dtype='FLOAT32', **kwargs): ...@@ -764,30 +770,26 @@ def Arange(start, stop=None, step=1, dtype='FLOAT32', **kwargs):
""" """
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['extra_inputs'] = [] arguments['extra_inputs'] = [Tensor.Convert(start, dtype='int32'),
if not isinstance(start, Tensor): arguments['static_start'] = int(start) Tensor.Convert(step, dtype='int32')]
else: arguments['start'] = arguments['extra_inputs'][0].name
arguments['dynamic_start'] = start.name arguments['step'] = arguments['extra_inputs'][1].name
arguments['extra_inputs'].append(start)
if stop is not None: if stop is not None:
if not isinstance(stop, Tensor): arguments['static_stop'] = int(stop) arguments['extra_inputs'].append(Tensor.Convert(stop, dtype='int32'))
else: arguments['stop'] = arguments['extra_inputs'][-1].name
arguments['dynamic_stop'] = stop.name arguments['dtype'] = arguments['dtype'].upper()
arguments['extra_inputs'].append(stop)
del arguments['stop']
if not isinstance(step, Tensor): arguments['static_step'] = int(step)
else:
arguments['dynamic_step'] = step.name
arguments['extra_inputs'].append(step)
del arguments['start']; del arguments['step']
output = Tensor.CreateOperator([], nout=1, op_type='Arange', **arguments) output = Tensor.CreateOperator([], nout=1, op_type='Arange', **arguments)
if 'static_start' in arguments and \ if not isinstance(start, Tensor) and \
'static_step' in arguments: not isinstance(step, Tensor):
if 'dynamic_stop' not in arguments: if stop is not None:
if stop is None: stop = start; start = 0 if isinstance(stop, Tensor):
count = (stop - start - 1) / step + 1 return output
output.shape = [np.long(count)] else:
stop = start
start = 0
count = int((stop - start - 1) / step) + 1
output.shape = [count]
return output return output
\ No newline at end of file
...@@ -88,6 +88,7 @@ def Conv2d(inputs, num_output, kernel_size, ...@@ -88,6 +88,7 @@ def Conv2d(inputs, num_output, kernel_size,
spatial_axis = 2 if data_format == 'NCHW' else 1 spatial_axis = 2 if data_format == 'NCHW' else 1
output.shape[channel_axis] = num_output output.shape[channel_axis] = num_output
for i in xrange(2): for i in xrange(2):
input_size = output.shape[i + spatial_axis]
k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \ k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \
else arguments['kernel_size'][-1] else arguments['kernel_size'][-1]
s = arguments['stride'][i] if i < len(arguments['stride']) \ s = arguments['stride'][i] if i < len(arguments['stride']) \
...@@ -99,10 +100,9 @@ def Conv2d(inputs, num_output, kernel_size, ...@@ -99,10 +100,9 @@ def Conv2d(inputs, num_output, kernel_size,
dk = d * (k - 1) + 1 dk = d * (k - 1) + 1
dp = 2 * p dp = 2 * p
if padding == 'SAME': if padding == 'SAME':
input_size = output.shape[i + spatial_axis] output.shape[i + spatial_axis] = int((input_size + s - 1) / s)
output_size = (input_size + s - 1) / float(s) else:
dp = int(max(0, (output_size - 1) * s + k - input_size)) output.shape[i + spatial_axis] = int((input_size + dp - dk) / s) + 1
output.shape[i + spatial_axis] = int(output.shape[i + spatial_axis] + dp - dk / s) + 1
return output return output
...@@ -173,15 +173,8 @@ def Conv2dTranspose(inputs, num_output, kernel_size, ...@@ -173,15 +173,8 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
if output_shape is not None: if output_shape is not None:
if not isinstance(output_shape, list): if not isinstance(output_shape, list):
raise TypeError('The output shape should be a list.') raise TypeError('The output shape should be a list.')
if isinstance(output_shape[0], Tensor): arguments['extra_inputs'] = [Tensor.Convert(dim, dtype='int32') for dim in output_shape]
arguments['dynamic_dsize'] = [] arguments['output_shape'] = [dim.name for dim in arguments['extra_inputs']]
arguments['extra_inputs'] = list(output_shape)
for dim in output_shape:
arguments['dynamic_dsize'].append(dim)
else:
arguments['static_dsize'] = []
for dim in output_shape:
arguments['static_dsize'].append(int(dim))
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
...@@ -216,9 +209,10 @@ def Conv2dTranspose(inputs, num_output, kernel_size, ...@@ -216,9 +209,10 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
else: else:
if output_shape is None: if output_shape is None:
raise ValueError('The output shape must be specified if using SAME padding algorithm.') raise ValueError('The output shape must be specified if using SAME padding algorithm.')
if 'dynamic_dsize' in arguments: if isinstance(output_shape[i + spatial_axis], Tensor):
output.shape = None output.shape = None
return output return output
else:
output.shape[i + spatial_axis] = output_shape[i + spatial_axis] output.shape[i + spatial_axis] = output_shape[i + spatial_axis]
return output return output
...@@ -433,14 +427,11 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -433,14 +427,11 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format)) raise ValueError('Unsupported data format: {}'.format(data_format))
if arguments['dsize'] is not None: if dsize is not None:
if isinstance(arguments['dsize'][0], Tensor): if len(dsize) != 2:
arguments['dynamic_dsize'] = [arguments['dsize'][0].name, raise ValueError('The dsize should be a list with 2 elements.')
arguments['dsize'][1].name] arguments['extra_inputs'] = [Tensor.Convert(size, dtype='int32') for size in dsize]
arguments['extra_inputs'] = list(arguments['dsize']) arguments['dsize'] = [size.name for size in arguments['extra_inputs']]
else:
arguments['static_size'] = arguments['dsize']
del arguments['dsize']
if dsize is None and (fy == -1.0 or fx == -1.0): if dsize is None and (fy == -1.0 or fx == -1.0):
raise RuntimeError('The dsize or fy/fx should be specified either.') raise RuntimeError('The dsize or fy/fx should be specified either.')
...@@ -450,12 +441,18 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -450,12 +441,18 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
if inputs.shape is not None: if inputs.shape is not None:
if len(inputs.shape) != 4: if len(inputs.shape) != 4:
raise ValueError('The inputs should be a 4d Tensor.') raise ValueError('The inputs should be a 4d Tensor.')
if 'dynamic_dsize' not in arguments: possible_to_infer_shape = True
if dsize is not None:
for size in dsize:
if isinstance(size, Tensor):
possible_to_infer_shape = False
if possible_to_infer_shape:
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
spatial_axis = 2 if data_format == 'NCHW' else 1 spatial_axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2): for i in xrange(2):
output_dim = output.shape[spatial_axis + i] output_dim = output.shape[spatial_axis + i]
if 'static_size' in arguments: if dsize is not None:
output_dim = dsize[i] output_dim = dsize[i]
else: else:
output_dim = int(float(output_dim) * ([fy, fx])[i]) output_dim = int(float(output_dim) * ([fy, fx])[i])
...@@ -494,14 +491,11 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -494,14 +491,11 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
if data_format not in ('NCHW', 'NHWC'): if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format)) raise ValueError('Unsupported data format: {}'.format(data_format))
if arguments['dsize'] is not None: if dsize is not None:
if isinstance(arguments['dsize'][0], Tensor): if len(dsize) != 2:
arguments['dynamic_dsize'] = [arguments['dsize'][0].name, raise ValueError('The dsize should be a list with 2 elements.')
arguments['dsize'][1].name] arguments['extra_inputs'] = [Tensor.Convert(size, dtype='int32') for size in dsize]
arguments['extra_inputs'] = list(arguments['dsize']) arguments['dsize'] = [size.name for size in arguments['extra_inputs']]
else:
arguments['static_size'] = arguments['dsize']
del arguments['dsize']
if dsize is None and (fy == -1.0 or fx == -1.0): if dsize is None and (fy == -1.0 or fx == -1.0):
raise RuntimeError('The dsize or fy/fx should be specified either.') raise RuntimeError('The dsize or fy/fx should be specified either.')
...@@ -511,12 +505,18 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -511,12 +505,18 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
if inputs.shape is not None: if inputs.shape is not None:
if len(inputs.shape) != 4: if len(inputs.shape) != 4:
raise ValueError('The inputs should be a 4d Tensor.') raise ValueError('The inputs should be a 4d Tensor.')
if 'dynamic_dsize' not in arguments: possible_to_infer_shape = True
if dsize is not None:
for size in dsize:
if isinstance(size, Tensor):
possible_to_infer_shape = False
if possible_to_infer_shape:
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
spatial_axis = 2 if data_format == 'NCHW' else 1 spatial_axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2): for i in xrange(2):
output_dim = output.shape[spatial_axis + i] output_dim = output.shape[spatial_axis + i]
if 'static_size' in arguments: if dsize is not None:
output_dim = dsize[i] output_dim = dsize[i]
else: else:
output_dim = int(float(output_dim) * ([fy, fx])[i]) output_dim = int(float(output_dim) * ([fy, fx])[i])
......
...@@ -36,7 +36,7 @@ find_packages('dragon') ...@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules() find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2', version='0.2.1.1',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon', url='https://github.com/neopenx/Dragon',
author='Ting Pan', author='Ting Pan',
......
...@@ -13,16 +13,22 @@ void InitializeOp<Context>::RunWithType() { ...@@ -13,16 +13,22 @@ void InitializeOp<Context>::RunWithType() {
template <class Context> template <class Context>
void InitializeOp<Context>::RunOnDevice() { void InitializeOp<Context>::RunOnDevice() {
vector<TIndex> dims; vector<TIndex> dims;
if (dynamic_shape.empty()) { if (shape_desc.empty()) {
for (auto& dim : static_shape) dims.push_back(dim); // determine the shape from dimensions
for (auto& dim_desc : dims_desc) {
Tensor* dim = ws()->GetTensor(dim_desc);
CHECK_EQ(dim->count(), 1) << "\nThe dimension should be a scalar.";
CHECK(dim->IsType<int>()) << "\nThe type of dimension should be int32.";
dims.push_back(dim->template data<int, CPUContext>()[0]);
}
} else { } else {
auto* shape_data = ws()->GetTensor(dynamic_shape) // determine the shape from given shape
->template data<float, CPUContext>(); Tensor* shape = ws()->GetTensor(shape_desc);
TIndex ndim = ws()->GetTensor(dynamic_shape)->count(); CHECK(shape->IsType<int>()) << "\nThe type of shape should be int32.";
for (int i = 0; i < ndim; i++) dims.push_back(shape_data[i]); auto* shape_data = shape->template data<int, CPUContext>();
for (int i = 0; i < shape->count(); i++) dims.push_back(shape_data[i]);
} }
output(0)->Reshape(dims); output(0)->Reshape(dims);
RunWithType<float>(); RunWithType<float>();
} }
......
...@@ -6,43 +6,24 @@ namespace dragon { ...@@ -6,43 +6,24 @@ namespace dragon {
template <class Context> template <class Context>
void ArangeOp<Context>::Reshape() { void ArangeOp<Context>::Reshape() {
if (!dynamic_start_.empty()) { // parse start & step & stop
dynamic_start = ws()->GetTensor(dynamic_start_); Tensor* t = ws()->GetTensor(start_desc);
CHECK_EQ(dynamic_start->count(), 1) CHECK_EQ(t->count(), 1) << "\nThe start should be a scalar";
<< "The start should be a scalar"; CHECK(t->IsType<int>()) << "\nThe type of start should be int32.";
if (dynamic_start->IsType<int>()) { start = t->template data<int, CPUContext>()[0];
start = dynamic_start->template data<int, CPUContext>()[0];
} else if (dynamic_start->IsType<float>()) { t = ws()->GetTensor(step_desc);
start = dynamic_start->template data<float, CPUContext>()[0]; CHECK_EQ(t->count(), 1) << "\nThe step should be a scalar";
} else { CHECK(t->IsType<int>()) << "\nThe type of step should be int32.";
LOG(FATAL) << "Unsupported types of start."; step = t->template data<int, CPUContext>()[0];
}
} if (!stop_desc.empty()) {
if (!dynamic_stop_.empty()) { t = ws()->GetTensor(stop_desc);
dynamic_stop = ws()->GetTensor(dynamic_stop_); CHECK_EQ(t->count(), 1) << "\nThe stop should be a scalar";
CHECK_EQ(dynamic_stop->count(), 1) CHECK(t->IsType<int>()) << "\nThe type of stop should be int32.";
<< "The stop should be a scalar"; stop = t->template data<int, CPUContext>()[0];
if (dynamic_stop->IsType<int>()) { } else { stop = start; start = 0; }
stop = dynamic_stop->template data<int, CPUContext>()[0];
} else if (dynamic_stop->IsType<float>()) {
stop = dynamic_stop->template data<float, CPUContext>()[0];
} else {
LOG(FATAL) << "Unsupported types of stop.";
}
}
if (!dynamic_step_.empty()) {
dynamic_step = ws()->GetTensor(dynamic_step_);
CHECK_EQ(dynamic_step->count(), 1)
<< "The step should be a scalar";
if (dynamic_step->IsType<int>()) {
step = dynamic_step->template data<int, CPUContext>()[0];
} else if (dynamic_step->IsType<float>()) {
step = dynamic_step->template data<float, CPUContext>()[0];
} else {
LOG(FATAL) << "Unsupported types of step.";
}
}
if (stop == -1) { stop = start; start = 0; }
count = (stop - start - 1) / step + 1; count = (stop - start - 1) / step + 1;
output(0)->Reshape(vector<TIndex>(1, count)); output(0)->Reshape(vector<TIndex>(1, count));
} }
......
...@@ -8,12 +8,11 @@ namespace dragon { ...@@ -8,12 +8,11 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void AtOp<Context>::RunWithType() { void AtOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* indices = input(1).template mutable_data<T, Context>(); auto* indices = input(1).template mutable_data<int, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::CanonicalAxis<T, Context>(input(1).count(), x_slice_dim, indices); kernel::CanonicalAxis<int, Context>(input(1).count(), x_slice_dim, indices);
kernel::At<T, Context>(output(0)->count(), outer_dim, inner_dim, kernel::At<T, Context>(output(0)->count(), outer_dim, inner_dim,
x_slice_dim, x_slice_dim, y_slice_dim,
y_slice_dim,
indices, indices,
Xdata, Xdata,
Ydata, Ydata,
...@@ -30,7 +29,9 @@ void AtOp<Context>::RunOnDevice() { ...@@ -30,7 +29,9 @@ void AtOp<Context>::RunOnDevice() {
inner_dim = input(0).count(axis + 1); inner_dim = input(0).count(axis + 1);
output(0)->Reshape(output_dims); output(0)->Reshape(output_dims);
CHECK(input(1).template IsType<int>()) << "\nThe type of indices should be int32.";
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
else if (input(0).template IsType<int>()) RunWithType<int>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
...@@ -42,12 +43,15 @@ OPERATOR_SCHEMA(At).NumInputs(2).NumOutputs(1); ...@@ -42,12 +43,15 @@ OPERATOR_SCHEMA(At).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void AtGradientOp<Context>::RunWithType() { void AtGradientOp<Context>::RunWithType() {
auto* indices = input(1).template data<T, Context>(); auto* indices = input(1).template data<int, Context>();
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
if (!acc_grad) math::Set<T, Context>(output(0)->count(), 0, dXdata); if (!acc_grad) math::Set<T, Context>(output(0)->count(), 0, dXdata);
kernel::AtGrad<T, Context>(input(-1).count(), outer_dim, inner_dim, kernel::AtGrad<T, Context>(input(-1).count(), outer_dim, inner_dim,
x_slice_dim, y_slice_dim, indices, dYdata, dXdata, &ctx()); x_slice_dim, y_slice_dim,
indices,
dYdata,
dXdata);
} }
template <class Context> template <class Context>
...@@ -58,7 +62,9 @@ void AtGradientOp<Context>::RunOnDevice() { ...@@ -58,7 +62,9 @@ void AtGradientOp<Context>::RunOnDevice() {
inner_dim = input(0).count(axis + 1); inner_dim = input(0).count(axis + 1);
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
CHECK(input(1).template IsType<int>()) << "\nThe type of indices should be int32.";
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
else if (input(0).template IsType<int>()) RunWithType<int>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
......
...@@ -7,12 +7,12 @@ namespace dragon { ...@@ -7,12 +7,12 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void RandomPickOp<Context>::RunWithType() { void RandomPickOp<Context>::RunWithType() {
auto* indices = pick_indices->template mutable_data<T, CPUContext>(); auto* indices = pick_indices->template mutable_data<int, CPUContext>();
for (int i = 0; i < pick_indices->count(); i++) for (int i = 0; i < pick_indices->count(); i++)
indices[i] = T((*rand_generator())() % x_slice_dim); indices[i] = int((*rand_generator())() % x_slice_dim);
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
indices = pick_indices->template mutable_data<T, Context>(); indices = pick_indices->template mutable_data<int, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::At<T, Context>(output(0)->count(), outer_dim, inner_dim, kernel::At<T, Context>(output(0)->count(), outer_dim, inner_dim,
x_slice_dim, x_slice_dim,
...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(RandomPick).NumInputs(1).NumOutputs(2); ...@@ -53,7 +53,7 @@ OPERATOR_SCHEMA(RandomPick).NumInputs(1).NumOutputs(2);
template <class Context> template <typename T> template <class Context> template <typename T>
void RandomPickGradientOp<Context>::RunWithType() { void RandomPickGradientOp<Context>::RunWithType() {
auto* indices = pick_indices->template data<T, Context>(); auto* indices = pick_indices->template data<int, Context>();
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(output(0)->count(), 0, dXdata); math::Set<T, Context>(output(0)->count(), 0, dXdata);
...@@ -62,8 +62,7 @@ void RandomPickGradientOp<Context>::RunWithType() { ...@@ -62,8 +62,7 @@ void RandomPickGradientOp<Context>::RunWithType() {
y_slice_dim, y_slice_dim,
indices, indices,
dYdata, dYdata,
dXdata, dXdata);
&ctx());
} }
template <class Context> template <class Context>
......
...@@ -8,7 +8,7 @@ void ShapeOp<Context>::RunOnDevice() { ...@@ -8,7 +8,7 @@ void ShapeOp<Context>::RunOnDevice() {
output(0)->Reshape(vector<TIndex>(1, input(0).ndim())); output(0)->Reshape(vector<TIndex>(1, input(0).ndim()));
// forward // forward
auto* Ydata = output(0)->template mutable_data<float, CPUContext>(); auto* Ydata = output(0)->template mutable_data<int, CPUContext>();
for (int i = 0; i < input(0).ndim(); i++) Ydata[i] = input(0).dim(i); for (int i = 0; i < input(0).ndim(); i++) Ydata[i] = input(0).dim(i);
} }
......
...@@ -25,7 +25,16 @@ void TileOp<Context>::TileRunWithType() { ...@@ -25,7 +25,16 @@ void TileOp<Context>::TileRunWithType() {
template <class Context> template <class Context>
void TileOp<Context>::RunOnDevice() { void TileOp<Context>::RunOnDevice() {
CHECK_EQ(multiples.size(), input(0).ndim()); // parse tasks from desc
CHECK_EQ(multiples_desc.size(), input(0).ndim())
<< "\nThe num of dimensions of input is " << input(0).ndim()
<< ", but provided " << multiples_desc.size() << " multiples.";
vector< pair<int, int> > process_axes;
for (int i = 0; i < multiples_desc.size(); i++) {
int mult = ws()->GetTensor(multiples_desc[i])->template data<int, CPUContext>()[0];
if (mult > 1) process_axes.push_back({ mult, i });
}
std::sort(process_axes.begin(), process_axes.end());
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
...@@ -81,7 +90,17 @@ void TileGradientOp<Context>::TileRunWithType() { ...@@ -81,7 +90,17 @@ void TileGradientOp<Context>::TileRunWithType() {
template <class Context> template <class Context>
void TileGradientOp<Context>::RunOnDevice() { void TileGradientOp<Context>::RunOnDevice() {
CHECK_EQ(multiples.size(), input(-1).ndim()); // parse tasks from desc
CHECK_EQ(multiples_desc.size(), input(-1).ndim())
<< "\nThe num of dimensions of input is " << input(-1).ndim()
<< ", but provided " << multiples_desc.size() << " multiples.";
vector< pair<int, int> > process_axes;
for (int i = 0; i < multiples_desc.size(); i++) {
int mult = ws()->GetTensor(multiples_desc[i])->template data<int, CPUContext>()[0];
if (mult > 1) process_axes.push_back({ mult, i });
}
std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end());
// do nothing // do nothing
if (process_axes.size() == 0) { if (process_axes.size() == 0) {
......
...@@ -173,52 +173,6 @@ void CuDNNBatchNormGradientOp<Context>::Setup() { ...@@ -173,52 +173,6 @@ void CuDNNBatchNormGradientOp<Context>::Setup() {
} }
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() {
if (output(0)->name() != "ignore") {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
stddev = ws()->GetBuffer();
stddev->ReshapeLike(input(0));
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* hVar_data = input(2).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// compute stddev
ctx().template Copy<T, Context, Context>(var->count(), tVar_data, hVar_data);
math::AddScalar<T, Context>(var->count(), this->eps, tVar_data);
math::Sqrt<T, Context>(var->count(), tVar_data, tVar_data);
// divide scale by stddev
math::Div<T, Context>(var->count(), Sdata, tVar_data, tVar_data);
// compute dE/dY \cot (scale / std(X))
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, tVar_data,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, tVar_data,
0.0, Std_data);
}
math::Mul<T, Context>(output(0)->count(), dYdata, Std_data, dXdata);
ws()->ReleaseBuffer(stddev);
}
}
template <class Context> template <typename T>
void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() { void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() {
// determine the bn desc // determine the bn desc
if (input(0).ndim() == 2) { if (input(0).ndim() == 2) {
...@@ -288,6 +242,52 @@ void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -288,6 +242,52 @@ void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() {
} }
} }
template <class Context> template <typename T>
void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() {
if (output(0)->name() != "ignore") {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
stddev = ws()->GetBuffer();
stddev->ReshapeLike(input(0));
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* hVar_data = input(2).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// compute stddev
ctx().template Copy<T, Context, Context>(var->count(), tVar_data, hVar_data);
math::AddScalar<T, Context>(var->count(), this->eps, tVar_data);
math::Sqrt<T, Context>(var->count(), tVar_data, tVar_data);
// divide scale by stddev
math::Div<T, Context>(var->count(), Sdata, tVar_data, tVar_data);
// compute dE/dY \cot (scale / std(X))
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, tVar_data,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, tVar_data,
0.0, Std_data);
}
math::Mul<T, Context>(output(0)->count(), dYdata, Std_data, dXdata);
ws()->ReleaseBuffer(stddev);
}
}
template <class Context> template <class Context>
void CuDNNBatchNormGradientOp<Context>::RunOnDevice() { void CuDNNBatchNormGradientOp<Context>::RunOnDevice() {
Setup(); Setup();
......
...@@ -34,23 +34,13 @@ void BilinearResizeOp<Context>::RunWithType() { ...@@ -34,23 +34,13 @@ void BilinearResizeOp<Context>::RunWithType() {
template <class Context> template <class Context>
void BilinearResizeOp<Context>::RunOnDevice() { void BilinearResizeOp<Context>::RunOnDevice() {
dims = input(0).dims(); dims = input(0).dims();
if (dynamic_dsize.size() > 0) { if (dsize_desc.size() > 0) {
CHECK_EQ(dynamic_dsize.size(), 2) CHECK_EQ(dsize_desc.size(), 2) << "\nThe dsize should be a scalar with 2 elements.";
<< "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
Tensor* t = ws()->GetTensor(dynamic_dsize[i]); Tensor* dsize = ws()->GetTensor(dsize_desc[i]);
if (t->IsType<int>()) { CHECK(dsize->IsType<int>()) << "\nThe type of dsize should be int32.";
dims[spatial_axis + i] = t->template data<int, CPUContext>()[0]; dims[spatial_axis + i] = dsize->template data<int, CPUContext>()[0];
} else if (t->IsType<float>()) {
dims[spatial_axis + i] = t->template data<float, CPUContext>()[0];
} else {
LOG(FATAL) << "Unsupported types of dsize.";
}
} }
} else if (static_dsize.size() > 0) {
CHECK_EQ(static_dsize.size(), 2)
<< "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) dims[spatial_axis + i] = static_dsize[i];
} else { } else {
CHECK(fy != -1.0 && fx != -1.0) CHECK(fy != -1.0 && fx != -1.0)
<< "\nThe fx and fy should be set."; << "\nThe fx and fy should be set.";
......
...@@ -29,15 +29,14 @@ void ConvOpBase<Context>::ComputeOutputShape() { ...@@ -29,15 +29,14 @@ void ConvOpBase<Context>::ComputeOutputShape() {
const TIndex output_dim = stride[i] * (input_dim - 1) + dilated_kernel - 2 * pad[i]; const TIndex output_dim = stride[i] * (input_dim - 1) + dilated_kernel - 2 * pad[i];
output_shape.push_back(output_dim); output_shape.push_back(output_dim);
} else { } else {
TIndex output_dim = -1; CHECK(output_dims_desc.size() > 0)
if (dynamic_dsize.size() > 0) { << "\nThe output shape must be specified if using SAME padding algorithm.";
NOT_IMPLEMENTED; CHECK_EQ((int)output_dims_desc.size(), num_spatial_axes + 2)
} else if (static_dsize.size() > 0) { << "\nThe len of output shape should be " << num_spatial_axes + 2
if ((int)static_dsize.size() != num_spatial_axes + 2) << ", but got " << output_dims_desc.size() << ".";
LOG(FATAL) << "The len of output shape should be " << num_spatial_axes + 2 Tensor* t = ws()->GetTensor(output_dims_desc[spatial_axis + i]);
<< ", but got " << static_dsize.size(); CHECK(t->IsType<int>()) << "\nThe type of output shape should be int32.";
output_dim = static_dsize[spatial_axis + i]; TIndex output_dim = t->template data<int, CPUContext>()[0];
} else LOG(FATAL) << "The output shape must be specified if using SAME padding algorithm.";
TIndex padding_needed = stride[i] * (input_dim - 1) + dilated_kernel - output_dim; TIndex padding_needed = stride[i] * (input_dim - 1) + dilated_kernel - output_dim;
CHECK_GE(padding_needed, 0) CHECK_GE(padding_needed, 0)
<< "\nThe output shape is incorrect." << "\nThe output shape is incorrect."
......
...@@ -34,23 +34,13 @@ void NNResizeOp<Context>::RunWithType() { ...@@ -34,23 +34,13 @@ void NNResizeOp<Context>::RunWithType() {
template <class Context> template <class Context>
void NNResizeOp<Context>::RunOnDevice() { void NNResizeOp<Context>::RunOnDevice() {
vector<TIndex> dims = input(0).dims(); vector<TIndex> dims = input(0).dims();
if (dynamic_dsize.size() > 0) { if (dsize_desc.size() > 0) {
CHECK_EQ(dynamic_dsize.size(), 2) CHECK_EQ(dsize_desc.size(), 2) << "\nThe dsize should be a scalar with 2 elements.";
<< "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
Tensor* t = ws()->GetTensor(dynamic_dsize[i]); Tensor* dsize = ws()->GetTensor(dsize_desc[i]);
if (t->IsType<int>()) { CHECK(dsize->IsType<int>()) << "\nThe type of dsize should be int32.";
dims[spatial_axis + i] = t->template data<int, CPUContext>()[0]; dims[spatial_axis + i] = dsize->template data<int, CPUContext>()[0];
} else if (t->IsType<float>()) {
dims[spatial_axis + i] = t->template data<float, CPUContext>()[0];
} else {
LOG(FATAL) << "Unsupported types of dsize.";
}
} }
} else if (static_dsize.size() > 0) {
CHECK_EQ(static_dsize.size(), 2)
<< "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) dims[spatial_axis + i] = static_dsize[i];
} else { } else {
CHECK(fy != -1.0 && fx != -1.0) CHECK(fy != -1.0 && fx != -1.0)
<< "\nThe fx and fy should be set."; << "\nThe fx and fy should be set.";
......
...@@ -150,6 +150,16 @@ template <> void Add<float, CPUContext>(const int n, ...@@ -150,6 +150,16 @@ template <> void Add<float, CPUContext>(const int n,
#endif // WITH_SSE #endif // WITH_SSE
} }
template <> void Add<int, CPUContext>(const int n,
const int* a,
const int* b,
int* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(n))
#endif
for (int i = 0; i < n; ++i) y[i] = a[i] + b[i];
}
template <> void Sub<float, CPUContext>(const int n, template <> void Sub<float, CPUContext>(const int n,
const float* a, const float* a,
const float* b, const float* b,
......
...@@ -904,22 +904,23 @@ template<> void Argmin<float, CPUContext>(const int count, ...@@ -904,22 +904,23 @@ template<> void Argmin<float, CPUContext>(const int count,
/******************** ndarray.at ********************/ /******************** ndarray.at ********************/
template <> void CanonicalAxis<float, CPUContext>(const int count, const int dim, float* y) { template <> void CanonicalAxis<int, CPUContext>(const int count, const int dim, int* y) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
for (int i = 0; i < count; ++i) if (y[i] < 0) y[i] += dim; for (int i = 0; i < count; ++i) if (y[i] < 0) y[i] += dim;
} }
template <> void At<float, CPUContext>(const int count, template <typename T>
void _At(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const float* indices, const int* indices,
const float* x, const T* x,
float* y, T* y,
CPUContext* context) { CPUContext* ctx) {
TIndex x_offset, y_offset, x_idx_offset, y_idx_offset; TIndex x_offset, y_offset, x_idx_offset, y_idx_offset;
for (int i = 0; i < y_slice_dim; ++i) { for (int i = 0; i < y_slice_dim; ++i) {
y_idx_offset = i; y_idx_offset = i;
...@@ -927,22 +928,51 @@ template <> void At<float, CPUContext>(const int count, ...@@ -927,22 +928,51 @@ template <> void At<float, CPUContext>(const int count,
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim; x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim;
y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim; y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim, ctx->Copy<T, CPUContext, CPUContext>(inner_dim,
y + y_offset, y + y_offset,
x + x_offset); x + x_offset);
} }
} }
} }
template <> void AtGrad<float, CPUContext>(const int count, template <> void At<float, CPUContext>(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const float* indices, const int* indices,
const float* dy, const float* x,
float* dx, float* y,
CPUContext* context) { CPUContext* ctx) {
_At<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y, ctx);
}
template <> void At<int, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const int* x,
int* y,
CPUContext* ctx) {
_At<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y, ctx);
}
template <typename T>
void _AtGrad(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const T* dy,
T* dx) {
TIndex x_offset, y_offset, x_idx_offset, y_idx_offset; TIndex x_offset, y_offset, x_idx_offset, y_idx_offset;
for (int i = 0; i < y_slice_dim; ++i) { for (int i = 0; i < y_slice_dim; ++i) {
y_idx_offset = i; y_idx_offset = i;
...@@ -950,7 +980,7 @@ template <> void AtGrad<float, CPUContext>(const int count, ...@@ -950,7 +980,7 @@ template <> void AtGrad<float, CPUContext>(const int count,
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim; x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim;
y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim; y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim;
math::Add<float, CPUContext>(inner_dim, math::Add<T, CPUContext>(inner_dim,
dy + y_offset, dy + y_offset,
dx + x_offset, dx + x_offset,
dx + x_offset); dx + x_offset);
...@@ -958,6 +988,32 @@ template <> void AtGrad<float, CPUContext>(const int count, ...@@ -958,6 +988,32 @@ template <> void AtGrad<float, CPUContext>(const int count,
} }
} }
template <> void AtGrad<float, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const float* dy,
float* dx) {
_AtGrad<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, dy, dx);
}
template <> void AtGrad<int, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const int* dy,
int* dx) {
_AtGrad<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, dy, dx);
}
/******************** ndarray.concat ********************/ /******************** ndarray.concat ********************/
template <> void Concat<float, CPUContext>(const int count, template <> void Concat<float, CPUContext>(const int count,
......
...@@ -1574,8 +1574,8 @@ __global__ void _CanonicalAxis(const int count, const int dim, T* y) { ...@@ -1574,8 +1574,8 @@ __global__ void _CanonicalAxis(const int count, const int dim, T* y) {
} }
} }
template <> void CanonicalAxis<float, CUDAContext>(const int count, const int dim, float* y) { template <> void CanonicalAxis<int, CUDAContext>(const int count, const int dim, int* y) {
_CanonicalAxis<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, dim, y); _CanonicalAxis<int> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, dim, y);
CUDA_POST_KERNEL_CHECK; CUDA_POST_KERNEL_CHECK;
} }
...@@ -1585,7 +1585,7 @@ __global__ void _At(const int count, ...@@ -1585,7 +1585,7 @@ __global__ void _At(const int count,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const T* indices, const int* indices,
const T* x, const T* x,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_KERNEL_LOOP(idx, count) {
...@@ -1604,18 +1604,30 @@ template <> void At<float, CUDAContext>(const int count, ...@@ -1604,18 +1604,30 @@ template <> void At<float, CUDAContext>(const int count,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const float* indices, const int* indices,
const float* x, const float* x,
float* y, float* y,
CUDAContext* context) { CUDAContext* context) {
_At<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _At<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
outer_dim, outer_dim, inner_dim,
inner_dim, x_slice_dim, y_slice_dim,
x_slice_dim, indices, x, y);
y_slice_dim, CUDA_POST_KERNEL_CHECK;
indices, }
x,
y); template <> void At<int, CUDAContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const int* x,
int* y,
CUDAContext* context) {
_At<int> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y);
CUDA_POST_KERNEL_CHECK; CUDA_POST_KERNEL_CHECK;
} }
...@@ -1625,7 +1637,7 @@ __global__ void _AtGrad(const int count, ...@@ -1625,7 +1637,7 @@ __global__ void _AtGrad(const int count,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const T* indices, const int* indices,
const T* dy, const T* dy,
T* dx) { T* dx) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_KERNEL_LOOP(idx, count) {
...@@ -1644,18 +1656,28 @@ template <> void AtGrad<float, CUDAContext>(const int count, ...@@ -1644,18 +1656,28 @@ template <> void AtGrad<float, CUDAContext>(const int count,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const float* indices, const int* indices,
const float* dy, const float* dy,
float* dx, float* dx) {
CUDAContext* context) {
_AtGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _AtGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
outer_dim, outer_dim, inner_dim,
inner_dim, x_slice_dim, y_slice_dim,
x_slice_dim, indices, dy, dx);
y_slice_dim, CUDA_POST_KERNEL_CHECK;
indices, }
dy,
dx); template <> void AtGrad<int, CUDAContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const int* dy,
int* dx) {
_AtGrad<int> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, dy, dx);
CUDA_POST_KERNEL_CHECK; CUDA_POST_KERNEL_CHECK;
} }
...@@ -3769,6 +3791,12 @@ __global__ void _ROIPooling(const int count, ...@@ -3769,6 +3791,12 @@ __global__ void _ROIPooling(const int count,
roi += n * 5; roi += n * 5;
int im_idx = roi[0]; int im_idx = roi[0];
if (im_idx < 0) {
y[idx] = 0;
mask[idx] = 0;
continue;
}
int x1 = round(roi[1] * spatial_scale); int x1 = round(roi[1] * spatial_scale);
int y1 = round(roi[2] * spatial_scale); int y1 = round(roi[2] * spatial_scale);
int x2 = round(roi[3] * spatial_scale); int x2 = round(roi[3] * spatial_scale);
...@@ -3802,8 +3830,8 @@ __global__ void _ROIPooling(const int count, ...@@ -3802,8 +3830,8 @@ __global__ void _ROIPooling(const int count,
max_val = x[x_idx]; max_val = x[x_idx];
max_idx = x_idx; max_idx = x_idx;
} }
} //end w }
} // end h }
y[idx] = max_val; y[idx] = max_val;
mask[idx] = max_idx; mask[idx] = max_idx;
...@@ -3857,7 +3885,6 @@ __global__ void _ROIPoolingGrad(const int count, ...@@ -3857,7 +3885,6 @@ __global__ void _ROIPoolingGrad(const int count,
const T* cur_roi = roi + n * 5; const T* cur_roi = roi + n * 5;
const int im_idx_spec = cur_roi[0]; const int im_idx_spec = cur_roi[0];
// ignore wrong im_batch_idx
if (im_idx != im_idx_spec) continue; if (im_idx != im_idx_spec) continue;
int x1 = round(cur_roi[1] * spatial_scale); int x1 = round(cur_roi[1] * spatial_scale);
...@@ -3895,9 +3922,9 @@ __global__ void _ROIPoolingGrad(const int count, ...@@ -3895,9 +3922,9 @@ __global__ void _ROIPoolingGrad(const int count,
if (mask_off[pool_idx] == (h * width + w)) { if (mask_off[pool_idx] == (h * width + w)) {
diff += dy_off[pool_idx]; diff += dy_off[pool_idx];
} }
} // end pw }
} // end ph }
} // end n }
dx[idx] = diff; dx[idx] = diff;
} }
} }
...@@ -3949,6 +3976,13 @@ __global__ void _ROIAlign(const int count, ...@@ -3949,6 +3976,13 @@ __global__ void _ROIAlign(const int count,
roi += n * 5; roi += n * 5;
int roi_batch_ind = roi[0]; int roi_batch_ind = roi[0];
if (roi_batch_ind < 0) {
y[idx] = 0;
mask_h[idx] = 0;
mask_w[idx] = 0;
continue;
}
T roi_start_w = (roi[1]) * spatial_scale; T roi_start_w = (roi[1]) * spatial_scale;
T roi_start_h = (roi[2]) * spatial_scale; T roi_start_h = (roi[2]) * spatial_scale;
T roi_end_w = (roi[3]) * spatial_scale; T roi_end_w = (roi[3]) * spatial_scale;
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!