Commit 1d03e8e2 by Ting PAN

Optimize GatherOp

1 parent c5def39b
Showing with 341 additions and 366 deletions
...@@ -283,14 +283,16 @@ code.docutils.literal:hover { ...@@ -283,14 +283,16 @@ code.docutils.literal:hover {
dt { dt {
font-weight: 700; font-weight: 700;
background: #e7f2fa; background: #f7f7f7;
border-bottom: solid #0079b2; border-bottom: solid #0079b2;
border-radius: 1px; border-radius: 8px;
margin-bottom: 20px; margin-bottom: 20px;
padding: 8px;
width: 75%;
} }
dt:target, .highlighted { dt:target, .highlighted {
background-color: #e7f2fa; background-color: #f7f7f7;
border-bottom: 3px solid #c7254e; border-bottom: 3px solid #c7254e;
} }
...@@ -299,7 +301,7 @@ dt:target:before { ...@@ -299,7 +301,7 @@ dt:target:before {
content: ''; content: '';
display: block; display: block;
height: 65px; height: 65px;
margin: -20px 0 0; margin: -20px -8px 8px;
} }
dl.method dt { dl.method dt {
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
========================== ============================================================================= ========================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
============================== ============================================================================= ============================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
============================== ============================================================================= ============================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
......
...@@ -112,8 +112,8 @@ List Brief ...@@ -112,8 +112,8 @@ List Brief
================================= ============================================================================= ================================= =============================================================================
Quick Shortcut Quick Reference
-------------- ---------------
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
========================= ============================================================================ ========================= ============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
========================= ============================================================================= ========================= =============================================================================
List Brief List Brief
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
.. toctree:: .. toctree::
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
:hidden: :hidden:
Quick Shortcut Quick Reference
-------------- ---------------
============================== ======================================================================= ============================== =======================================================================
List Brief List Brief
......
...@@ -39,15 +39,15 @@ class GatherGradientOp final : public Operator<Context> { ...@@ -39,15 +39,15 @@ class GatherGradientOp final : public Operator<Context> {
GatherGradientOp(const OperatorDef& def, Workspace* ws) GatherGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), : Operator<Context>(def, ws),
axis(OperatorBase::Arg<int64_t>("axis", 0)), axis(OperatorBase::Arg<int64_t>("axis", 0)),
acc_grad(OperatorBase::Arg<bool>("acc_gradient", false)) {} zero_grad(OperatorBase::Arg<bool>("zero_grad", true)) {}
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
bool zero_grad;
int64_t axis, outer_dim, inner_dim, x_slice_dim, y_slice_dim; int64_t axis, outer_dim, inner_dim, x_slice_dim, y_slice_dim;
bool acc_grad;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -601,32 +601,23 @@ void ArgMin( ...@@ -601,32 +601,23 @@ void ArgMin(
/*! ndarray.gather */ /*! ndarray.gather */
template <typename T, class Context> template <typename T, class Context>
void CanonicalAxis(
const int count,
const int dim,
T* y,
Context* ctx);
template <typename T, class Context>
void Gather( void Gather(
const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int64_t* indices,
const T* x, const T* x,
T* y, T* y,
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void GatherGrad( void GatherGrad(
const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int64_t* indices,
const T* dy, const T* dy,
T* dx, T* dx,
Context* ctx); Context* ctx);
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "core/common.h" #include "core/common.h"
#include "utils/proto_utils.h" #include "utils/proto_utils.h"
#include "utils/caffemodel.h" #include "utils/caffemodel.h"
#include "contrib/onnx/onnx_backend.h" #include "onnx/onnx_backend.h"
#include "dragon.h" #include "dragon.h"
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#ifndef DRAGON_PYTHON_PY_ONNX_H_ #ifndef DRAGON_PYTHON_PY_ONNX_H_
#define DRAGON_PYTHON_PY_ONNX_H_ #define DRAGON_PYTHON_PY_ONNX_H_
#include "contrib/onnx/onnx_backend.h" #include "onnx/onnx_backend.h"
#include "py_dragon.h" #include "py_dragon.h"
......
...@@ -270,7 +270,7 @@ def ExportMetaGraph(prefix=''): ...@@ -270,7 +270,7 @@ def ExportMetaGraph(prefix=''):
These text files will be saved as the following format: These text files will be saved as the following format:
``prefix/Graph_xxx.metatxt`` *prefix/Graph.metatxt*
Note that an empty prefix will leads to invalid exporting. Note that an empty prefix will leads to invalid exporting.
...@@ -293,12 +293,12 @@ def SetLoggingLevel(level): ...@@ -293,12 +293,12 @@ def SetLoggingLevel(level):
Parameters Parameters
---------- ----------
level : str level : {'DEBUG', 'INFO, 'WARNING', 'ERROR', 'FATAL'}, required
The level, ``DEBUG``, ``INFO``, ``WARNING``, ``ERROR`` or ``FATAL``. The logging level.
Notes Notes
----- -----
The default level is ``INFO``. The default level is *INFO*.
""" """
C.SetLogLevelCC(level) C.SetLogLevelCC(level)
......
...@@ -391,9 +391,12 @@ class OperatorHelper(object): ...@@ -391,9 +391,12 @@ class OperatorHelper(object):
@classmethod @classmethod
def _apply_Gather(cls, arguments, inputs, outputs): def _apply_Gather(cls, arguments, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
axis = arguments['axis']
try: try:
outputs[0].shape = inputs[0].shape[:] outputs[0].shape = \
outputs[0].shape[arguments['axis']] = None inputs[0].shape[:axis] + \
inputs[1].shape[:] + \
inputs[0].shape[axis + 1:]
except: except:
pass pass
return outputs return outputs
......
...@@ -17,10 +17,10 @@ from . import * ...@@ -17,10 +17,10 @@ from . import *
@OpSchema.Inputs(1) @OpSchema.Inputs(1)
def Gather(inputs, indices, axis=0, acc_gradient=False, **kwargs): def Gather(inputs, indices, axis=0, zero_grad=True, **kwargs):
"""Gather the input according to the indices along the given axis. """Gather the input according to the indices along the given axis.
**Type Constraints**: (*int32*, *float32*) **Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters Parameters
---------- ----------
...@@ -30,7 +30,7 @@ def Gather(inputs, indices, axis=0, acc_gradient=False, **kwargs): ...@@ -30,7 +30,7 @@ def Gather(inputs, indices, axis=0, acc_gradient=False, **kwargs):
The indices to form output tensor. The indices to form output tensor.
axis : int, optional axis : int, optional
The start axis, can be negative. The start axis, can be negative.
acc_gradient : bool, optional zero_grad : bool, optional
Whether to accumulate the gradients. Whether to accumulate the gradients.
Returns Returns
...@@ -40,24 +40,10 @@ def Gather(inputs, indices, axis=0, acc_gradient=False, **kwargs): ...@@ -40,24 +40,10 @@ def Gather(inputs, indices, axis=0, acc_gradient=False, **kwargs):
""" """
arguments = ParseArgs(locals()) arguments = ParseArgs(locals())
arguments['inputs'], arguments['indices'] = \
arguments['inputs'], arguments['indices'] = [arguments['inputs'], [arguments['inputs'], Tensor.Convert(
Tensor.Convert(indices, dtype='int32')], None indices, dtype='int64')], None
return Tensor.CreateOperator('Gather', **arguments)
output = Tensor.CreateOperator('Gather', **arguments)
try:
output.shape = inputs.shape[:]
if not isinstance(indices, Tensor):
if not isinstance(indices, (list, tuple)):
indices = [indices]
output.shape[axis] = len(indices)
else:
output.shape[axis] = None
except:
pass
return output
@OpSchema.Inputs(1) @OpSchema.Inputs(1)
......
...@@ -283,9 +283,7 @@ def Pool2d( ...@@ -283,9 +283,7 @@ def Pool2d(
@OpSchema.Inputs(2) @OpSchema.Inputs(2)
def ROIPool(inputs, pool_h, pool_w, spatial_scale=1.0, **kwargs): def ROIPool(inputs, pool_h, pool_w, spatial_scale=1.0, **kwargs):
"""Max RoI Pooling. `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_. """Max RoIPooling. `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
The first dimension of input must be ``1``.
**Type Constraints**: (*float16*, *float32*) **Type Constraints**: (*float16*, *float32*)
...@@ -311,9 +309,7 @@ def ROIPool(inputs, pool_h, pool_w, spatial_scale=1.0, **kwargs): ...@@ -311,9 +309,7 @@ def ROIPool(inputs, pool_h, pool_w, spatial_scale=1.0, **kwargs):
@OpSchema.Inputs(2) @OpSchema.Inputs(2)
def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, sampling_ratio=2, **kwargs): def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, sampling_ratio=2, **kwargs):
"""AVG ROIAlign. `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_. """AVG RoIAlign. `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
The first dimension of input must be ``1``.
**Type Constraints**: (*float16*, *float32*) **Type Constraints**: (*float16*, *float32*)
......
...@@ -20,7 +20,7 @@ from multiprocessing import Process ...@@ -20,7 +20,7 @@ from multiprocessing import Process
class BlobFetcher(Process): class BlobFetcher(Process):
"""BlobFetcher is deployed to queue blobs from `DataTransformer`_. """BlobFetcher is deployed to queue blobs from `DataTransformer`_.
It is supported to form ``NHWC`` image blobs and ``1D`` label blobs. It is supported to form *NHWC* image blobs and *1d* label blobs.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
......
...@@ -26,7 +26,7 @@ from .blob_fetcher import BlobFetcher ...@@ -26,7 +26,7 @@ from .blob_fetcher import BlobFetcher
class DataBatch(object): class DataBatch(object):
"""DataBatch aims to prefetch data by ``Triple-Buffering``. """DataBatch aims to prefetch data by *Triple-Buffering*.
It takes full advantages of the Process/Thread of Python, It takes full advantages of the Process/Thread of Python,
which provides remarkable I/O speed up for scalable distributed training. which provides remarkable I/O speed up for scalable distributed training.
......
...@@ -236,4 +236,4 @@ class Parameters(object): ...@@ -236,4 +236,4 @@ class Parameters(object):
_param_names = param_name_dict() _param_names = param_name_dict()
layers = Layers() layers = Layers()
params = Parameters() params = Parameters()
\ No newline at end of file
...@@ -354,15 +354,14 @@ class Function(object): ...@@ -354,15 +354,14 @@ class Function(object):
# Store for future development # Store for future development
self.meta_graph = meta_graph self.meta_graph = meta_graph
self.graph_name = meta_graph.name
# Call c api to create graph # Call c api to create graph
ws.CreateGraph(meta_graph) self.graph_name = ws.CreateGraph(meta_graph)
# Bind a lambda callback to run this graph # Bind a lambda callback to run this graph
callback_inputs = self.inputs if explicit_inputs else [] callback_inputs = self.inputs if explicit_inputs else []
self.callback = lambda *args, **kwargs: \ self.callback = lambda *args, **kwargs: \
ws.RunGraph(meta_graph.name, (callback_inputs, args), self.outputs, **kwargs) ws.RunGraph(self.graph_name, (callback_inputs, args), self.outputs, **kwargs)
# Self return # Self return
return self return self
...@@ -386,7 +385,7 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -386,7 +385,7 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
---------- ----------
inputs : sequence of Tensor, optional inputs : sequence of Tensor, optional
The inputs to feed. The inputs to feed.
inputs : sequence of Tensor, optional outputs : sequence of Tensor, optional
The outputs to fetch. The outputs to fetch.
givens : dict of Tensor, optional givens : dict of Tensor, optional
The substitutions to use. The substitutions to use.
......
...@@ -60,6 +60,7 @@ class Gather(BaseModule): ...@@ -60,6 +60,7 @@ class Gather(BaseModule):
'n_inputs': 2, 'n_outputs': 1, 'n_inputs': 2, 'n_outputs': 1,
'arguments': { 'arguments': {
'axis': self.axis, 'axis': self.axis,
'zero_grad': True,
} }
} }
......
...@@ -188,16 +188,16 @@ inline void RetrieveRoIs( ...@@ -188,16 +188,16 @@ inline void RetrieveRoIs(
template <typename T> template <typename T>
inline int roi_level( inline int roi_level(
const int min_level, // e.g. 2 const int min_level,
const int max_level, // e.g. 5 const int max_level,
const int canonical_level, // e.g. 4 const int canonical_level,
const int canonical_scale, // e.g. 224 const int canonical_scale,
T* roi) { T* roi) {
T w = roi[3] - roi[1] + 1; T w = roi[3] - roi[1] + 1;
T h = roi[4] - roi[2] + 1; T h = roi[4] - roi[2] + 1;
// Refer the settings of paper // Refer the settings of paper
int level = canonical_level + (int)std::log( int level = canonical_level + std::log2(
std::max(std::sqrt(w * h), (T)1) / (T)canonical_scale); std::max(std::sqrt(w * h), (T)1) / (T)canonical_scale);
return std::min(max_level, std::max(min_level, level)); return std::min(max_level, std::max(min_level, level));
} }
......
...@@ -80,7 +80,7 @@ void ProposalOp<Context>::RunWithType( ...@@ -80,7 +80,7 @@ void ProposalOp<Context>::RunWithType(
anchors_.Reshape({ A, 4 }); anchors_.Reshape({ A, 4 });
rcnn::GenerateAnchors<BT>(strides[i], rcnn::GenerateAnchors<BT>(strides[i],
(int)ratios.size(), 1, &ratios[0], &scales[0], (int)ratios.size(), 1, &ratios[0], &scales[i],
anchors_.template mutable_data<BT, CPUContext>()); anchors_.template mutable_data<BT, CPUContext>());
rcnn::GenerateGridAnchors<BT>( rcnn::GenerateGridAnchors<BT>(
......
...@@ -6,134 +6,93 @@ namespace dragon { ...@@ -6,134 +6,93 @@ namespace dragon {
namespace kernel { namespace kernel {
/*! CanonicalAxis <T = int32, Device = CPU> */
template <> void CanonicalAxis<int, CPUContext>(
const int count,
const int dim,
int* y,
CPUContext* ctx) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) if (y[i] < 0) y[i] += dim;
}
/*! Gather <T = ?, Device = CPU> */ /*! Gather <T = ?, Device = CPU> */
template <typename T> template <typename T>
void _Gather( void _Gather(
const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int64_t* indices,
const T* x, const T* x,
T* y, T* y,
CPUContext* ctx) { CPUContext* ctx) {
int64_t x_offset, y_offset, x_idx_offset, y_idx_offset; int64_t x_offset, select_idx;
for (int i = 0; i < y_slice_dim; ++i) { for (int n = 0; n < outer_dim; ++n) {
y_idx_offset = i; for (int i = 0; i < y_slice_dim; ++i) {
x_idx_offset = indices[y_idx_offset]; select_idx = indices[i];
for (int n = 0; n < outer_dim; ++n) { select_idx = select_idx >= 0 ?
x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim; select_idx : select_idx + x_slice_dim;
y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim; x_offset = (n * x_slice_dim + select_idx) * inner_dim;
ctx->Copy<T, CPUContext, CPUContext>( ctx->Copy<T, CPUContext, CPUContext>(
inner_dim, y + y_offset, x + x_offset); inner_dim, y, x + x_offset);
y += inner_dim;
} }
} }
} }
/*! Gather <T = float32, Device = CPU> */
template <> void Gather<float, CPUContext>(
const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const float* x,
float* y,
CPUContext* ctx) {
_Gather<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, indices, x, y, ctx);
}
/*! Gather <T = int32, Device = CPU> */
template <> void Gather<int, CPUContext>(
const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const int* x,
int* y,
CPUContext* ctx) {
_Gather<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, indices, x, y, ctx);
}
/*! GatherGrad <T = ?, Device = CPU> */ /*! GatherGrad <T = ?, Device = CPU> */
template <typename T> template <typename T>
void _GatherGrad( void _GatherGrad(
const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int64_t* indices,
const T* dy, const T* dy,
T* dx, T* dx,
CPUContext* ctx) { CPUContext* ctx) {
int64_t x_offset, y_offset, x_idx_offset, y_idx_offset; int64_t x_offset, select_idx;
for (int i = 0; i < y_slice_dim; ++i) { for (int n = 0; n < outer_dim; ++n) {
y_idx_offset = i; for (int i = 0; i < y_slice_dim; ++i) {
x_idx_offset = indices[y_idx_offset]; select_idx = indices[i];
for (int n = 0; n < outer_dim; ++n) { select_idx = select_idx >= 0 ?
x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim; select_idx : select_idx + x_slice_dim;
y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim; x_offset = (n * x_slice_dim + select_idx) * inner_dim;
math::Add<T, CPUContext>(inner_dim, math::Add<T, CPUContext>(inner_dim,
dy + y_offset, dx + x_offset, dx + x_offset, ctx); dy, dx + x_offset, dx + x_offset, ctx);
dy += inner_dim;
} }
} }
} }
/*! GatherGrad <T = float32, Device = CPU> */ /*! Kernel Launchers */
template <> void GatherGrad<float, CPUContext>( #define DEFINE_GATHER_KERNEL_LAUNCHER(name, T) \
const int count, template <> void name<T, CPUContext>( \
const int outer_dim, const int outer_dim, \
const int inner_dim, const int inner_dim, \
const int x_slice_dim, const int x_slice_dim, \
const int y_slice_dim, const int y_slice_dim, \
const int* indices, const int64_t* indices, \
const float* dy, const T* x, \
float* dx, T* y, \
CPUContext* ctx) { CPUContext* ctx) { \
_GatherGrad<float>(count, outer_dim, inner_dim, _##name<T> \
x_slice_dim, y_slice_dim, indices, dy, dx, ctx); (outer_dim, inner_dim, x_slice_dim, \
} y_slice_dim, indices, x, y, ctx); \
}
/*! GatherGrad <T = int32, Device = CPU> */
template <> void GatherGrad<int, CPUContext>( DEFINE_GATHER_KERNEL_LAUNCHER(Gather, bool);
const int count, DEFINE_GATHER_KERNEL_LAUNCHER(Gather, int8_t);
const int outer_dim, DEFINE_GATHER_KERNEL_LAUNCHER(Gather, uint8_t);
const int inner_dim, DEFINE_GATHER_KERNEL_LAUNCHER(Gather, int);
const int x_slice_dim, DEFINE_GATHER_KERNEL_LAUNCHER(Gather, int64_t);
const int y_slice_dim, DEFINE_GATHER_KERNEL_LAUNCHER(Gather, float16);
const int* indices, DEFINE_GATHER_KERNEL_LAUNCHER(Gather, float);
const int* dy, DEFINE_GATHER_KERNEL_LAUNCHER(Gather, double);
int* dx,
CPUContext* ctx) { DEFINE_GATHER_KERNEL_LAUNCHER(GatherGrad, int8_t);
_GatherGrad<int>(count, outer_dim, inner_dim, DEFINE_GATHER_KERNEL_LAUNCHER(GatherGrad, uint8_t);
x_slice_dim, y_slice_dim, indices, dy, dx, ctx); DEFINE_GATHER_KERNEL_LAUNCHER(GatherGrad, int);
} DEFINE_GATHER_KERNEL_LAUNCHER(GatherGrad, int64_t);
DEFINE_GATHER_KERNEL_LAUNCHER(GatherGrad, float16);
DEFINE_GATHER_KERNEL_LAUNCHER(GatherGrad, float);
DEFINE_GATHER_KERNEL_LAUNCHER(GatherGrad, double);
#undef DEFINE_GATHER_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
......
...@@ -2,160 +2,176 @@ ...@@ -2,160 +2,176 @@
#include "core/context_cuda.h" #include "core/context_cuda.h"
#include "utils/op_kernel.h" #include "utils/op_kernel.h"
#include "utils/cub_device.h"
namespace dragon { namespace dragon {
namespace kernel { namespace kernel {
/*! CanonicalAxis <T = int32, Device = CUDA> */
template <typename T>
__global__ void _CanonicalAxis(
const int count,
const int dim,
T* y) {
CUDA_1D_KERNEL_LOOP(idx, count) {
if (y[idx] < 0) y[idx] += dim;
}
}
template <> void CanonicalAxis<int, CUDAContext>(
const int count,
const int dim,
int* y,
CUDAContext* ctx) {
_CanonicalAxis<int>
<< < CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count, dim, y);
}
/*! Gather <T = ?, Device = CUDA> */ /*! Gather <T = ?, Device = CUDA> */
template <typename T> template <typename T>
__global__ void _Gather( __global__ void _Gather(
const int count, const int nthreads,
const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int64_t* indices,
const T* x, const T* x,
T* y) { T* y) {
CUDA_1D_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(y_idx, nthreads) {
const int outer_idx = idx / inner_dim / y_slice_dim; const int outer_idx = y_idx / inner_dim / y_slice_dim;
const int slice_idx = idx % inner_dim; const int inner_idx = y_idx % inner_dim;
const int y_idx_offset = (idx / inner_dim) % y_slice_dim; #if __CUDA_ARCH__ >= 350
const int x_idx_offset = indices[y_idx_offset]; int select_idx = __ldg(indices +
const int x_idx = (outer_idx * x_slice_dim + x_idx_offset) ((y_idx / inner_dim) % y_slice_dim));
* inner_dim + slice_idx; #else
y[idx] = x[x_idx]; int select_idx = indices[
(y_idx / inner_dim) % y_slice_dim];
#endif
select_idx = select_idx >= 0 ?
select_idx : select_idx + x_slice_dim;
const int x_idx = (outer_idx * x_slice_dim + select_idx)
* inner_dim + inner_idx;
y[y_idx] = x[x_idx];
} }
} }
/*! Gather <T = float32, Device = CUDA> */
template <> void Gather<float, CUDAContext>(
const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const float* x,
float* y,
CUDAContext* ctx) {
_Gather<float>
<< < CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y);
}
/*! Gather <T = int32, Device = CUDA> */
template <> void Gather<int, CUDAContext>(
const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
const int y_slice_dim,
const int* indices,
const int* x,
int* y,
CUDAContext* ctx) {
_Gather<int>
<< <CUDA_BLOCKS(count), CUDA_THREADS,
0, ctx->cuda_stream() >> >
(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y);
}
/*! GatherGrad <T = ?, Device = CUDA> */ /*! GatherGrad <T = ?, Device = CUDA> */
template <typename T> template <typename T>
__global__ void _GatherGrad( __global__ void _GatherGrad(
const int count, const int nthreads,
const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int64_t* indices,
const T* dy, const T* dy,
T* dx) { T* dx) {
CUDA_1D_KERNEL_LOOP(idx, count) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
const int outer_idx = idx / inner_dim / y_slice_dim; const int outer_idx = i / inner_dim;
const int slice_idx = idx % inner_dim; const int inner_idx = i % inner_dim;
const int y_idx_offset = (idx / inner_dim) % y_slice_dim; for (int j = 0; j < y_slice_dim; ++j) {
const int x_idx_offset = indices[y_idx_offset]; #if __CUDA_ARCH__ >= 350
const int x_idx = (outer_idx * x_slice_dim + x_idx_offset) int select_idx = __ldg(indices + j);
* inner_dim + slice_idx; #else
atomicAdd(dx + x_idx, dy[idx]); int select_idx = indices[j];
#endif
select_idx = select_idx >= 0 ?
select_idx : select_idx + x_slice_dim;
const int x_idx = (outer_idx * x_slice_dim + select_idx)
* inner_dim + inner_idx;
const int y_idx = (outer_idx * y_slice_dim + j)
* inner_dim + inner_idx;
dx[x_idx] += dy[y_idx];
}
} }
} }
/*! GatherGrad <T = float32, Device = CUDA> */ /*! GatherGrad <T = float16, Device = CUDA> */
template <> void GatherGrad<float, CUDAContext>( template <> __global__ void _GatherGrad<half>(
const int count, const int nthreads,
const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int64_t* indices,
const float* dy, const half* dy,
float* dx, half* dx) {
CUDAContext* ctx) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
_GatherGrad<float> #if __CUDA_ARCH__ >= 530
<< < CUDA_BLOCKS(count), CUDA_THREADS, const int outer_idx = i / inner_dim;
0, ctx->cuda_stream() >> > const int inner_idx = i % inner_dim;
(count, outer_dim, inner_dim, for (int j = 0; j < y_slice_dim; ++j) {
x_slice_dim, y_slice_dim, int select_idx = __ldg(indices + j);
indices, dy, dx); select_idx = select_idx >= 0 ?
select_idx : select_idx + x_slice_dim;
const int x_idx = (outer_idx * x_slice_dim + select_idx)
* inner_dim + inner_idx;
const int y_idx = (outer_idx * y_slice_dim + j)
* inner_dim + inner_idx;
dx[x_idx] = __hadd(dx[x_idx], dy[y_idx]);
}
#endif
}
} }
/*! GatherGrad <T = int32, Device = CUDA> */ /*! Kernel Launchers */
#define DEFINE_GATHER_KERNEL_LAUNCHER(T) \
template <> void Gather<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int x_slice_dim, \
const int y_slice_dim, \
const int64_t* indices, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
auto nthreads = outer_dim * y_slice_dim * inner_dim; \
_Gather<T> \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \
(nthreads, inner_dim, x_slice_dim, \
y_slice_dim, indices, x, y); \
}
#define DEFINE_GATHER_GRAD_KERNEL_LAUNCHER(T) \
template <> void GatherGrad<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int x_slice_dim, \
const int y_slice_dim, \
const int64_t* indices, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
auto nthreads = outer_dim * inner_dim; \
_GatherGrad<T> \
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >> > \
(nthreads, inner_dim, x_slice_dim, \
y_slice_dim, indices, dy, dx); \
}
template <> void GatherGrad<int, CUDAContext>( DEFINE_GATHER_KERNEL_LAUNCHER(bool);
const int count, DEFINE_GATHER_KERNEL_LAUNCHER(int8_t);
DEFINE_GATHER_KERNEL_LAUNCHER(uint8_t);
DEFINE_GATHER_KERNEL_LAUNCHER(int);
DEFINE_GATHER_KERNEL_LAUNCHER(int64_t);
DEFINE_GATHER_KERNEL_LAUNCHER(float16);
DEFINE_GATHER_KERNEL_LAUNCHER(float);
DEFINE_GATHER_KERNEL_LAUNCHER(double);
DEFINE_GATHER_GRAD_KERNEL_LAUNCHER(int8_t);
DEFINE_GATHER_GRAD_KERNEL_LAUNCHER(uint8_t);
DEFINE_GATHER_GRAD_KERNEL_LAUNCHER(int);
DEFINE_GATHER_GRAD_KERNEL_LAUNCHER(int64_t);
DEFINE_GATHER_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GATHER_GRAD_KERNEL_LAUNCHER(double);
template <> void GatherGrad<float16, CUDAContext>(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
const int y_slice_dim, const int y_slice_dim,
const int* indices, const int64_t* indices,
const int* dy, const float16* dy,
int* dx, float16* dx,
CUDAContext* ctx) { CUDAContext* ctx) {
_GatherGrad<int> auto nthreads = outer_dim * inner_dim;
<< < CUDA_BLOCKS(count), CUDA_THREADS, _GatherGrad<half>
<< < CUDA_BLOCKS(nthreads), CUDA_THREADS,
0, ctx->cuda_stream() >> > 0, ctx->cuda_stream() >> >
(count, outer_dim, inner_dim, (nthreads, inner_dim, x_slice_dim,
x_slice_dim, y_slice_dim, y_slice_dim, indices,
indices, dy, dx); reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx));
} }
#undef DEFINE_GATHER_KERNEL_LAUNCHER
#undef DEFINE_GATHER_GRAD_KERNEL_LAUNCHER
} // namespace kernel } // namespace kernel
} // namepsace dragon } // namepsace dragon
......
#include "contrib/onnx/onnx_backend.h" #include "onnx/onnx_backend.h"
namespace dragon { namespace dragon {
......
#include "core/operator_schema.h" #include "core/operator_schema.h"
#include "utils/proto_utils.h" #include "utils/proto_utils.h"
#include "contrib/onnx/onnx_backend.h" #include "onnx/onnx_backend.h"
namespace dragon { namespace dragon {
......
/*! /*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd. * Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
* *
* Licensed under the BSD 2-Clause License. * Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License * You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See, * along with the software. If not, See,
* *
* <https://opensource.org/licenses/BSD-2-Clause> * <https://opensource.org/licenses/BSD-2-Clause>
* *
* Codes are based on: * Codes are based on:
* *
* <https://github.com/pytorch/pytorch/blob/master/caffe2/onnx/backend.h> * <https://github.com/pytorch/pytorch/blob/master/caffe2/onnx/backend.h>
* *
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_CONTRIB_ONNX_ONNX_BACKEND_H_ #ifndef DRAGON_ONNX_ONNX_BACKEND_H_
#define DRAGON_CONTRIB_ONNX_ONNX_BACKEND_H_ #define DRAGON_ONNX_ONNX_BACKEND_H_
#include "core/common.h" #include "core/common.h"
#include "proto/onnx.pb.h" #include "proto/onnx.pb.h"
...@@ -228,4 +228,4 @@ class ONNXBackend { ...@@ -228,4 +228,4 @@ class ONNXBackend {
} // namespace dragon } // namespace dragon
#endif // DRAGON_CONTRIB_ONNX_ONNX_BACKEND_H_ #endif // DRAGON_ONNX_ONNX_BACKEND_H_
\ No newline at end of file \ No newline at end of file
#include "utils/map_utils.h" #include "utils/map_utils.h"
#include "contrib/onnx/onnx_backend.h" #include "onnx/onnx_backend.h"
namespace dragon { namespace dragon {
......
#include "contrib/onnx/onnx_backend.h" #include "onnx/onnx_backend.h"
namespace dragon { namespace dragon {
......
...@@ -57,7 +57,7 @@ void MaximumOp<Context>::RunOnDevice() { ...@@ -57,7 +57,7 @@ void MaximumOp<Context>::RunOnDevice() {
else if (XIsType(Input(0), double)) RunWithType<double>(); else if (XIsType(Input(0), double)) RunWithType<double>();
else LOG(FATAL) << DTypeHelper(Input(0), { else LOG(FATAL) << DTypeHelper(Input(0), {
"int8", "uint8", "int32", "int64", "int8", "uint8", "int32", "int64",
"float16", "float32", "float64", "float16", "float32", "float64",
}); });
} }
......
...@@ -13,12 +13,10 @@ namespace dragon { ...@@ -13,12 +13,10 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void GatherOp<Context>::RunWithType() { void GatherOp<Context>::RunWithType() {
auto* Xdata = Input(0).template data<T, Context>(); auto* Xdata = Input(0).template data<T, Context>();
auto* indices = Input(1).template mutable_data<int, Context>(); auto* indices = Input(1).template mutable_data<int64_t, Context>();
auto* Ydata = Output(0)->template mutable_data<T, Context>(); auto* Ydata = Output(0)->template mutable_data<T, Context>();
kernel::CanonicalAxis(Input(1).count(), x_slice_dim, indices, ctx()); kernel::Gather(
kernel::Gather(Output(0)->count(),
outer_dim, inner_dim, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, x_slice_dim, y_slice_dim,
indices, Xdata, Ydata, ctx()); indices, Xdata, Ydata, ctx());
...@@ -28,22 +26,38 @@ template <class Context> ...@@ -28,22 +26,38 @@ template <class Context>
void GatherOp<Context>::RunOnDevice() { void GatherOp<Context>::RunOnDevice() {
DETERMINE_RUNTIME_ARGUMENTS(Input(0)); DETERMINE_RUNTIME_ARGUMENTS(Input(0));
output_dims = Input(0).dims();
x_slice_dim = Input(0).dim(axis); x_slice_dim = Input(0).dim(axis);
output_dims[axis] = y_slice_dim = Input(1).count(); y_slice_dim = Input(1).count();
outer_dim = Input(0).count(0, axis); outer_dim = Input(0).count(0, axis);
inner_dim = Input(0).count(axis + 1); inner_dim = Input(0).count(axis + 1);
CHECK_GT(y_slice_dim, 0) << "\nLength of indices must > 0."; CHECK_GT(y_slice_dim, 0) << "\nLength of indices must > 0.";
const auto& s1 = Input(0).dims().begin();
const auto& e1 = s1 + axis, s3 = e1 + 1;
const auto& e3 = Input(0).dims().end();
const auto& s2 = Input(1).dims().begin();
const auto& e2 = Input(1).dims().end();
output_dims.assign(s1, e1);
output_dims.insert(output_dims.end(), s2, e2);
output_dims.insert(output_dims.end(), s3, e3);
Output(0)->Reshape(output_dims); Output(0)->Reshape(output_dims);
CHECK(Input(1).template IsType<int>()) CHECK(Input(1).template IsType<int64_t>())
<< "\nThe type of indices should be int32."; << "\nThe type of indices should be int64.";
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), bool)) RunWithType<bool>();
else if (XIsType(Input(0), int8_t)) RunWithType<int8_t>();
else if (XIsType(Input(0), uint8_t)) RunWithType<uint8_t>();
else if (XIsType(Input(0), int)) RunWithType<int>(); else if (XIsType(Input(0), int)) RunWithType<int>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "int32" }); else if (XIsType(Input(0), int64_t)) RunWithType<int64_t>();
else if (XIsType(Input(0), float16)) RunWithType<float16>();
else if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), double)) RunWithType<double>();
else LOG(FATAL) << DTypeHelper(Input(0), {
"bool", "int8", "uint8", "int32", "int64",
"float16", "float32", "float64",
});
} }
DEPLOY_CPU(Gather); DEPLOY_CPU(Gather);
...@@ -54,18 +68,17 @@ OPERATOR_SCHEMA(Gather).NumInputs(2).NumOutputs(1); ...@@ -54,18 +68,17 @@ OPERATOR_SCHEMA(Gather).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void GatherGradientOp<Context>::RunWithType() { void GatherGradientOp<Context>::RunWithType() {
auto* indices = Input(1).template data<int, Context>(); auto* indices = Input(1).template data<int64_t, Context>();
auto* dYdata = Input(-1).template data<T, Context>(); auto* dYdata = Input(-1).template data<T, Context>();
auto* dXdata = Output(0)->template mutable_data<T, Context>();
T* dXdata = nullptr; // Zero the gradients Optionally
if (!acc_grad) { if (zero_grad) {
dXdata = Output(0)->template mutable_data<T, Context>(); math::Set(Output(0)->count(),
math::Set(Output(0)->count(), cast::to<T>(0.f), dXdata, ctx()); cast::to<T>(0.f), dXdata, ctx());
} else {
dXdata = Output(0)->template mutable_data<T, Context>();
} }
kernel::GatherGrad(Input(-1).count(), kernel::GatherGrad(
outer_dim, inner_dim, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, x_slice_dim, y_slice_dim,
indices, dYdata, dXdata, ctx()); indices, dYdata, dXdata, ctx());
...@@ -82,12 +95,20 @@ void GatherGradientOp<Context>::RunOnDevice() { ...@@ -82,12 +95,20 @@ void GatherGradientOp<Context>::RunOnDevice() {
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
CHECK(Input(1).template IsType<int>()) CHECK(Input(1).template IsType<int64_t>())
<< "\nThe type of indices should be int32."; << "\nThe type of indices should be int64.";
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), int8_t)) RunWithType<int8_t>();
else if (XIsType(Input(0), uint8_t)) RunWithType<uint8_t>();
else if (XIsType(Input(0), int)) RunWithType<int>(); else if (XIsType(Input(0), int)) RunWithType<int>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "int32" }); else if (XIsType(Input(0), int64_t)) RunWithType<int64_t>();
else if (XIsType(Input(0), float16)) RunWithType<float16>();
else if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), double)) RunWithType<double>();
else LOG(FATAL) << DTypeHelper(Input(0), {
"int8", "uint8", "int32", "int64",
"float16", "float32", "float64",
});
} }
DEPLOY_CPU(GatherGradient); DEPLOY_CPU(GatherGradient);
......
...@@ -15,6 +15,27 @@ void DropBlock2dOp<Context>::RunWithType() { ...@@ -15,6 +15,27 @@ void DropBlock2dOp<Context>::RunWithType() {
Output(0)->count(), Ydata, Xdata); Output(0)->count(), Ydata, Xdata);
} }
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
if (data_format == "NCHW") {
n = Input(0).dim(0), c = Input(0).dim(1);
h = Input(0).dim(2), w = Input(0).dim(3);
} else if (data_format == "NHWC") {
n = Input(0).dim(0), c = Input(0).dim(-1);
h = Input(0).dim(1), w = Input(0).dim(2);
}
seed_h = h - block_size + 1;
seed_w = w - block_size + 1;
CHECK(seed_h > 0 && seed_w > 0)
<< "\nExcepted block_size <= feat_size.";
if (decrement > 0 && apply_prob > keep_prob()) {
apply_prob -= decrement;
} else { apply_prob = keep_prob(); }
gamma = (1.f - apply_prob) / (block_size * block_size);
gamma *= (alpha * (h * w) / (seed_h * seed_w));
auto* mask = ws()->CreateTensor(mount_name( auto* mask = ws()->CreateTensor(mount_name(
"drop_block/mask"))->ReshapeLike(Input(0)); "drop_block/mask"))->ReshapeLike(Input(0));
auto* norm = ws()->CreateTensor(mount_name( auto* norm = ws()->CreateTensor(mount_name(
...@@ -58,29 +79,8 @@ void DropBlock2dOp<Context>::RunWithType() { ...@@ -58,29 +79,8 @@ void DropBlock2dOp<Context>::RunWithType() {
template <class Context> template <class Context>
void DropBlock2dOp<Context>::RunOnDevice() { void DropBlock2dOp<Context>::RunOnDevice() {
if (data_format == "NCHW") {
n = Input(0).dim(0), c = Input(0).dim(1);
h = Input(0).dim(2), w = Input(0).dim(3);
} else if (data_format == "NHWC") {
n = Input(0).dim(0), c = Input(0).dim(-1);
h = Input(0).dim(1), w = Input(0).dim(2);
}
seed_h = h - block_size + 1;
seed_w = w - block_size + 1;
CHECK(seed_h > 0 && seed_w > 0)
<< "\nExcepted block_size <= feat_size.";
Output(0)->ReshapeLike(Input(0)); Output(0)->ReshapeLike(Input(0));
if (decrement > 0 && apply_prob > keep_prob()) {
apply_prob -= decrement;
} else { apply_prob = keep_prob(); }
gamma = (1.f - apply_prob) / (block_size * block_size);
gamma *= (alpha * (h * w) / (seed_h * seed_w));
if (XIsType(Input(0), float)) RunWithType<float>(); if (XIsType(Input(0), float)) RunWithType<float>();
else if (XIsType(Input(0), float16)) RunWithType<float16>(); else if (XIsType(Input(0), float16)) RunWithType<float16>();
else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" }); else LOG(FATAL) << DTypeHelper(Input(0), { "float32", "float16" });
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!