Commit 904b59fd by Ting PAN

Add Contrib ops

1 parent 36f27485
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_PROPOSAL_OP_H_
#define DRAGON_OPERATORS_MISC_PROPOSAL_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ProposalOp final : public Operator<Context> {
public:
ProposalOp(const OperatorDef& op_def, Workspace* ws)
: base_size_(OperatorBase::GetSingleArg<int>("base_size", 16)),
min_size_(OperatorBase::GetSingleArg<int>("min_size", 16)),
feat_stride_(OperatorBase::GetSingleArg<int>("feat_stride", -1)),
pre_nms_topn_(OperatorBase::GetSingleArg<int>("pre_nms_topn", 12000)),
post_nms_topn_(OperatorBase::GetSingleArg<int>("post_nms_topn", 2000)),
nms_thresh_(OperatorBase::GetSingleArg<float>("nms_thresh", (float)0.7)),
Operator<Context>(op_def, ws) { Setup(); }
void Setup();
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
int min_size_, base_size_, feat_stride_;
int pre_nms_topn_, post_nms_topn_;
float nms_thresh_;
Tensor anchors_, roi_indices_, proposals_, nms_mask_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MISC_COMPARE_OP_H_
\ No newline at end of file
...@@ -64,6 +64,7 @@ Custom ...@@ -64,6 +64,7 @@ Custom
operators/custom/data_process operators/custom/data_process
operators/custom/vec_mult operators/custom/vec_mult
========================================= ===================================================================== ========================================= =====================================================================
List Brief List Brief
========================================= ===================================================================== ========================================= =====================================================================
...@@ -73,6 +74,19 @@ List Brief ...@@ -73,6 +74,19 @@ List Brief
========================================= ===================================================================== ========================================= =====================================================================
Contrib
-------
.. toctree::
:hidden:
operators/contrib/rcnn
========================================= =====================================================================
List Brief
========================================= =====================================================================
`dragon.operators.contrib.rcnn`_ Contrib ops for R-CNN.
========================================= =====================================================================
.. _dragon.operators.data: operators/data.html .. _dragon.operators.data: operators/data.html
...@@ -91,4 +105,6 @@ List Brief ...@@ -91,4 +105,6 @@ List Brief
.. _dragon.io: io.html .. _dragon.io: io.html
.. _dragon.operators.custom.minibatch: operators/custom/minibatch.html .. _dragon.operators.custom.minibatch: operators/custom/minibatch.html
.. _dragon.operators.custom.data_process: operators/custom/data_process.html .. _dragon.operators.custom.data_process: operators/custom/data_process.html
.. _dragon.operators.custom.vec_mult: operators/custom/vec_mult.html .. _dragon.operators.custom.vec_mult: operators/custom/vec_mult.html
\ No newline at end of file .. _dragon.operators.contrib.rcnn: operators/contrib/rcnn.html
============
:mod:`R-CNN`
============
.. toctree::
:hidden:
.. automodule:: dragon.operators.contrib.rcnn.ops
:members:
\ No newline at end of file
...@@ -35,8 +35,8 @@ List Brief ...@@ -35,8 +35,8 @@ List Brief
`Conv2d`_ 2d Convolution. `Conv2d`_ 2d Convolution.
`Conv2dTranspose`_ 2d Deconvolution. `Conv2dTranspose`_ 2d Deconvolution.
`Pool2d`_ 2d Pooling, MAX or AVG. `Pool2d`_ 2d Pooling, MAX or AVG.
`ROIPooling`_ ROIPoolin(MAX), introduced by `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_. `ROIPooling`_ ROIPooling(MAX), introduced by `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
`ROIAlign`_ ROIAlign(MAX), introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_. `ROIAlign`_ ROIAlign(AVG), introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
`LRN`_ Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_. `LRN`_ Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_.
`NNResize`_ Resize the image with Nearest-Neighbor method. `NNResize`_ Resize the image with Nearest-Neighbor method.
`BilinearResize`_ Resize the image with Bi-linear method. `BilinearResize`_ Resize the image with Bi-linear method.
...@@ -167,6 +167,14 @@ List Brief ...@@ -167,6 +167,14 @@ List Brief
`Proposal`_ Generate Regional Proposals, introduced by `[Ren et.al, 2015] <https://arxiv.org/abs/1506.01497>`_. `Proposal`_ Generate Regional Proposals, introduced by `[Ren et.al, 2015] <https://arxiv.org/abs/1506.01497>`_.
================= ====================================================================== ================= ======================================================================
Contrib
-------
================= ======================================================================
List Brief
================= ======================================================================
`Proposal`_ Generate Regional Proposals, introduced by `[Ren et.al, 2015] <https://arxiv.org/abs/1506.01497>`_.
================= ======================================================================
Cast Cast
---- ----
================= ====================================================================== ================= ======================================================================
...@@ -279,7 +287,8 @@ List Brief ...@@ -279,7 +287,8 @@ List Brief
.. _Accuracy: operators/misc.html#dragon.operators.misc.Accuracy .. _Accuracy: operators/misc.html#dragon.operators.misc.Accuracy
.. _StopGradient: operators/misc.html#dragon.operators.misc.StopGradient .. _StopGradient: operators/misc.html#dragon.operators.misc.StopGradient
.. _MovingAverage: operators/misc.html#dragon.operators.misc.MovingAverage .. _MovingAverage: operators/misc.html#dragon.operators.misc.MovingAverage
.. _Proposal: operators/misc.html#dragon.operators.misc.Proposal
.. _Proposal: operators/contrib/rcnn.html#dragon.operators.contrib.rcnn.ops.Proposal
.. _FloatToHalf: operators/cast.html#dragon.operators.misc.FloatToHalf .. _FloatToHalf: operators/cast.html#dragon.operators.misc.FloatToHalf
......
...@@ -282,6 +282,8 @@ API Reference ...@@ -282,6 +282,8 @@ API Reference
.. _NormalizeParameter.scale_filler: https://github.com/weiliu89/caffe/blob/f5eac041aafbc8b86954bd161710f65e70042ce6/src/caffe/proto/caffe.proto#L1332 .. _NormalizeParameter.scale_filler: https://github.com/weiliu89/caffe/blob/f5eac041aafbc8b86954bd161710f65e70042ce6/src/caffe/proto/caffe.proto#L1332
.. _NormalizeParameter.channel_shared: https://github.com/weiliu89/caffe/blob/f5eac041aafbc8b86954bd161710f65e70042ce6/src/caffe/proto/caffe.proto#L1334 .. _NormalizeParameter.channel_shared: https://github.com/weiliu89/caffe/blob/f5eac041aafbc8b86954bd161710f65e70042ce6/src/caffe/proto/caffe.proto#L1334
.. _NormalizeParameter.eps: https://github.com/weiliu89/caffe/blob/f5eac041aafbc8b86954bd161710f65e70042ce6/src/caffe/proto/caffe.proto#L1336 .. _NormalizeParameter.eps: https://github.com/weiliu89/caffe/blob/f5eac041aafbc8b86954bd161710f65e70042ce6/src/caffe/proto/caffe.proto#L1336
.. _ReductionParameter.operation: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L973
.. _ReductionParameter.axis: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L988
.. _TileParameter.multiples: https://github.com/neopenx/Dragon/blob/6eeac5fec58ed3d0d79f0b4003471e4a641c72f4/Dragon/python/dragon/vm/caffe/proto/caffe.proto#L1173 .. _TileParameter.multiples: https://github.com/neopenx/Dragon/blob/6eeac5fec58ed3d0d79f0b4003471e4a641c72f4/Dragon/python/dragon/vm/caffe/proto/caffe.proto#L1173
.. _ExpandDimsParameter.axis: https://github.com/neopenx/Dragon/blob/6eeac5fec58ed3d0d79f0b4003471e4a641c72f4/Dragon/python/dragon/vm/caffe/proto/caffe.proto#L1480 .. _ExpandDimsParameter.axis: https://github.com/neopenx/Dragon/blob/6eeac5fec58ed3d0d79f0b4003471e4a641c72f4/Dragon/python/dragon/vm/caffe/proto/caffe.proto#L1480
.. _ProposalParameter.feat_stride: https://github.com/sanghoon/caffe/blob/6068dd04ea93cca9fcee036628fdb3ea95b4ebcd/src/caffe/proto/caffe.proto#L431 .. _ProposalParameter.feat_stride: https://github.com/sanghoon/caffe/blob/6068dd04ea93cca9fcee036628fdb3ea95b4ebcd/src/caffe/proto/caffe.proto#L431
......
# --------------------------------------------------------
# Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# R-CNN ops
from dragon.operators.contrib.rcnn.ops import *
# --------------------------------------------------------
# Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
\ No newline at end of file
# --------------------------------------------------------
# Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.operators import *
def Proposal(inputs, strides, ratios, scales,
pre_nms_top_n=6000, post_nms_top_n=300,
nms_thresh=0.7, min_size=16,
min_level=2, max_level=5,
canonical_scale=224, canonical_level=4, **kwargs):
"""Generate Regional Proposals, introduced by `[Ren et.al, 2015] <https://arxiv.org/abs/1506.01497>`_.
Multi-Level proposals was introduced by `[Lin et.al, 2017] <https://arxiv.org/abs/1612.03144>`_.
For single level proposals(e.g. C4), the inputs should be: [cls_probs, bbox_deltas, im_info].
For multiple level proposals(e.g. FPN), the inputs should be: [cls_score/Px, ...] + [cls_probs, bbox_deltas, im_info].
Parameters
----------
inputs : list of Tensor
The inputs.
strides : list of int
The strides of anchors.
ratios : list of float
The ratios of anchors.
scales : list of float
The scales of anchors.
pre_nms_top_n : int
The number of anchors before nms.
post_nms_top_n : int
The number of anchors after nms.
nms_thresh : float
The threshold of nms.
min_size : int
The min size of anchors.
min_level : int
Finest level of the FPN pyramid.
max_level : int
Coarsest level of the FPN pyramid.
canonical_scale : int
The baseline scale of mapping policy.
canonical_level : int
Heuristic level of the canonical scale.
Returns
-------
Tensor
The proposals.
"""
CheckInputs(inputs, 3, INT_MAX)
arguments = ParseArguments(locals())
num_levels = (max_level - min_level) + 1
num_levels = 1 if len(inputs) == 3 else num_levels
outputs = Tensor.CreateOperator(nout=num_levels, op_type='Proposal', **arguments)
return outputs
\ No newline at end of file
...@@ -157,44 +157,4 @@ def MovingAverage(inputs, decay, **kwargs): ...@@ -157,44 +157,4 @@ def MovingAverage(inputs, decay, **kwargs):
output = Tensor.CreateOperator(op_type='MovingAverage', output = Tensor.CreateOperator(op_type='MovingAverage',
existing_outputs=variable, **arguments) existing_outputs=variable, **arguments)
return output
def Proposal(inputs, ratios, scales,
base_size=16, min_size=16, feat_stride=16,
pre_nms_topn=12000, post_nms_topn=2000, nms_thresh=0.7, **kwargs):
"""Generate Regional Proposals, introduced by `[Ren et.al, 2015] <https://arxiv.org/abs/1506.01497>`_.
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, anchors, im_info].
ratios : list of float
The ratios of anchors.
scales : list of float
The scales of anchors.
base_size : int
The base size of anchors.
min_size : int
The min size of anchors.
feat_stride : int
The stride of input. Default is ``16`` (The 4th down-samples).
pre_nms_topn : int
The number of anchors before nms.
post_nms_topn : int
The number of anchors after nms.
nms_thresh : float
The threshold of nms.
Returns
=------
Tensor
The proposals.
"""
CheckInputs(inputs, 3)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='Proposal', **arguments)
return output return output
\ No newline at end of file
...@@ -272,9 +272,14 @@ def Reduce(inputs, axis=-1, operation='NONE', keep_dims=False, **kwargs): ...@@ -272,9 +272,14 @@ def Reduce(inputs, axis=-1, operation='NONE', keep_dims=False, **kwargs):
output = Tensor.CreateOperator(nout=1, op_type='Reduce', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Reduce', **arguments)
if inputs.shape is not None: if inputs.shape is not None:
if axis == -1: output.shape = [1] output.shape = inputs.shape[:]
if axis == -1:
if keep_dims:
for i in xrange(len(output.shape)):
output.shape[i] = 1
else: output.shape = [1]
else: else:
output.shape = inputs.shape[:]
if keep_dims: output.shape[axis] = 1 if keep_dims: output.shape[axis] = 1
else: del output.shape[axis] else: del output.shape[axis]
......
...@@ -329,7 +329,7 @@ def ROIPooling(inputs, pool_h, pool_w, spatial_scale, **kwargs): ...@@ -329,7 +329,7 @@ def ROIPooling(inputs, pool_h, pool_w, spatial_scale, **kwargs):
def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, sampling_ratio=2, **kwargs): def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, sampling_ratio=2, **kwargs):
"""Max ROIAlign, introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_. """AVG ROIAlign, introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
The first dimension of input must be ``1``. The first dimension of input must be ``1``.
......
...@@ -21,6 +21,7 @@ from .operators import mpi ...@@ -21,6 +21,7 @@ from .operators import mpi
from .operators import ndarray from .operators import ndarray
from .operators import norm from .operators import norm
from .operators import recurrent from .operators import recurrent
from .operators import contrib
# data # data
LMDBData = data.LMDBData LMDBData = data.LMDBData
...@@ -128,11 +129,13 @@ Template = misc.Template ...@@ -128,11 +129,13 @@ Template = misc.Template
Accuracy = misc.Accuracy Accuracy = misc.Accuracy
StopGradient = misc.StopGradient StopGradient = misc.StopGradient
MovingAverage = misc.MovingAverage MovingAverage = misc.MovingAverage
Proposal = misc.Proposal
# cast # cast
FloatToHalf = cast.FloatToHalf FloatToHalf = cast.FloatToHalf
# mpi # mpi
MPIBroadcast = mpi.MPIBroadcast MPIBroadcast = mpi.MPIBroadcast
MPIGather = mpi.MPIGather MPIGather = mpi.MPIGather
\ No newline at end of file
# contrib
Proposal = contrib.Proposal # R-CNN
\ No newline at end of file
...@@ -272,7 +272,7 @@ class GatherLayer(Layer): ...@@ -272,7 +272,7 @@ class GatherLayer(Layer):
Parameters Parameters
---------- ----------
axis : int axis : int
The axis for gathering. Refer `GatherParameter.axis`_. The axis for gathering. Refer ``GatherParameter.axis``.
""" """
def __init__(self, LayerParameter): def __init__(self, LayerParameter):
...@@ -623,35 +623,44 @@ class ProposalLayer(Layer): ...@@ -623,35 +623,44 @@ class ProposalLayer(Layer):
Parameters Parameters
---------- ----------
feat_stride : int stride : list of int
The stride of input. Refer `ProposalParameter.feat_stride`_. The stride of anchors. Refer ``ProposalParameter.stride``.
base_size : int
The base size of anchors. Refer `ProposalParameter.base_size`_.
min_size : int
The min size of anchors. Refer `ProposalParameter.min_size`_.
ratio : list of float
The ratios of anchors. Refer `ProposalParameter.ratio`_.
scale : list of float scale : list of float
The scales of anchors. Refer `ProposalParameter.scale`_. The scales of anchors. Refer `ProposalParameter.scale`_.
pre_nms_topn : int ratio : list of float
The num of anchors before nms. Refer `ProposalParameter.pre_nms_topn`_. The ratios of anchors. Refer `ProposalParameter.ratio`_.
post_nms_topn : int pre_nms_top_n : int
The num of anchors before nms. Refer `ProposalParameter.pre_nms_topn`_.
post_nms_top_n : int
The num of anchors after nms. Refer `ProposalParameter.post_nms_topn`_. The num of anchors after nms. Refer `ProposalParameter.post_nms_topn`_.
nms_thresh : float nms_thresh : float
The threshold of nms. Refer `ProposalParameter.nms_thresh`_. The threshold of nms. Refer `ProposalParameter.nms_thresh`_.
min_size : int
The min size of anchors. Refer `ProposalParameter.min_size`_.
min_level : int
Finest level of the FPN pyramid. Refer ``ProposalParameter.min_level``.
max_level : int
Coarsest level of the FPN pyramid. Refer ``ProposalParameter.max_level``.
canonical_scale : int
The baseline scale of mapping policy. Refer ``ProposalParameter.canonical_scale``.
canonical_level : int
Heuristic level of the canonical scale. Refer ``ProposalParameter.canonical_level``.
""" """
def __init__(self, LayerParameter): def __init__(self, LayerParameter):
super(ProposalLayer, self).__init__(LayerParameter) super(ProposalLayer, self).__init__(LayerParameter)
param = LayerParameter.proposal_param param = LayerParameter.proposal_param
self._param = {'base_size': param.base_size, self._param = {'strides': param.stride,
'min_size': param.min_size,
'feat_stride': param.feat_stride,
'pre_nms_topn': param.pre_nms_topn,
'post_nms_topn': param.post_nms_topn,
'nms_thresh': param.nms_thresh,
'ratios': param.ratio, 'ratios': param.ratio,
'scales': param.scale} 'scales': param.scale,
'pre_nms_top_n': param.pre_nms_top_n,
'post_nms_top_n': param.post_nms_top_n,
'nms_thresh': param.nms_thresh,
'min_size': param.min_size,
'min_level': param.min_level,
'max_level': param.max_level,
'canonical_scale': param.canonical_scale,
'canonical_level': param.canonical_level}
def Setup(self, bottom): def Setup(self, bottom):
super(ProposalLayer, self).Setup(bottom) super(ProposalLayer, self).Setup(bottom)
......
...@@ -1474,14 +1474,17 @@ message ExpandDimsParameter { ...@@ -1474,14 +1474,17 @@ message ExpandDimsParameter {
} }
message ProposalParameter { message ProposalParameter {
optional uint32 feat_stride = 1 [default = 16]; repeated int32 stride = 1;
optional uint32 base_size = 2 [default = 16]; repeated float ratio = 2;
optional uint32 min_size = 3 [default = 16]; repeated float scale = 3;
repeated float ratio = 4; optional uint32 pre_nms_top_n = 4 [default = 6000];
repeated float scale = 5; optional uint32 post_nms_top_n = 5 [default = 300];
optional uint32 pre_nms_topn = 6 [default = 6000]; optional float nms_thresh = 6 [default = 0.7];
optional uint32 post_nms_topn = 7 [default = 300]; optional uint32 min_size = 7 [default = 16];
optional float nms_thresh = 8 [default = 0.7]; optional int32 min_level = 8 [default = 2];
optional int32 max_level = 9 [default = 5];
optional int32 canonical_scale = 10 [default = 224];
optional int32 canonical_level = 11 [default = 4];
} }
message BatchRenormParameter { message BatchRenormParameter {
......
...@@ -36,7 +36,7 @@ find_packages('dragon') ...@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules() find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.1.3', version='0.2.1.4',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon', url='https://github.com/neopenx/Dragon',
author='Ting Pan', author='Ting Pan',
......
#include "core/context.h"
#include "contrib/rcnn/bbox_utils.h"
namespace dragon {
namespace rcnn {
/******************** Proposal ********************/
template <> void GenerateProposals<float, CPUContext>(const int A,
const int feat_h,
const int feat_w,
const int stride,
const float im_h, const float im_w,
const float min_box_h, const float min_box_w,
const float* scores,
const float* bbox_deltas,
const float* anchors,
float* proposals) {
float* proposal = proposals;
const int K = feat_h * feat_w;
for (int h = 0; h < feat_h; ++h) {
for (int w = 0; w < feat_w; ++w) {
const float x = w * stride;
const float y = h * stride;
// bbox_deltas: [1, A, 4, K]
const float* bbox_delta = bbox_deltas + h * feat_w + w;
// scores: [1, A, K]
const float* score = scores + h * feat_w + w;
for (int a = 0; a < A; ++a) {
const float dx = bbox_delta[(a * 4 + 0) * K];
const float dy = bbox_delta[(a * 4 + 1) * K];
const float d_log_w = bbox_delta[(a * 4 + 2) * K];
const float d_log_h = bbox_delta[(a * 4 + 3) * K];
proposal[0] = x + anchors[a * 4 + 0];
proposal[1] = y + anchors[a * 4 + 1];
proposal[2] = x + anchors[a * 4 + 2];
proposal[3] = y + anchors[a * 4 + 3];
proposal[4] = BBoxTransform<float>(dx, dy,
d_log_w, d_log_h,
im_w, im_h,
min_box_w, min_box_h,
proposal) * score[a * K];
proposal += 5;
}
}
}
}
template <> void GenerateProposals_v2<float, CPUContext>(const int total_anchors,
const float im_h, const float im_w,
const float min_box_h, const float min_box_w,
const float* scores,
const float* bbox_deltas,
float* proposals) {
float* proposal = proposals;
for (int i = 0; i < total_anchors; ++i) {
// bbox_deltas: [1, 4, total_anchors]
// scores: [1, total_anchors]
const float dx = bbox_deltas[i];
const float dy = bbox_deltas[total_anchors + i];
const float d_log_w = bbox_deltas[2 * total_anchors + i];
const float d_log_h = bbox_deltas[3 * total_anchors + i];
proposal[4] = BBoxTransform<float>(dx, dy,
d_log_w, d_log_h,
im_w, im_h,
min_box_w, min_box_h,
proposal) * scores[i];
proposal += 5;
}
}
/******************** NMS ********************/
template <typename T>
T iou(const T A[], const T B[]) {
if (A[0] > B[2] || A[1] > B[3] || A[2] < B[0] || A[3] < B[1]) return 0;
const T x1 = std::max(A[0], B[0]);
const T y1 = std::max(A[1], B[1]);
const T x2 = std::min(A[2], B[2]);
const T y2 = std::min(A[3], B[3]);
const T width = std::max((T)0, x2 - x1 + (T)1);
const T height = std::max((T)0, y2 - y1 + (T)1);
const T area = width * height;
const T A_area = (A[2] - A[0] + (T)1) * (A[3] - A[1] + (T)1);
const T B_area = (B[2] - B[0] + (T)1) * (B[3] - B[1] + (T)1);
return area / (A_area + B_area - area);
}
template <> void NMS<float, CPUContext>(const int num_boxes,
const int max_keeps,
const float thresh,
const float* proposals,
int* roi_indices,
int& num_rois,
Tensor* mask) {
int count = 0;
std::vector<char> is_dead(num_boxes);
for (int i = 0; i < num_boxes; ++i) is_dead[i] = 0;
for (int i = 0; i < num_boxes; ++i) {
if (is_dead[i]) continue;
roi_indices[count++] = i;
if (count == max_keeps) break;
for (int j = i + 1; j < num_boxes; ++j)
if (!is_dead[j] && iou(&proposals[i * 5], &proposals[j * 5]) > thresh) is_dead[j] = 1;
}
num_rois = count;
}
} // namespace rcnn
} // namespace dragon
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CONTRIB_RCNN_BBOX_UTILS_H_
#define DRAGON_CONTRIB_RCNN_BBOX_UTILS_H_
#include "core/context.h"
#include "core/operator.h"
namespace dragon {
namespace rcnn {
#define ROUND(x) ((int)((x) + (T)0.5))
/******************** BBox ********************/
template <typename T>
int BBoxTransform(const T dx, const T dy,
const T d_log_w, const T d_log_h,
const T im_w, const T im_h,
const T min_box_w, const T min_box_h,
T* bbox) {
const T w = bbox[2] - bbox[0] + (T)1;
const T h = bbox[3] - bbox[1] + (T)1;
const T ctr_x = bbox[0] + (T)0.5 * w;
const T ctr_y = bbox[1] + (T)0.5 * h;
const T pred_ctr_x = dx * w + ctr_x;
const T pred_ctr_y = dy * h + ctr_y;
const T pred_w = exp(d_log_w) * w;
const T pred_h = exp(d_log_h) * h;
bbox[0] = pred_ctr_x - (T)0.5 * pred_w;
bbox[1] = pred_ctr_y - (T)0.5 * pred_h;
bbox[2] = pred_ctr_x + (T)0.5 * pred_w;
bbox[3] = pred_ctr_y + (T)0.5 * pred_h;
bbox[0] = std::max((T)0, std::min(bbox[0], im_w - (T)1));
bbox[1] = std::max((T)0, std::min(bbox[1], im_h - (T)1));
bbox[2] = std::max((T)0, std::min(bbox[2], im_w - (T)1));
bbox[3] = std::max((T)0, std::min(bbox[3], im_h - (T)1));
const T bbox_w = bbox[2] - bbox[0] + (T)1;
const T bbox_h = bbox[3] - bbox[1] + (T)1;
return (bbox_w >= min_box_w) * (bbox_h >= min_box_h);
}
/******************** Anchor ********************/
template <typename T>
void GenerateAnchors(int base_size,
const int num_ratios,
const int num_scales,
const T* ratios,
const T* scales,
T* anchors) {
const T base_area = (T)(base_size * base_size);
const T center = (T)0.5 * (base_size - (T)1);
T* offset_anchors = anchors;
for (int i = 0; i < num_ratios; ++i) {
const T ratio_w = (T)ROUND(sqrt(base_area / ratios[i]));
const T ratio_h = (T)ROUND(ratio_w * ratios[i]);
for (int j = 0; j < num_scales; ++j) {
const T scale_w = (T)0.5 * (ratio_w * scales[j] - (T)1);
const T scale_h = (T)0.5 * (ratio_h * scales[j] - (T)1);
offset_anchors[0] = center - scale_w;
offset_anchors[1] = center - scale_h;
offset_anchors[2] = center + scale_w;
offset_anchors[3] = center + scale_h;
offset_anchors += 4;
}
}
}
template <typename T>
void GenerateGridAnchors(const int A, const int feat_h, const int feat_w,
const int stride,
const T* anchors,
T* proposals) {
T* proposal = proposals;
for (int a = 0; a < A; ++a) {
for (int h = 0; h < feat_h; ++h) {
for (int w = 0; w < feat_w; ++w) {
const T x = w * stride;
const T y = h * stride;
proposal[0] = x + anchors[a * 4 + 0];
proposal[1] = y + anchors[a * 4 + 1];
proposal[2] = x + anchors[a * 4 + 2];
proposal[3] = y + anchors[a * 4 + 3];
proposal += 5;
}
}
}
}
/******************** Proposal ********************/
template <typename T, class Context>
void GenerateProposals(const int A, const int feat_h, const int feat_w,
const int stride,
const float im_h, const float im_w,
const float min_box_h, const float min_box_w,
const T* scores,
const T* bbox_deltas,
const T* anchors,
T* proposals);
template <typename T, class Context>
void GenerateProposals_v2(const int total_anchors,
const float im_h, const float im_w,
const float min_box_h, const float min_box_w,
const T* scores,
const T* bbox_deltas,
T* proposals);
template <typename T>
void SortProposals(const int start,
const int end,
const int num_top,
T* proposals) {
const T pivot_score = proposals[start * 5 + 4];
int left = start + 1, right = end;
while (left <= right) {
while (left <= end && proposals[left * 5 + 4] >= pivot_score) ++left;
while (right > start && proposals[right * 5 + 4] <= pivot_score) --right;
if (left <= right) {
for (int i = 0; i < 5; ++i)
std::swap(proposals[left * 5 + i], proposals[right * 5 + i]);
++left;
--right;
}
}
if (right > start) {
for (int i = 0; i < 5; ++i)
std::swap(proposals[start * 5 + i], proposals[right * 5 + i]);
}
if (start < right - 1) SortProposals(start, right - 1, num_top, proposals);
if (right + 1 < num_top && right + 1 < end) SortProposals(right + 1, end, num_top, proposals);
}
template <typename T>
void RetrieveRoIs(const int num_rois,
const int roi_batch_ind,
const T* proposals,
const int* roi_indices,
T* rois) {
for (int i = 0; i < num_rois; ++i) {
const T* proposal = proposals + roi_indices[i] * 5;
rois[i * 5 + 0] = roi_batch_ind;
rois[i * 5 + 1] = proposal[0];
rois[i * 5 + 2] = proposal[1];
rois[i * 5 + 3] = proposal[2];
rois[i * 5 + 4] = proposal[3];
}
}
template <typename T>
int roi_level(const int min_level, // e.g. 2
const int max_level, // e.g. 5
const int canonical_level, // e.g. 4
const int canonical_scale, // e.g. 224
T* roi) {
T w = roi[3] - roi[1] + 1;
T h = roi[4] - roi[2] + 1;
// reference the settings of paper
int level = canonical_level + std::log(std::max(std::sqrt(w * h), T(1)) / T(canonical_scale));
return std::min(max_level, std::max(min_level, level));
}
template <typename T>
void CollectRoIs(const int num_rois,
const int min_level,
const int max_level,
const int canonical_level,
const int canonical_scale,
const T* rois,
vector< vector<TIndex> >& roi_bins) {
const T* roi = rois;
for (int i = 0; i < num_rois; ++i) {
int bin_idx = roi_level(min_level, max_level,
canonical_level, canonical_scale,
roi);
bin_idx = std::max(bin_idx - min_level, 0);
roi_bins[bin_idx].push_back(i);
roi += 5;
}
}
template <typename T>
void DistributeRoIs(const vector< vector<TIndex> >& roi_bins,
const T* rois,
vector<T*> outputs) {
for (int i = 0; i < roi_bins.size(); i++) {
auto* y = outputs[i];
if (roi_bins[i].size() == 0) {
// fake a tiny roi to avoid empty roi pooling
y[0] = 0, y[1] = 0, y[2] = 0, y[3] = 1, y[4] = 1;
} else {
for (int j = 0; j < roi_bins[i].size(); ++j) {
const T* roi = rois + roi_bins[i][j] * 5;
for (int k = 0; k < 5; ++k) y[k] = roi[k];
y += 5;
}
}
}
}
/******************** NMS ********************/
template <typename T, class Context>
void NMS(const int num_boxes,
const int max_keeps,
const T thresh,
const T* proposals,
int* roi_indices,
int& num_rois,
Tensor* mask);
} // namespace rcnn
} // namespace dragon
#endif // DRAGON_OPERATORS_CONTRIB_BBOX_UTILS_H_
\ No newline at end of file
#include "contrib/rcnn/proposal_op.h"
#include "contrib/rcnn/bbox_utils.h"
namespace dragon {
template <class Context> template <typename T>
void ProposalOp<Context>::RunWithType() {
TIndex total_rois = 0;
auto* im_info = input(-1).template data<T, CPUContext>();
auto* Ydata = output(0)->template mutable_data<T, CPUContext>();
for (int n = 0; n < num_images; ++n) {
const T im_height = im_info[0];
const T im_width = im_info[1];
const T scale = im_info[2];
const T min_box_h = min_size * scale;
const T min_box_w = min_size * scale;
int num_rois = 0;
if (strides.size() == 1) {
// case 1: single stride (Faster R-CNN)
const TIndex feat_height = input(0).dim(2);
const TIndex feat_width = input(0).dim(3);
const TIndex K = feat_height * feat_width;
const TIndex A = ratios.size() * scales.size();
const TIndex num_proposals = K * A;
const TIndex pre_nms_topn = std::min(num_proposals, pre_nms_top_n);
anchors_.Reshape(vector<TIndex>({ A, 4 }));
proposals_.Reshape(vector<TIndex>({ num_proposals, 5 }));
rcnn::GenerateAnchors<T>(strides[0], (int)ratios.size(), (int)scales.size(),
&ratios[0], &scales[0],
anchors_.template mutable_data<T, CPUContext>());
rcnn::GenerateProposals<T, Context>(A, feat_height, feat_width, strides[0],
im_height, im_width, min_box_h, min_box_w,
input(0).template data<T, Context>() + num_proposals,
input(1).template data<T, Context>(),
anchors_.template mutable_data<T, Context>(),
proposals_.template mutable_data<T, Context>());
rcnn::SortProposals(0, num_proposals - 1, pre_nms_top_n,
proposals_.template mutable_data<T, CPUContext>());
rcnn::NMS<T, Context>(pre_nms_topn, post_nms_top_n, nms_thresh,
proposals_.template mutable_data<T, Context>(),
roi_indices_.template mutable_data<int, CPUContext>(),
num_rois,
&nms_mask_);
rcnn::RetrieveRoIs<T>(num_rois, n, proposals_.template mutable_data<T, CPUContext>(),
roi_indices_.template mutable_data<int, CPUContext>(),
Ydata);
} else if (strides.size() > 1) {
// case 2: multiple stride (FPN / Mask R-CNN / RetinaNet)
CHECK_EQ(strides.size(), (int)InputSize() - 3)
<< "\nGiven " << strides.size() << " strides and "
<< InputSize() - 3 << " feature inputs";
CHECK_EQ(strides.size(), scales.size())
<< "\nGiven " << strides.size() << " strides and "
<< scales.size() << " scales";
// cls_probs: [1, 2, total_proposals]
// bbox_deltas: [1, 4, total_proposals]
TIndex total_proposals = input(-3).dim(2), acc_proposals = 0;
const TIndex pre_nms_topn = std::min(total_proposals, pre_nms_top_n);;
proposals_.Reshape(vector<TIndex>({ total_proposals, 5 }));
auto* proposals = proposals_.template mutable_data<T, CPUContext>();
for (int i = 0; i < strides.size(); i++) {
const TIndex feat_height = input(i).dim(2);
const TIndex feat_width = input(i).dim(3);
const TIndex K = feat_height * feat_width;
const TIndex A = ratios.size();
const TIndex num_proposals = K * A;
anchors_.Reshape(vector<TIndex>({ A, 4 }));
rcnn::GenerateAnchors<T>(strides[i], (int)ratios.size(), 1,
&ratios[0], &scales[0],
anchors_.template mutable_data<T, CPUContext>());
rcnn::GenerateGridAnchors<T>(A, feat_height, feat_width, strides[i],
anchors_.template mutable_data<T, CPUContext>(),
proposals);
acc_proposals += num_proposals;
proposals += (num_proposals * 5);
}
CHECK_EQ(acc_proposals, total_proposals)
<< "\nExcepted " << total_proposals << " proposals from the network, "
<< "but generated " << acc_proposals << " proposals.";
rcnn::GenerateProposals_v2<T, Context>(total_proposals, im_height, im_width,
min_box_h, min_box_w,
input(-3).template data<T, Context>() + total_proposals,
input(-2).template data<T, Context>(),
proposals_.template mutable_data<T, Context>());
rcnn::SortProposals(0, total_proposals - 1, pre_nms_top_n,
proposals_.template mutable_data<T, CPUContext>());
rcnn::NMS<T, Context>(pre_nms_topn, post_nms_top_n, nms_thresh,
proposals_.template mutable_data<T, Context>(),
roi_indices_.template mutable_data<int, CPUContext>(),
num_rois,
&nms_mask_);
rcnn::RetrieveRoIs<T>(num_rois, n, proposals_.template mutable_data<T, CPUContext>(),
roi_indices_.template mutable_data<int, CPUContext>(),
Ydata);
} else {
LOG(FATAL) << "There should be given at least one stride for proposals.";
}
total_rois += num_rois;
Ydata += (num_rois * 5);
im_info += 3;
}
output(0)->Reshape(vector<TIndex>({ total_rois, 5 }));
// distribute rois into K bins
if (OutputSize() > 1) {
CHECK_EQ(max_level - min_level + 1, (int)OutputSize())
<< "Excepted " << OutputSize() << " outputs for levels between "
<< "[" << min_level << ", " << max_level << "].";
vector< vector<TIndex> > roi_bins(OutputSize(), vector<TIndex>());
vector<T*> outputs;
Tensor collective_rois;
collective_rois.ReshapeLike(*output(0));
auto* rois = collective_rois.template mutable_data<T, CPUContext>();
CPUContext::template Copy<T, CPUContext, CPUContext>(collective_rois.count(),
rois,
output(0)->template data<T, CPUContext>());
rcnn::CollectRoIs<T>(total_rois, min_level, max_level,
canonical_level, canonical_scale,
rois,
roi_bins);
for (int i = 0; i < OutputSize(); i++) {
output(i)->Reshape(vector<TIndex>({ std::max((int)roi_bins[i].size(), 1), 5 }));
outputs.push_back(output(i)->template mutable_data<T, CPUContext>());
}
rcnn::DistributeRoIs(roi_bins, rois, outputs);
}
}
template <class Context>
void ProposalOp<Context>::RunOnDevice() {
num_images = input(0).dim(0);
CHECK_EQ(input(-1).count(), num_images * 3)
<< "Excepted " << num_images * 3 << " groups image info, "
<< "but got " << input(-1).count() / 3 << ".";
roi_indices_.Reshape(vector<TIndex>(1, post_nms_top_n));
output(0)->Reshape(vector<TIndex>({ num_images * post_nms_top_n, 5 }));
if (TypeMeta::Id<Context>() == TypeMeta::Id<CPUContext>()) {
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (TypeMeta::Id<Context>() == TypeMeta::Id<CUDAContext>()) {
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
}
DEPLOY_CPU(Proposal);
#ifdef WITH_CUDA
DEPLOY_CUDA(Proposal);
#endif
OPERATOR_SCHEMA(Proposal).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
} // namespace dragon
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_CONTRIB_RCNN_PROPOSAL_OP_H_
#define DRAGON_CONTRIB_RCNN_PROPOSAL_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ProposalOp final : public Operator<Context> {
public:
ProposalOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
strides(OperatorBase::GetRepeatedArg<int>("strides")),
ratios(OperatorBase::GetRepeatedArg<float>("ratios")),
scales(OperatorBase::GetRepeatedArg<float>("scales")),
pre_nms_top_n(OperatorBase::GetSingleArg<int>("pre_nms_top_n", 6000)),
post_nms_top_n(OperatorBase::GetSingleArg<int>("post_nms_top_n", 300)),
nms_thresh(OperatorBase::GetSingleArg<float>("nms_thresh", (float)0.7)),
min_size(OperatorBase::GetSingleArg<int>("min_size", 16)),
min_level(OperatorBase::GetSingleArg<int>("min_level", 2)),
max_level(OperatorBase::GetSingleArg<int>("max_level", 5)),
canonical_level(OperatorBase::GetSingleArg<int>("canonical_level", 4)),
canonical_scale(OperatorBase::GetSingleArg<int>("canonical_scale", 224)) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
vector<int> strides;
vector<float> ratios, scales;
TIndex pre_nms_top_n, post_nms_top_n, min_size, num_images;
TIndex min_level, max_level, canonical_level, canonical_scale;
float nms_thresh;
Tensor anchors_, proposals_, roi_indices_, nms_mask_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_CONTRIB_RCNN_PROPOSAL_OP_H_
\ No newline at end of file
...@@ -244,23 +244,46 @@ void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -244,23 +244,46 @@ void CuDNNBatchNormGradientOp<Context>::TrainingRunWithType() {
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() { void CuDNNBatchNormGradientOp<Context>::InferenceRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
auto* dYdata = input(-1).template data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* hVar_data = input(2).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// gradient w.r.t. scale
if (output(1)->name() != "ignore")
LOG(FATAL) << "The gamma should be fixed if using global stats.";
// gradient w.r.t. bias
if (output(2)->name() != "ignore") {
auto* dBdata = output(2)->template mutable_data<T, Context>();
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NC, S,
1.0, dYdata, SMul_data,
0.0, NC_data);
math::Gemv<T, Context>(CblasTrans, N, C,
1.0, NC_data, NMul_data,
1.0, dBdata);
} else if (data_format == "NHWC") {
math::Gemv<T, Context>(CblasTrans, NS, C,
1.0, dYdata, NSMul_data,
1.0, dBdata);
}
}
// gradient w.r.t. x
if (output(0)->name() != "ignore") { if (output(0)->name() != "ignore") {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
stddev = ws()->GetBuffer(); stddev = ws()->GetBuffer();
stddev->ReshapeLike(input(0)); stddev->ReshapeLike(input(0));
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* hVar_data = input(2).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// compute stddev // compute stddev
ctx().template Copy<T, Context, Context>(var->count(), tVar_data, hVar_data); ctx().template Copy<T, Context, Context>(var->count(), tVar_data, hVar_data);
......
...@@ -429,21 +429,43 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -429,21 +429,43 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
template <class Context> template <typename T> template <class Context> template <typename T>
void FusedBatchNormGradientOp<Context>::InferenceRunWithType() { void FusedBatchNormGradientOp<Context>::InferenceRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
auto* dYdata = input(-1).template data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// gradient w.r.t. scale
if (output(1)->name() != "ignore")
LOG(FATAL) << "The gamma should be fixed if using global stats.";
// gradient w.r.t. bias
if (output(2)->name() != "ignore") {
auto* dBdata = output(2)->template mutable_data<T, Context>();
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NC, S,
1.0, dYdata, SMul_data,
0.0, NC_data);
math::Gemv<T, Context>(CblasTrans, N, C,
1.0, NC_data, NMul_data,
1.0, dBdata);
} else if (data_format == "NHWC") {
math::Gemv<T, Context>(CblasTrans, NS, C,
1.0, dYdata, NSMul_data,
1.0, dBdata);
}
}
// gradient w.r.t. x
if (output(0)->name() != "ignore") { if (output(0)->name() != "ignore") {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* hVar_data = input(2).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// divide scale by stddev // divide scale by stddev
math::Div<T, Context>(var->count(), Sdata, tVar_data, tVar_data); math::Div<T, Context>(var->count(), Sdata, tVar_data, tVar_data);
...@@ -492,7 +514,9 @@ void FusedBatchNormGradientOp<Context>::Setup() { ...@@ -492,7 +514,9 @@ void FusedBatchNormGradientOp<Context>::Setup() {
// reshape // reshape
num_by_chans.Reshape(vector<TIndex>(1, NC)); num_by_chans.Reshape(vector<TIndex>(1, NC));
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0)); // dX
output(1)->ReshapeLike(input(3)); // dScale
output(2)->ReshapeLike(input(3)); // dBias
} }
template <class Context> template <class Context>
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!