Commit 418e0c0a by Ting PAN

Add AssignOp

1 parent 52402169
Showing with 986 additions and 365 deletions
......@@ -61,7 +61,8 @@ List Brief
`Tensor.__le__`_ x.__le__() <=> x <= y
`Tensor.__eq__`_ x.__eq__() <=> x == y
`Tensor.__repr__`_ Return the information(name/shape).
`Tensor.__getitem__`_ Return a Tensor with specific indices.
`Tensor.__getitem__`_ Return the value at the specific indices.
`Tensor.__setitem__`_ Set the value at the specific indices.
`Tensor.__call__`_ Return the expressions for displaying.
============================== =============================================================================
......@@ -120,6 +121,7 @@ API Reference
.. _Tensor.__eq__: #dragon.core.tensor.Tensor.__eq__
.. _Tensor.__repr__: #dragon.core.tensor.Tensor.__repr__
.. _Tensor.__getitem__: #dragon.core.tensor.Tensor.__getitem__
.. _Tensor.__setitem__: #dragon.core.tensor.Tensor.__setitem__
.. _Tensor.__call__: #dragon.core.tensor.Tensor.__call__
.. _Tensor.name: #dragon.core.tensor.Tensor.name
......
......@@ -165,7 +165,8 @@ Control Flow
=============== ======================================================================
List Brief
=============== ======================================================================
`Copy`_ Copy A to B.
`Copy`_ Copy the *value* to *ref*.
`Assign`_ Assign the *value* to *ref*.
`Equal`_ *Equal* Comparing between A and B.
`Less`_ *Less* Comparing between A and B.
`LessEqual`_ *LessEqual* Comparing between A and B.
......@@ -306,6 +307,7 @@ List Brief
.. _Multinomial: operators/array.html#dragon.operators.array.Multinomial
.. _Copy: operators/control_flow.html#dAragon.operators.control_flow.Copy
.. _Assign: operators/control_flow.html#dAragon.operators.control_flow.Assign
.. _Equal: operators/control_flow.html#dragon.operators.control_flow.Equal
.. _Less: operators/control_flow.html#dragon.operators.control_flow.Less
.. _LessEqual: operators/control_flow.html#dragon.operators.control_flow.LessEqual
......
......@@ -38,15 +38,13 @@ class GatherGradientOp final : public Operator<Context> {
public:
GatherGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)),
zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) {}
axis(OperatorBase::Arg<int64_t>("axis", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool zero_grad;
int64_t axis, outer_dim, inner_dim, x_slice_dim, y_slice_dim;
};
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_ASSIGN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_ASSIGN_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class AssignOp final : public Operator<Context> {
public:
AssignOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
GET_ARGUMENTS_WITH_DESC(int64_t, starts);
GET_ARGUMENTS_WITH_DESC(int64_t, sizes);
}
USE_OPERATOR_FUNCTIONS;
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
vector<int64_t> st, ed, x_dimsV;
Tensor startsT, y_stridesT, x_dimsT, fake_x;
DECLARE_ARGUMENTS_WITH_DESC(int64_t, starts);
DECLARE_ARGUMENTS_WITH_DESC(int64_t, sizes);
};
DEFINE_ARGUMENTS_WITH_DESC(int64_t, AssignOp, starts);
DEFINE_ARGUMENTS_WITH_DESC(int64_t, AssignOp, sizes);
} // namespace dragon
#endif // DRAGON_OPERATORS_CONTROL_FLOW_ASSIGN_OP_H_
\ No newline at end of file
......@@ -91,17 +91,26 @@ void Square(
*/
template <typename T, class Context>
void Pow(
void Set(
const int n,
const float exp,
const T alpha,
T* y,
Context* ctx);
template <typename T, class Context>
void BroadcastSet(
const int rows,
const int cols,
const int type,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
void Set(
void Pow(
const int n,
const T alpha,
const float exp,
const T* x,
T* y,
Context* ctx);
......
......@@ -37,6 +37,26 @@ MATH_UTILS_DECL T Cube(const T x) {
return x * x * x;
}
template <typename T>
inline void ArgPartition(
const int count,
const int kth,
const bool descend,
const T* v,
std::vector<int64_t>& indices) {
indices.resize(count);
std::iota(indices.begin(), indices.end(), 0);
if (descend) {
std::nth_element(
indices.begin(), indices.begin() + kth, indices.end(),
[&v](int64_t i1, int64_t i2) { return v[i1] > v[i2]; });
} else {
std::nth_element(
indices.begin(), indices.begin() + kth, indices.end(),
[&v](int64_t i1, int64_t i2) { return v[i1] < v[i2]; });
}
}
} // namespace math
void IncreaseIndexInDims(
......
......@@ -34,6 +34,9 @@ void AddProtoMethods(pybind11::module& m) {
}).def("SerializeAs", [](
OperatorDef* self) {
return pybind11::bytes(self->SerializeAsString());
}).def("__repr__", [](
OperatorDef* self) {
return self->DebugString();
}).def("add_input", [](
OperatorDef* self,
const string& input) {
......
......@@ -384,7 +384,7 @@ class OperatorHelper(object):
###############################################
# #
# NDArray #
# Array #
# #
###############################################
......
......@@ -425,26 +425,13 @@ class Tensor(object):
return 'Tensor("{}", shape={}, dtype={})' \
.format(self.name, shape_str, self.dtype)
def __getitem__(self, item):
"""Return a Tensor with specific indices.
Parameters
----------
item : int, slice or Tensor
The indices.
Returns
-------
Tensor
The output tensor.
"""
def _process_indices(self, item):
if not isinstance(item, (slice, tuple)):
if not isinstance(item, int):
raise ValueError('The index should be a integer.')
item = (item,)
if not isinstance(item, tuple): item = tuple([item])
starts = []; sizes = []
starts, sizes = [], []
for ix, it in enumerate(item):
if isinstance(it, slice):
# Handle start
......@@ -457,20 +444,36 @@ class Tensor(object):
sizes.append(it.stop - starts[-1])
if sizes[-1] == 0:
raise ValueError(
'The cropping starts and ends of axis {} '
'The starts and ends of axis {} '
'can not be equal, got {}:{}.'
.format(ix, starts[-1], it.stop))
# Handle step
if it.step is not None:
raise NotImplementedError('Cropping with step has not been implemented yet. ')
raise NotImplementedError(
'Indexing with step has not been implemented yet. ')
elif isinstance(it, int):
starts.append(it)
sizes.append(0)
else:
raise TypeError('Unsupported type of indices: {}'.format(type(type(it))))
raise TypeError('Unsupported type of indices: {}'.format(type(it)))
return starts, sizes
output = self.CreateOperator('Crop', self, starts=starts, sizes=sizes)
def __getitem__(self, item):
"""Return the value at the specific indices.
Parameters
----------
item : int or slice
The indices.
Returns
-------
Tensor
The output tensor.
"""
starts, sizes = self._process_indices(item)
output = self.CreateOperator('Crop', self, starts=starts, sizes=sizes)
if self.shape is not None:
output_shape, squeeze_shape = self.shape[:], []
for ix in range(len(sizes)):
......@@ -479,9 +482,29 @@ class Tensor(object):
if dim != -1: squeeze_shape.append(dim)
if len(squeeze_shape) == 0: output.shape = []
else: output.shape = squeeze_shape[:]
return output
def __setitem__(self, key, value):
"""Set the value at the specific indices.
Parameters
----------
key : int or slice
The indices.
value : Tensor, number or sequence
The value.
Returns
-------
None
"""
starts, sizes = self._process_indices(key)
if not isinstance(value, Tensor):
value = self._from_constants(value)
return self.CreateOperator('Assign', [value],
existing_outputs=[self], starts=starts, sizes=sizes)
def _from_constants(self, value):
if not isinstance(value, np.ndarray):
try:
......
......@@ -385,7 +385,8 @@ def ResetTensor(tensor):
def RunGraph(
graph_name, inputs=(), outputs=[],
stage=None, return_outputs=True):
stage=None, return_outputs=True,
):
"""Run the specific graph.
Parameters
......@@ -516,7 +517,8 @@ def ExportMetaGraph(graph_def):
def Snapshot(
tensors, filename,
prefix='', suffix='.bin',
format='default'):
format='default',
):
"""Snapshot tensors into a binary file.
Parameters
......
......@@ -108,7 +108,7 @@ def Elu(inputs, alpha=1.0, **kwargs):
def SElu(inputs, **kwargs):
"""Scaled Exponential Linear Unit function. `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
**Type Constraints**: *float32*
**Type Constraints**: (*float16*, *float32*)
Parameters
----------
......
......@@ -484,8 +484,9 @@ def Accumulate(inputs, alpha=1., beta=1., **kwargs):
inputs : sequence of Tensor
The inputs, i.e., the *x*.
alpha : float, optional, default=1.
The alpha value.
The value of alpha.
beta : float, optional, default=1.
The value beta.
Returns
-------
......
......@@ -17,7 +17,7 @@ from . import *
@OpSchema.Inputs(1)
def Gather(inputs, indices, axis=0, zero_grad=True, **kwargs):
def Gather(inputs, indices, axis=0, **kwargs):
"""Gather the input according to the indices along the given axis.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
......@@ -30,8 +30,6 @@ def Gather(inputs, indices, axis=0, zero_grad=True, **kwargs):
The indices to form output tensor.
axis : int, optional
The start axis, can be negative.
zero_grad : bool, optional
Whether to accumulate the gradients.
Returns
-------
......@@ -49,10 +47,14 @@ def Gather(inputs, indices, axis=0, zero_grad=True, **kwargs):
@OpSchema.Inputs(1)
@ArgumentHelper.RepeatedDesc('starts')
@ArgumentHelper.RepeatedDesc('sizes')
def Crop(inputs, starts, sizes, start_axis=None, offsets=None, shape_like=None, **kwargs):
def Crop(
inputs, starts=None, sizes=None,
start_axis=None, offsets=None,
shape_like=None, **kwargs
):
"""Crop the input according to the given starts and sizes.
Set ``starts`` and ``sizes`` to *None*, if using ``start_axis``, ``offsets`` and ``shape_like``.
The value of ``sizes`` could be set to *-1* (to end) or *0* (squeeze).
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
......@@ -60,10 +62,10 @@ def Crop(inputs, starts, sizes, start_axis=None, offsets=None, shape_like=None,
----------
inputs : Tensor
The input tensor.
starts : int, Tensor, sequence of (int, Tensor)
The starts.
sizes : int, Tensor, sequence of (int, Tensor)
The crop sizes.
starts : sequence of (int, Tensor), optional
The start pos of each dimension.
sizes : sequence of (int, Tensor), optional
The size of each dimension.
start_axis : int, optional
The axis to start.
offsets : int, sequence of, optional
......
......@@ -18,29 +18,63 @@ from . import *
@OpSchema.Inputs(2)
def Copy(inputs, **kwargs):
"""Copy A to B.
"""Copy the ``value`` to ``ref``.
The size of ``value`` and ``ref`` should be same.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters
----------
inputs : sequence of Tensor
The inputs, A and B respectively.
The ``ref`` and ``value`` respectively.
Returns
-------
Tensor
The output tensor, i.e., B(taking values of A).
The ``ref``.
"""
arguments = ParseArgs(locals())
arguments['existing_outputs'] = [arguments['inputs'][1]]
arguments['inputs'] = [arguments['inputs'][0]]
arguments['existing_outputs'] = [arguments['inputs'][0]]
arguments['inputs'] = [arguments['inputs'][1]]
return Tensor.CreateOperator('Copy', **arguments)
@OpSchema.ConvertConstantInputs()
@OpSchema.Inputs(2)
@ArgumentHelper.RepeatedDesc('starts')
@ArgumentHelper.RepeatedDesc('sizes')
def Assign(inputs, starts=None, sizes=None, **kwargs):
"""Assign the ``value`` to ``ref``.
The value of ``sizes`` could be set to *-1* (to end) or *0* (squeeze).
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters
----------
inputs : sequence of Tensor
The ``ref`` and ``value`` respectively.
starts : sequence of (int, Tensor), optional
The start pos of each dimension.
sizes : sequence of (int, Tensor), optional
The size of each dimension.
Returns
-------
Tensor
The ``ref``.
"""
arguments = ParseArgs(locals())
arguments['existing_outputs'] = [arguments['inputs'][0]]
arguments['inputs'] = [arguments['inputs'][1]]
return Tensor.CreateOperator('Assign', **arguments)
@OpSchema.ConvertConstantInputs()
@OpSchema.Inputs(2)
def Equal(inputs, to_uint8=False, **kwargs):
"""``Equal`` comparing between A and B.
......
......@@ -141,6 +141,7 @@ Multinomial = array_ops.Multinomial
# Control Flow
Copy = control_flow_ops.Copy
Assign = control_flow_ops.Assign
Equal = control_flow_ops.Equal
Less = control_flow_ops.Less
LessEqual = control_flow_ops.LessEqual
......
......@@ -57,6 +57,9 @@ class _BatchNorm(Module):
}
}
def half(self):
return self # Float32 parameters are required
def extra_repr(self):
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
'track_running_stats={track_running_stats}'.format(**self.__dict__)
......
......@@ -51,6 +51,13 @@ class _GroupNorm(Module):
}
}
def half(self):
return self # Float32 parameters are required
def extra_repr(self):
return '{num_features}, eps={eps}, group={group}, ' \
'affine={affine}'.format(**self.__dict__)
def forward(self, input):
inputs = [input] + self.inputs
self.unify_devices(inputs)
......
......@@ -78,6 +78,20 @@ class RNNBase(Module):
}
}
def extra_repr(self):
s = '{input_size}, {hidden_size}'
if self.num_layers != 1:
s += ', num_layers={num_layers}'
if self.bias is not True:
s += ', bias={bias}'
if self.batch_first is not False:
s += ', batch_first={batch_first}'
if self.dropout != 0:
s += ', dropout={dropout}'
if self.bidirectional is not False:
s += ', bidirectional={bidirectional}'
return s.format(**self.__dict__)
def make_meta_from_phase(self, phase):
def reset_meta(self, phase):
self._module_key = None
......
......@@ -21,6 +21,7 @@ from dragon.vm.torch.ops.modules.control_flow import Compare
from dragon.vm.torch.ops.modules.arithmetic import (
Fundamental, Log, Exp, Sqrt,
Accumulate,
MM, FullyConnected,
Maximum, Minimum, Clamp,
)
......@@ -31,12 +32,12 @@ from dragon.vm.torch.ops.modules.init import (
from dragon.vm.torch.ops.modules.array import (
Reshape, Squeeze, UnSqueeze, Permute,
Indexing, Repeat, Concat, Gather,
Indexing, Assigning, Repeat, Concat, Gather,
Reduce, ArgReduce, OneHot, Multinomial,
)
from dragon.vm.torch.ops.modules.update import (
Accumulate, Collective, Update,
Accumulate as _Accumulate, Collective, Update,
)
from dragon.vm.torch.ops.modules.vision import (
......@@ -46,6 +47,7 @@ from dragon.vm.torch.ops.modules.vision import (
__all__ = [
'add', 'sub', 'mul', 'div',
'accumulate',
'maximum', 'minimum', 'clamp',
'log', 'exp', 'sqrt',
'mm', 'xw_plus_b',
......@@ -317,6 +319,32 @@ def sqrt(input, out=None):
return module.forward(input, out)
def accumulate(input, alpha=1., beta=1., out=None):
"""Compute *out = alpha * input + beta * out*
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
alpha : float, optional, default=1.
The value of alpha.
beta : float, optional, default=1.
The value beta.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dev = MakeDevice(inputs=[input])
key = 'Accumulate/{}/alpha:{}/beta:{}'.format(dev, alpha, beta)
module = get_module(Accumulate, key, dev, alpha=alpha, beta=beta)
return module.forward(input, out)
def mm(mat1, mat2, transA=False, transB=False, out=None):
"""Performs a matrix multiplication of the matrices ``mat1`` and ``mat2.``
......@@ -458,6 +486,19 @@ def _indexing(input, starts, sizes):
return module.forward(input, starts, sizes)
def _assigning(output, input, starts, sizes):
if not isinstance(input, Tensor):
if isinstance(input, (tuple, list)):
input = Tensor(input, dtype=output.dtype, device=output.device)
else:
input = WrapScalar(input, output.dtype, output.device)
n_starts, n_sizes = len(starts), len(sizes)
dev = MakeDevice(inputs=[input])
key = 'Assign/{}/n_starts:{}/n_sizes:{}'.format(dev, n_starts, n_sizes)
module = get_module(Assigning, key, dev, n_starts=n_starts, n_sizes=n_sizes)
return module.forward(input, output, starts, sizes)
def _compare(input, other, operation, out=None):
if not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input.device)
......@@ -1074,7 +1115,7 @@ def _accumulate(grads):
if not isinstance(grads, (list, tuple)): grads = [grads]
dev = MakeDevice(inputs=grads)
key = 'Accumulate/{}/alpha:1./beta:1.'.format(dev)
module = get_module(Accumulate, key, dev)
module = get_module(_Accumulate, key, dev)
return module.forward(grads)
......
......@@ -163,3 +163,24 @@ class FullyConnected(BaseModule):
self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs)
class Accumulate(BaseModule):
def __init__(self, key, dev, **kwargs):
super(Accumulate, self).__init__(key, dev, **kwargs)
self.alpha = kwargs.get('alpha', 1.)
self.beta = kwargs.get('beta', 1.)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Accumulate',
'arguments': {
'alpha': self.alpha,
'beta': self.beta,
},
}
def forward(self, x, y=None):
outputs = [y] if y else [self.register_output()]
return self.run([x], outputs, auto_grad=False)
\ No newline at end of file
......@@ -56,6 +56,42 @@ class Indexing(BaseModule):
return self.run(inputs, outputs, callback=callback)
class Assigning(BaseModule):
"""This module imports the *AssignOp* from backend.
Arbitrary length of starts and sizes will be take.
"""
def __init__(self, key, dev, **kwargs):
super(Assigning, self).__init__(key, dev, **kwargs)
self.n_starts = kwargs.get('n_starts', 0)
self.n_sizes = kwargs.get('n_sizes', 0)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Assign',
'arguments': {
'starts_desc': [
'${{ANCHOR}}/starts[{}]'.format(n)
for n in range(self.n_starts)],
'sizes_desc': [
'${{ANCHOR}}/sizes[{}]'.format(n)
for n in range(self.n_sizes)],
},
}
def update_arguments(self, A, starts, sizes):
for i, e in enumerate(starts):
self.set_argument_i64('{}/starts[{}]'.format(A, i), e)
self.set_argument_i64('{}/sizes[{}]'.format(A, i), sizes[i])
def forward(self, x, y, starts, sizes):
self.unify_devices([x, y])
callback = lambda A: self.update_arguments(A, starts, sizes)
return self.run([x], [y], callback=callback, auto_grad=False)
class Concat(BaseModule):
"""This module imports the *ConcatOp* from backend.
......@@ -100,7 +136,6 @@ class Gather(BaseModule):
'op_type': 'Gather',
'arguments': {
'axis': self.axis,
'zero_grad': True,
},
}
......
......@@ -23,7 +23,7 @@ from dragon.vm.torch.ops.builtin import (
_fundamental, _rfundamental,
log, exp, sqrt, clamp,
_reshape, squeeze, unsqueeze,
_permute, _repeat, _indexing, narrow,
_permute, _repeat, _indexing, _assigning, narrow,
mean, sum, max, min,
gt, lt, eq, ge, le,
)
......@@ -83,6 +83,7 @@ Tensor.le = lambda *args, **kwargs: le(*args, **kwargs)
Tensor.eq = lambda *args, **kwargs: eq(*args, **kwargs)
Tensor.narrow = lambda *args, **kwargs: narrow(*args, **kwargs)
Tensor._indexing = lambda *args, **kwargs: _indexing(*args, **kwargs)
Tensor._assigning = lambda *args, **kwargs: _assigning(*args, **kwargs)
Tensor.half = lambda self: _type_to(self, dtype='float16', inplace=False)
......
......@@ -475,20 +475,7 @@ class Tensor(object):
# PyGC will detect them automatically
TensorPool.put(self.name)
def __getitem__(self, item):
"""Return a Tensor with specific indices.
Parameters
----------
item : int, slice or Tensor
The indices.
Returns
-------
Tensor
The output tensor.
"""
def _process_indices(self, item):
if not isinstance(item, (slice, tuple)):
# + value[?]
if not isinstance(item, int):
......@@ -498,7 +485,7 @@ class Tensor(object):
# + value[?:?]
item = tuple([item])
# + value[?:?, ?:?, ...]
starts = []; sizes = []
starts, sizes = [], []
for ix, it in enumerate(item):
if isinstance(it, slice):
# Handle start
......@@ -511,19 +498,55 @@ class Tensor(object):
sizes.append(it.stop - starts[-1])
if sizes[-1] == 0:
raise ValueError(
'The cropping starts and ends of axis {} '
'The starts and ends of axis {} '
'can not be equal, got {}:{}.'
.format(ix, starts[-1], it.stop))
# Handle step
if it.step is not None:
raise NotImplementedError('Indexing with step has not been implemented yet. ')
raise NotImplementedError(
'Indexing with step has not been implemented yet. ')
elif isinstance(it, int):
starts.append(it)
sizes.append(0)
else:
raise TypeError('Unsupported type of indices: {}'.format(type(type(it))))
raise TypeError('Unsupported type of indices: {}'.format(type(it)))
return starts, sizes
def __getitem__(self, item):
"""Return the value at the specific indices.
Parameters
----------
item : int or slice
The indices.
Returns
-------
Tensor
The output tensor.
"""
starts, sizes = self._process_indices(item)
return self._indexing(starts, sizes)
def __setitem__(self, key, value):
"""Set the value at the specific indices.
Parameters
----------
key : int, slice
The indices.
value : dragon.vm.torch.Tensor, number or sequence
The value.
Returns
-------
None
"""
starts, sizes = self._process_indices(key)
return self._assigning(value, starts, sizes)
def __hash__(self):
return id(self)
......
......@@ -5,81 +5,8 @@ namespace dragon {
namespace rcnn {
/******************** Proposal ********************/
template <> void GenerateProposals<float, CPUContext>(
const int A,
const int feat_h,
const int feat_w,
const int stride,
const float im_h,
const float im_w,
const float min_box_h,
const float min_box_w,
const float* scores,
const float* bbox_deltas,
const float* anchors,
float* proposals,
CPUContext* ctx) {
float* proposal = proposals;
const int K = feat_h * feat_w;
for (int h = 0; h < feat_h; ++h) {
for (int w = 0; w < feat_w; ++w) {
const float x = (float)w * stride;
const float y = (float)h * stride;
// bbox_deltas: [1, A, 4, K]
const float* bbox_delta = bbox_deltas + h * feat_w + w;
// scores: [1, A, K]
const float* score = scores + h * feat_w + w;
for (int a = 0; a < A; ++a) {
const float dx = bbox_delta[(a * 4 + 0) * K];
const float dy = bbox_delta[(a * 4 + 1) * K];
const float d_log_w = bbox_delta[(a * 4 + 2) * K];
const float d_log_h = bbox_delta[(a * 4 + 3) * K];
proposal[0] = x + anchors[a * 4 + 0];
proposal[1] = y + anchors[a * 4 + 1];
proposal[2] = x + anchors[a * 4 + 2];
proposal[3] = y + anchors[a * 4 + 3];
proposal[4] = BBoxTransform<float>(
dx, dy, d_log_w, d_log_h,
im_w, im_h, min_box_w, min_box_h,
proposal) * score[a * K];
proposal += 5;
}
}
}
}
template <> void GenerateProposals_v2<float, CPUContext>(
const int total_anchors,
const float im_h,
const float im_w,
const float min_box_h,
const float min_box_w,
const float* scores,
const float* bbox_deltas,
float* proposals,
CPUContext* ctx) {
float* proposal = proposals;
for (int i = 0; i < total_anchors; ++i) {
// bbox_deltas: [1, 4, total_anchors]
// scores: [1, total_anchors]
const float dx = bbox_deltas[i];
const float dy = bbox_deltas[total_anchors + i];
const float d_log_w = bbox_deltas[2 * total_anchors + i];
const float d_log_h = bbox_deltas[3 * total_anchors + i];
proposal[4] = BBoxTransform<float>(
dx, dy, d_log_w, d_log_h,
im_w, im_h, min_box_w, min_box_h,
proposal) * scores[i];
proposal += 5;
}
}
/******************** NMS ********************/
template <typename T>
T iou(const T A[], const T B[]) {
T IoU(const T A[], const T B[]) {
if (A[0] > B[2] || A[1] > B[3] ||
A[2] < B[0] || A[3] < B[1]) return 0;
const T x1 = std::max(A[0], B[0]);
......@@ -99,7 +26,7 @@ template <> void ApplyNMS<float, CPUContext>(
const int max_keeps,
const float thresh,
const float* boxes,
int* keep_indices,
int64_t* keep_indices,
int& num_keep,
CPUContext* ctx) {
int count = 0;
......@@ -110,7 +37,7 @@ template <> void ApplyNMS<float, CPUContext>(
keep_indices[count++] = i;
if (count == max_keeps) break;
for (int j = i + 1; j < num_boxes; ++j)
if (!is_dead[j] && iou(&boxes[i * 5],
if (!is_dead[j] && IoU(&boxes[i * 5],
&boxes[j * 5]) > thresh)
is_dead[j] = 1;
}
......
......@@ -7,156 +7,11 @@ namespace dragon {
namespace rcnn {
/******************** BBox ********************/
template <typename T>
__device__ int _BBoxTransform(
const T dx,
const T dy,
const T d_log_w,
const T d_log_h,
const T im_w,
const T im_h,
const T min_box_w,
const T min_box_h,
T* bbox) {
const T w = bbox[2] - bbox[0] + (T)1;
const T h = bbox[3] - bbox[1] + (T)1;
const T ctr_x = bbox[0] + (T)0.5 * w;
const T ctr_y = bbox[1] + (T)0.5 * h;
const T pred_ctr_x = dx * w + ctr_x;
const T pred_ctr_y = dy * h + ctr_y;
const T pred_w = exp(d_log_w) * w;
const T pred_h = exp(d_log_h) * h;
bbox[0] = pred_ctr_x - (T)0.5 * pred_w;
bbox[1] = pred_ctr_y - (T)0.5 * pred_h;
bbox[2] = pred_ctr_x + (T)0.5 * pred_w;
bbox[3] = pred_ctr_y + (T)0.5 * pred_h;
bbox[0] = max((T)0, min(bbox[0], im_w - (T)1));
bbox[1] = max((T)0, min(bbox[1], im_h - (T)1));
bbox[2] = max((T)0, min(bbox[2], im_w - (T)1));
bbox[3] = max((T)0, min(bbox[3], im_h - (T)1));
const T box_w = bbox[2] - bbox[0] + (T)1;
const T box_h = bbox[3] - bbox[1] + (T)1;
return (box_w >= min_box_w) * (box_h >= min_box_h);
}
/******************** Proposal ********************/
template <typename T>
__global__ void _GenerateProposals(
const int nthreads,
const int A,
const int feat_h,
const int feat_w,
const int stride,
const float im_h,
const float im_w,
const float min_box_h,
const float min_box_w,
const T* scores,
const T* bbox_deltas,
const T* anchors,
T* proposals) {
CUDA_1D_KERNEL_LOOP(idx, nthreads) {
const int h = idx / A / feat_w;
const int w = (idx / A) % feat_w;
const int a = idx % A;
const T x = w * stride;
const T y = h * stride;
const T* bbox_delta = bbox_deltas + h * feat_w + w;
const T* score = scores + h * feat_w + w;
const int K = feat_h * feat_w;
const T dx = bbox_delta[(a * 4 + 0) * K];
const T dy = bbox_delta[(a * 4 + 1) * K];
const T d_log_w = bbox_delta[(a * 4 + 2) * K];
const T d_log_h = bbox_delta[(a * 4 + 3) * K];
T* proposal = proposals + idx * 5;
proposal[0] = x + anchors[a * 4 + 0];
proposal[1] = y + anchors[a * 4 + 1];
proposal[2] = x + anchors[a * 4 + 2];
proposal[3] = y + anchors[a * 4 + 3];
proposal[4] = _BBoxTransform(
dx, dy, d_log_w, d_log_h,
im_w, im_h, min_box_w, min_box_h,
proposal) * score[a * K];
}
}
template <> void GenerateProposals<float, CUDAContext>(
const int A,
const int feat_h,
const int feat_w,
const int stride,
const float im_h,
const float im_w,
const float min_box_h,
const float min_box_w,
const float* scores,
const float* bbox_deltas,
const float* anchors,
float* proposals,
CUDAContext* ctx) {
const auto num_proposals = A * feat_h * feat_w;
_GenerateProposals<float>
<< < CUDA_BLOCKS(num_proposals), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(num_proposals, A, feat_h, feat_w, stride,
im_h, im_w, min_box_h, min_box_w,
scores, bbox_deltas, anchors, proposals);
}
template <typename T>
__global__ void _GenerateProposals_v2(
const int nthreads,
const float im_h,
const float im_w,
const float min_box_h,
const float min_box_w,
const T* scores,
const T* bbox_deltas,
T* proposals) {
CUDA_1D_KERNEL_LOOP(idx, nthreads) {
const float dx = bbox_deltas[idx];
const float dy = bbox_deltas[nthreads + idx];
const float d_log_w = bbox_deltas[2 * nthreads + idx];
const float d_log_h = bbox_deltas[3 * nthreads + idx];
T* proposal = proposals + idx * 5;
proposal[4] = _BBoxTransform(
dx, dy, d_log_w, d_log_h,
im_w, im_h, min_box_w, min_box_h,
proposal) * scores[idx];
}
}
template <> void GenerateProposals_v2<float, CUDAContext>(
const int total_anchors,
const float im_h,
const float im_w,
const float min_box_h,
const float min_box_w,
const float* scores,
const float* bbox_deltas,
float* proposals,
CUDAContext* ctx) {
_GenerateProposals_v2<float>
<< < CUDA_BLOCKS(total_anchors), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(total_anchors, im_h, im_w, min_box_h, min_box_w,
scores, bbox_deltas, proposals);
}
/******************** NMS ********************/
#define DIV_UP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define NMS_BLOCK_SIZE 64
template <typename T>
__device__ T iou(const T* A, const T* B) {
__device__ T _IoU(const T* A, const T* B) {
const T x1 = max(A[0], B[0]);
const T y1 = max(A[1], B[1]);
const T x2 = min(A[2], B[2]);
......@@ -200,7 +55,7 @@ __global__ void nms_mask(
unsigned long long mask_j = 0;
const int di_start = (i_start == j_start) ? (tid + 1) : 0;
for (int di = di_start; di < di_end; ++di)
if (iou(box_j, boxes_i + di * 4) > nms_thresh)
if (_IoU(box_j, boxes_i + di * 4) > nms_thresh)
mask_j |= 1ULL << di;
mask[(j_start + tid) * num_blocks + bid] = mask_j;
}
......@@ -212,7 +67,7 @@ void _ApplyNMS(
const int max_keeps,
const float thresh,
const T* boxes,
int* keep_indices,
int64_t* keep_indices,
int& num_keep,
CUDAContext* ctx) {
const int num_blocks = DIV_UP(num_boxes, NMS_BLOCK_SIZE);
......@@ -229,7 +84,7 @@ void _ApplyNMS(
<< < blocks, NMS_BLOCK_SIZE,
0, ctx->cuda_stream() >> > (num_boxes,
thresh, (T*)boxes_dev, (uint64_t*)mask_dev);
CUDA_CHECK(cudaPeekAtLastError());
ctx->FinishDeviceCompution();
std::vector<uint64_t> mask_host(num_boxes * num_blocks);
CUDA_CHECK(cudaMemcpy(&mask_host[0], mask_dev,
......@@ -259,7 +114,7 @@ template <> void ApplyNMS<float, CUDAContext>(
const int max_keeps,
const float thresh,
const float* boxes,
int* keep_indices,
int64_t* keep_indices,
int& num_keep,
CUDAContext* ctx) {
_ApplyNMS<float>(num_boxes, max_keeps, thresh,
......
......@@ -90,57 +90,95 @@ inline void GenerateAnchors(
template <typename T>
inline void GenerateGridAnchors(
const int A,
const int num_proposals,
const int num_anchors,
const int feat_h,
const int feat_w,
const int stride,
const int base_offset,
const T* anchors,
const int64_t* indices,
T* proposals) {
T* proposal = proposals;
for (int a = 0; a < A; ++a) {
for (int h = 0; h < feat_h; ++h) {
for (int w = 0; w < feat_w; ++w) {
const T x = (T)w * stride;
const T y = (T)h * stride;
proposal[0] = x + anchors[a * 4 + 0];
proposal[1] = y + anchors[a * 4 + 1];
proposal[2] = x + anchors[a * 4 + 2];
proposal[3] = y + anchors[a * 4 + 3];
proposal += 5;
}
T x, y;
int idx_3d, a, h, w;
int idx_range = num_anchors * feat_h * feat_w;
for (int i = 0; i < num_proposals; ++i) {
idx_3d = (int)indices[i] - base_offset;
if (idx_3d >= 0 && idx_3d < idx_range) {
w = idx_3d % feat_w;
h = (idx_3d / feat_w) % feat_h;
a = idx_3d / feat_w / feat_h;
x = (T)w * stride, y = (T)h * stride;
auto* A = anchors + a * 4;
auto* P = proposals + i * 5;
P[0] = x + A[0], P[1] = y + A[1];
P[2] = x + A[2], P[3] = y + A[3];
}
}
}
/******************** Proposal ********************/
template <typename T, class Context>
void GenerateProposals(
const int A,
const int feat_h,
const int feat_w,
const int stride,
template <typename T>
void GenerateSSProposals(
const int K,
const int num_proposals,
const float im_h,
const float im_w,
const float min_box_h,
const float min_box_w,
const T* scores,
const T* bbox_deltas,
const T* anchors,
T* proposals,
Context* ctx);
const T* deltas,
const int64_t* indices,
T* proposals) {
int64_t index, a, k;
const float* delta;
float* proposal = proposals;
float dx, dy, d_log_w, d_log_h;
for (int i = 0; i < num_proposals; ++i) {
index = indices[i];
a = index / K, k = index % K;
delta = deltas + k;
dx = delta[(a * 4 + 0) * K];
dy = delta[(a * 4 + 1) * K];
d_log_w = delta[(a * 4 + 2) * K];
d_log_h = delta[(a * 4 + 3) * K];
proposal[4] = BBoxTransform<float>(
dx, dy, d_log_w, d_log_h,
im_w, im_h, min_box_w, min_box_h,
proposal) * scores[index];
proposal += 5;
}
}
template <typename T, class Context>
void GenerateProposals_v2(
const int total_anchors,
template <typename T>
void GenerateMSProposals(
const int num_candidates,
const int num_proposals,
const float im_h,
const float im_w,
const float min_box_h,
const float min_box_w,
const T* scores,
const T* bbox_deltas,
T* proposals,
Context* ctx);
const T* deltas,
const int64_t* indices,
T* proposals) {
int64_t index;
float* proposal = proposals;
float dx, dy, d_log_w, d_log_h;
for (int i = 0; i < num_proposals; ++i) {
index = indices[i];
dx = deltas[index];
dy = deltas[num_candidates + index];
d_log_w = deltas[2 * num_candidates + index];
d_log_h = deltas[3 * num_candidates + index];
proposal[4] = BBoxTransform<float>(
dx, dy, d_log_w, d_log_h,
im_w, im_h, min_box_w, min_box_h,
proposal) * scores[index];
proposal += 5;
}
}
template <typename T>
inline void SortProposals(
......@@ -174,7 +212,7 @@ inline void RetrieveRoIs(
const int num_rois,
const int roi_batch_ind,
const T* proposals,
const int* roi_indices,
const int64_t* roi_indices,
T* rois) {
for (int i = 0; i < num_rois; ++i) {
const T* proposal = proposals + roi_indices[i] * 5;
......@@ -248,7 +286,7 @@ void ApplyNMS(
const int max_keeps,
const T thresh,
const T* boxes,
int* keep_indices,
int64_t* keep_indices,
int& num_keep,
Context* ctx);
......
......@@ -40,12 +40,12 @@ class ProposalOp final : public Operator<Context> {
template <typename T> void RunWithType();
protected:
vector<int64_t> strides;
vector<int64_t> strides, indices, roi_indices;
vector<float> ratios, scales;
int64_t pre_nms_top_n, post_nms_top_n, min_size, num_images;
int64_t min_level, max_level, canonical_level, canonical_scale;
float nms_thresh;
Tensor anchors_, proposals_, roi_indices_, nms_mask_;
Tensor anchors_, proposals_, nms_mask_;
};
} // namespace dragon
......
......@@ -49,6 +49,18 @@ template<> void ReluGrad<float, CPUContext>(
}
}
/*! ReluGrad <T = float16, Device = CPU> */
template<> void ReluGrad<float16, CPUContext>(
const int count,
const float slope,
const float16* dy,
const float16* y,
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace kernel
} // namepsace dragon
\ No newline at end of file
......@@ -13,11 +13,18 @@ namespace kernel {
template <typename T>
__global__ void _Relu(
const int count,
const float slope,
const T slope,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(idx, count) {
y[idx] = x[idx] > 0 ? x[idx] : x[idx] * slope;
#if __CUDA_ARCH__ >= 350
y[idx] = __ldg(x + idx) > 0 ?
__ldg(x + idx) :
__ldg(x + idx) * slope;
#else
y[idx] = x[idx] > 0 ?
x[idx] : x[idx] * slope;
#endif
}
}
......@@ -35,8 +42,7 @@ template<> void Relu<float, CUDAContext>(
/*! Relu <T = float16, Device = CUDA> */
template <typename T>
__global__ void _ReluHalf(
template <> __global__ void _Relu<half>(
const int count,
const half slope,
const half* x,
......@@ -44,14 +50,14 @@ __global__ void _ReluHalf(
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530
y[idx] = __hgt(x[idx], kZero) ?
x[idx] : __hmul(x[idx], slope);
y[idx] = __hgt(__ldg(x + idx), kZero) ?
__ldg(x + idx) : __hmul(
__ldg(x + idx), slope);
#endif
}
}
template <typename T>
__global__ void _ReluHalf2(
template <> __global__ void _Relu<half2>(
const int count,
const half2 slope,
const half2* x,
......@@ -59,8 +65,9 @@ __global__ void _ReluHalf2(
const half2 kZero = __float2half2_rn(0.f);
CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530
y[idx] = __hbgt2(x[idx], kZero) ?
x[idx] : __hmul2(x[idx], slope);
y[idx] = __hbgt2(__ldg(x + idx), kZero) ?
__ldg(x + idx) : __hmul2(
__ldg(x + idx), slope);
#endif
}
}
......@@ -72,14 +79,14 @@ template<> void Relu<float16, CUDAContext>(
float16* y,
CUDAContext* ctx) {
if ((count & 1) == 0) {
_ReluHalf2<half2>
_Relu<half2>
<< < CUDA_BLOCKS(count >> 1), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count >> 1, cast::to<half2>(slope),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y));
} else {
_ReluHalf<half>
_Relu<half>
<< < CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count, cast::to<half>(slope),
......@@ -98,9 +105,17 @@ __global__ void _ReluGrad(
const T* y,
T* dx) {
CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530
dx[idx] = __ldg(dy + idx) * (
(__ldg(y + idx) > 0) +
slope * (__ldg(y + idx) <= 0)
);
#else
dx[idx] = dy[idx] * (
(y[idx] > 0) + slope * (y[idx] <= 0)
(y[idx] > 0) +
slope * (y[idx] <= 0)
);
#endif
}
}
......@@ -117,6 +132,40 @@ template<> void ReluGrad<float, CUDAContext>(
(count, slope, dy, y, dx);
}
/*! ReluGrad <T = float16, Device = CUDA> */
template <> __global__ void _ReluGrad<half>(
const int count,
const float slope,
const half* dy,
const half* y,
half* dx) {
const half kZero = __float2half(0.f);
CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530
dx[idx] = __hmul(__ldg(dy + idx), __float2half(
__hgt(__ldg(y + idx), kZero) +
slope * __hle(__ldg(y + idx), kZero))
);
#endif
}
}
template<> void ReluGrad<float16, CUDAContext>(
const int count,
const float slope,
const float16* dy,
const float16* y,
float16* dx,
CUDAContext* ctx) {
_ReluGrad<half>
<< < CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count, slope, reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y),
reinterpret_cast<half*>(dx));
}
} // namespace kernel
} // namepsace dragon
......
......@@ -21,6 +21,16 @@ template<> void SElu<float, CPUContext>(
}
}
/*! SElu <T = float16, Device = CPU> */
template<> void SElu<float16, CPUContext>(
const int count,
const float16* x,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
/*! SEluGrad <T = float32, Device = CPU> */
template<> void SEluGrad<float, CPUContext>(
......@@ -38,6 +48,17 @@ template<> void SEluGrad<float, CPUContext>(
}
}
/*! SEluGrad <T = float16, Device = CPU> */
template<> void SEluGrad<float16, CPUContext>(
const int count,
const float16* dy,
const float16* y,
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
} // namespace kernel
} // namepsace dragon
\ No newline at end of file
#ifdef WITH_CUDA
#include "core/context_cuda.h"
#include "utils/cast.h"
#include "utils/op_kernel.h"
namespace dragon {
......@@ -15,8 +16,15 @@ __global__ void _SElu(
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(idx, count) {
y[idx] = x[idx] > 0 ? 1.0507f * x[idx] :
#if __CUDA_ARCH__ >= 350
y[idx] = __ldg(x + idx) > 0 ?
1.0507f * __ldg(x + idx) :
1.7581f * (exp(__ldg(x + idx)) - 1);
#else
y[idx] = x[idx] > 0 ?
1.0507f * x[idx] :
1.7581f * (exp(x[idx]) - 1);
#endif
}
}
......@@ -31,6 +39,34 @@ template<> void SElu<float, CUDAContext>(
(count, x, y);
}
/*! SElu <T = float16, Device = CUDA> */
template <> __global__ void _SElu<half>(
const int count,
const half* x,
half* y) {
CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530
const float x32 = __half2float(x[idx]);
y[idx] = __float2half(x32 > 0 ?
1.0507f * x32 : 1.7581f * (
exp(x32) - 1));
#endif
}
}
template<> void SElu<float16, CUDAContext>(
const int count,
const float16* x,
float16* y,
CUDAContext* ctx) {
_SElu<half>
<< < CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count, reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
/*! SEluGrad <T = float32, Device = CUDA> */
template <typename T>
......@@ -40,8 +76,17 @@ __global__ void _SEluGrad(
const T* y,
T* dx) {
CUDA_1D_KERNEL_LOOP(idx, count) {
dx[idx] = y[idx] > 0 ? 1.0507f * dy[idx] :
(1.7581f + y[idx]) * dy[idx];
#if __CUDA_ARCH__ >= 350
dx[idx] = __ldg(y + idx) > 0 ?
1.0507f * __ldg(dy + idx) :
(1.7581f + __ldg(y + idx))
* __ldg(dy + idx);
#else
dx[idx] = y[idx] > 0 ?
1.0507f * dy[idx] :
(1.7581f + y[idx])
* dy[idx];
#endif
}
}
......@@ -57,6 +102,37 @@ template<> void SEluGrad<float, CUDAContext>(
(count, dy, y, dx);
}
/*! SEluGrad <T = float16, Device = CUDA> */
template<> __global__ void _SEluGrad<half>(
const int count,
const half* dy,
const half* y,
half* dx) {
CUDA_1D_KERNEL_LOOP(idx, count) {
#if __CUDA_ARCH__ >= 530
const float y32 = __half2float(y[idx]);
dx[idx] = __float2half(
y32 > 0 ? 1.0507f * __half2float(dy[idx]) :
(1.7581f + y32) * __half2float(dy[idx]));
#endif
}
}
template<> void SEluGrad<float16, CUDAContext>(
const int count,
const float16* dy,
const float16* y,
float16* dx,
CUDAContext* ctx) {
_SEluGrad<half>
<< < CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count, reinterpret_cast<const half*>(dy),
reinterpret_cast<const half*>(y),
reinterpret_cast<half*>(dx));
}
} // namespace kernel
} // namepsace dragon
......
#include "utils/op_kernel.h"
#include "utils/math_utils.h"
namespace dragon {
namespace kernel {
/*! Assign <T = ?, Device = CPU> */
template <typename T>
void _Assign(
const int count,
const int ndims,
const int* x_dims,
const int* y_strides,
const int* starts,
const T* x,
T* y) {
vector<int> index(ndims, 0); int y_idx;
for (int x_idx = 0; x_idx < count; ++x_idx) {
y_idx = 0;
for (int d = ndims - 1; d >= 0; --d) {
y_idx += (index[d] + starts[d]) * y_strides[d];
}
y[y_idx] = x[x_idx];
utils::IncreaseIndexInDims(ndims, x_dims, index.data());
}
}
/*! Kernel Launchers */
#define DEFINE_ASSIGN_KERNEL_LAUNCHER(T) \
template<> void Assign<T, CPUContext>( \
const int count, \
const int ndims, \
const int* x_dims, \
const int* y_strides, \
const int* starts, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_Assign<T>(count, ndims, x_dims, \
y_strides, starts, x, y); \
}
DEFINE_ASSIGN_KERNEL_LAUNCHER(bool);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(uint8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int64_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float16);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float);
DEFINE_ASSIGN_KERNEL_LAUNCHER(double);
#undef DEFINE_ASSIGN_KERNEL_LAUNCHER
} // namespace kernel
} // namepsace dragon
\ No newline at end of file
#ifdef WITH_CUDA
#include "core/context_cuda.h"
#include "utils/op_kernel.h"
namespace dragon {
namespace kernel {
#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) \
do { \
const auto n_copy = n; \
*q = n_copy / d; \
*r = n_copy % d; \
} while (0)
/*! Assign <T = ?, Device = CUDA> */
template<typename T>
__global__ void _Assign(
const int nthreads,
const int ndims,
const int* x_dims,
const int* y_strides,
const int* starts,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(x_idx, nthreads) {
int y_idx = 0, tmp = x_idx;
#pragma unroll
for (int d = ndims - 1; d >= 0; --d) {
int r;
#if __CUDA_ARCH__ >= 350
FIXED_DIVISOR_DIV_MOD(__ldg(x_dims + d), tmp, &tmp, &r);
y_idx += (r + __ldg(starts + d)) * __ldg(y_strides + d);
#else
FIXED_DIVISOR_DIV_MOD(x_dims[d], tmp, &tmp, &r);
y_idx += (r + starts[d]) * y_strides[d];
#endif
}
y[y_idx] = x[x_idx];
}
}
/*! Kernel Launchers */
#define DEFINE_ASSIGN_KERNEL_LAUNCHER(T) \
template<> void Assign<T, CUDAContext>( \
const int count, \
const int ndims, \
const int* x_dims, \
const int* y_strides, \
const int* starts, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_Assign<T> \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \
(count, ndims, x_dims, y_strides, starts, x, y); \
}
DEFINE_ASSIGN_KERNEL_LAUNCHER(bool);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(uint8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int64_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float16);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float);
DEFINE_ASSIGN_KERNEL_LAUNCHER(double);
#undef FIXED_DIVISOR_DIV_MOD
#undef DEFINE_ASSIGN_KERNEL_LAUNCHER
} // namespace kernel
} // namepsace dragon
#endif // WITH_CUDA
\ No newline at end of file
......@@ -79,21 +79,21 @@ void ONNXBackend::ONNXTensorToArgument(
Argument* dtype,
Argument* values) {
if (onnx_tensor.data_type() == TensorProto::FLOAT16) {
/*! floa16 - float_data */
/*! float16: raw_data = >floats */
dtype->set_s("float16");
auto* floats = values->mutable_floats();
CHECK((TryConvertingTensorRawValues_v2<
google::protobuf::uint16, float>(onnx_tensor, floats)))
<< "Excepted the raw data to store the FLOAT16.";
} else if (onnx_tensor.data_type() == TensorProto::FLOAT) {
/*! float32 - float_data */
/*! float32: float_data | raw_data => floats */
dtype->set_s("float32");
auto* floats = values->mutable_floats();
if (!TryConvertingTensorRawValues<float>(onnx_tensor, floats)) {
floats->CopyFrom(onnx_tensor.float_data());
}
} else if (onnx_tensor.data_type() == TensorProto::DOUBLE) {
/*! float64 - double_data */
/*! float64: double_data | raw_data => floats */
dtype->set_s("float64");
google::protobuf::RepeatedField<double> tmp;
const auto* src = &tmp;
......@@ -102,43 +102,43 @@ void ONNXBackend::ONNXTensorToArgument(
}
for (const auto i : *src) values->add_floats(i);
} else if (onnx_tensor.data_type() == TensorProto::INT64) {
/*! <int64> - int64 - int64_data */
/*! int64: int64_data | raw_data => ints */
dtype->set_s("int64");
ConvertIntegralValue<google::protobuf::int64>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::UINT64) {
/*! <uint64> - uint64 - uint64_data */
/*! uint64: uint64_data | raw_data => ints */
dtype->set_s("uint64");
ConvertIntegralValue<google::protobuf::uint64>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::UINT32) {
/*! <uint64> - uint32 - uint64_data */
/*! uint32: uint64_data | raw_data => ints */
dtype->set_s("uint32");
ConvertIntegralValue<google::protobuf::uint64>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::BOOL) {
/*! <T> - bool - int32_data */
/*! bool: int32_data | raw_data => ints */
dtype->set_s("bool");
ConvertIntegralValue<google::protobuf::int8>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::UINT8) {
/*! <T> - uint8 - int32_data */
/*! uint8: int32_data | raw_data => ints */
dtype->set_s("uint8");
ConvertIntegralValue<google::protobuf::uint8>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::INT8) {
/*! <T> - int8 - int32_data */
/*! int8: int32_data | raw_data => ints */
dtype->set_s("int8");
ConvertIntegralValue<google::protobuf::int8>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::UINT16) {
/*! <T> - uint16 - int32_data */
/*! uint16: int32_data | raw_data => ints */
dtype->set_s("uint16");
ConvertIntegralValue<google::protobuf::uint16>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::INT16) {
/*! <T> - int16 - int32_data */
/*! int16: int32_data | raw_data => ints */
dtype->set_s("int16");
ConvertIntegralValue<google::protobuf::int16>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::INT32) {
/*! <T> - int32 - int32_data */
/*! int32: int32_data | raw_data => ints */
dtype->set_s("int32");
ConvertIntegralValue<google::protobuf::int32>(onnx_tensor, values);
} else if (onnx_tensor.data_type() == TensorProto::STRING) {
/*! <string> - string - string_data */
/*! string: string_data => strings */
dtype->set_s("string");
auto* strings = values->mutable_strings();
strings->CopyFrom(onnx_tensor.string_data());
......
......@@ -44,7 +44,8 @@ void ReluGradientOp<Context>::RunOnDevice() {
Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) RunWithType<float>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else if (XIsType(Input(0), float16)) RunWithType<float16>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
}
DEPLOY_CPU(ReluGradient);
......
......@@ -16,7 +16,8 @@ void SEluOp<Context>::RunOnDevice() {
Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) RunWithType<float>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else if (XIsType(Input(0), float16)) RunWithType<float16>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
}
DEPLOY_CPU(SElu);
......@@ -40,7 +41,8 @@ void SEluGradientOp<Context>::RunOnDevice() {
Output(0)->ReshapeLike(Input(0));
if (XIsType(Input(0), float)) RunWithType<float>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32" });
else if (XIsType(Input(0), float16)) RunWithType<float16>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
}
DEPLOY_CPU(SEluGradient);
......
......@@ -72,11 +72,8 @@ void GatherGradientOp<Context>::RunWithType() {
auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>();
// Zero the gradients Optionally
if (zero_grad) {
math::Set(Output(0)->count(),
cast::to<T>(0.f), dXdata, ctx());
}
kernel::GatherGrad(
outer_dim, inner_dim,
......
#include "core/workspace.h"
#include "utils/op_kernel.h"
#include "utils/math_utils.h"
#include "utils/math_functions.h"
#include "operators/control_flow/assign_op.h"
namespace dragon {
#define TENSOR_FROM_VECTOR(tensor, vec, T) \
{ \
tensor.Reshape({ (int64_t)vec.size() }); \
auto* data = tensor.template mutable_data<T, CPUContext>(); \
for (int i = 0; i < vec.size(); i++) data[i] = (T)vec[i]; \
}
template <class Context> template <typename T>
void AssignOp<Context>::RunWithType() {
const T* Xdata = nullptr;
auto* XDS = x_dimsT.template data<int, Context>();
auto* YSS = y_stridesT.template data<int, Context>();
auto* STS = startsT.template data<int, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>();
if (Input(0).count() < fake_x.count()) {
int rows, cols;
auto* WSdata = ws()->template caches
<T, Context>({ fake_x.count() })[0];
auto* RXdata = Input(0).template data<T, Context>();
if (utils::IsRowwiseBroadcast(
fake_x.dims(), Input(0).dims(),
&rows, &cols)) {
math::BroadcastSet(rows, cols, 0,
RXdata, WSdata, ctx());
} else if (utils::IsColwiseBroadcast(
fake_x.dims(), Input(0).dims(),
&rows, &cols)) {
math::BroadcastSet(rows, cols, 1,
RXdata, WSdata, ctx());
} else {
LOG(FATAL) << "Could not broadcast "
<< Input(0).DimString() << " to "
<< fake_x.DimString();
}
Xdata = WSdata;
} else if (Input(0).count() == fake_x.count()) {
Xdata = Input(0).template data<T, Context>();
} else {
LOG(FATAL) << "Could not assign "
<< Input(0).DimString() << " to "
<< Output(0)->DimString();
}
// Apply a simple Nd-Broadcast solution
kernel::Assign(fake_x.count(), x_dimsT.count(),
XDS, YSS, STS, Xdata, Ydata, ctx());
}
template <class Context>
void AssignOp<Context>::Setup() {
st.assign((size_t)Output(0)->ndim(), 0);
ed.assign(st.size(), 0);
// Determine the starts
int n_starts = GET_ARGUMENTS_SIZE(starts);
for (int i = 0; i < st.size(); i++)
if (i < n_starts) st[i] = starts(i);
// Determine the ends
int n_sizes = GET_ARGUMENTS_SIZE(sizes);
for (int i = 0; i < ed.size(); i++) {
ed[i] = Output(0)->dim(i);
if (i < n_sizes) {
auto len = sizes(i);
if (len > 0) { ed[i] = st[i] + len; }
else if (len == 0) { ed[i] = st[i] + 1; }
}
}
// Check starts and ends
for (int i = 0; i < st.size(); i++) {
CHECK(st[i] >= 0 && st[i] < Output(0)->dim(i))
<< "\nThe assigning starts at the pos " << st[i] << " of axis " << i << ", "
<< "while the dimension of this axis is " << Output(0)->dim(i) << ".";
CHECK(ed[i] > 0 && ed[i] <= Output(0)->dim(i))
<< "\nThe assigning ends at the pos " << ed[i] << " of axis " << i << ", "
<< "while the dimension of this axis is " << Output(0)->dim(i) << ".";
}
x_dimsV = Output(0)->dims();
for (int i = 0; i < st.size(); i++)
x_dimsV[i] = ed[i] - st[i];
fake_x.Reshape(x_dimsV);
}
template <class Context>
void AssignOp<Context>::RunOnDevice() {
Setup();
TENSOR_FROM_VECTOR(y_stridesT, Output(0)->strides(), int);
TENSOR_FROM_VECTOR(x_dimsT, x_dimsV, int);
TENSOR_FROM_VECTOR(startsT, st, int);
if (XIsType(Input(0), bool)) RunWithType<bool>();
else if (XIsType(Input(0), int8_t)) RunWithType<int8_t>();
else if (XIsType(Input(0), uint8_t)) RunWithType<uint8_t>();
else if (XIsType(Input(0), int)) RunWithType<int>();
else if (XIsType(Input(0), int64_t)) RunWithType<int64_t>();
else if (XIsType(Input(0), float16)) RunWithType<float16>();
else if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), double)) RunWithType<double>();
else LOG(FATAL) << DTypeHelper(Input(0), {
"bool", "int8", "uint8", "int32", "int64",
"float16", "float32", "float64",
});
}
DEPLOY_CPU(Assign);
#ifdef WITH_CUDA
DEPLOY_CUDA(Assign);
#endif
OPERATOR_SCHEMA(Assign).NumInputs(1).NumOutputs(1);
NO_GRADIENT(Assign);
} // namespace dragon
\ No newline at end of file
......@@ -81,6 +81,35 @@ DEFINE_SET_FUNC(float);
DEFINE_SET_FUNC(double);
#undef DEFINE_SET_FUNC
#define DEFINE_BROADCAST_SET_FUNC(T) \
template <> void BroadcastSet<T, CPUContext>( \
const int rows, \
const int cols, \
const int type, \
const T* x, \
T* y, \
CPUContext* ctx) { \
if (type == 0) { \
/*! Row - BroadcastX */ \
EigenArrayMap<T>(y, cols, rows).colwise() = \
ConstEigenVectorArrayMap<T>(x, cols); \
} else if (type == 1) { \
/*! Col - BroadcastX */ \
EigenArrayMap<T>(y, cols, rows).rowwise() = \
ConstEigenVectorArrayMap<T>(x, rows).transpose(); \
} \
}
DEFINE_BROADCAST_SET_FUNC(bool);
DEFINE_BROADCAST_SET_FUNC(int8_t);
DEFINE_BROADCAST_SET_FUNC(uint8_t);
DEFINE_BROADCAST_SET_FUNC(int);
DEFINE_BROADCAST_SET_FUNC(int64_t);
DEFINE_BROADCAST_SET_FUNC(float16);
DEFINE_BROADCAST_SET_FUNC(float);
DEFINE_BROADCAST_SET_FUNC(double);
#undef DEFINE_BROADCAST_SET_FUNC
/*! y = x^e */
#define DEFINE_POWX_FUNC(T) \
......
......@@ -130,6 +130,28 @@ __global__ void _Set(
}
template <typename T>
__global__ void _RowBroadcastSet(
const int count,
const int cols,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(idx, count) {
y[idx] = x[idx % cols];
}
}
template <typename T>
__global__ void _ColBroadcastSet(
const int count,
const int cols,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(idx, count) {
y[idx] = x[idx / cols];
}
}
template <typename T>
__global__ void _Pow(
const int n,
const T exp,
......@@ -212,6 +234,40 @@ DEFINE_SET_FUNC(float);
DEFINE_SET_FUNC(double);
#undef DEFINE_SET_FUNC
#define DEFINE_BROADCAST_SET_FUNC(T) \
template <> void BroadcastSet<T, CUDAContext>( \
const int rows, \
const int cols, \
const int type, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
auto n = rows * cols; \
if (type == 0) { \
/*! Row - BroadcastX */ \
_RowBroadcastSet<T> \
<< < CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \
(n, cols, x, y); \
} else if (type == 1) { \
/*! Col - BroadcastX */ \
_ColBroadcastSet<T> \
<< < CUDA_BLOCKS(n), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \
(n, cols, x, y); \
} \
}
DEFINE_BROADCAST_SET_FUNC(bool);
DEFINE_BROADCAST_SET_FUNC(int8_t);
DEFINE_BROADCAST_SET_FUNC(uint8_t);
DEFINE_BROADCAST_SET_FUNC(int);
DEFINE_BROADCAST_SET_FUNC(int64_t);
DEFINE_BROADCAST_SET_FUNC(float16);
DEFINE_BROADCAST_SET_FUNC(float);
DEFINE_BROADCAST_SET_FUNC(double);
#undef DEFINE_BROADCAST_SET_FUNC
/*! y = x^e */
#define DEFINE_POWX_FUNC(T) \
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!