Commit 2d1b7752 by Ting PAN

Disable sharing gradients on shape ops

1 parent 7bc8fb22
...@@ -34,9 +34,10 @@ class ConcatGradientOp : public Operator<Context> { ...@@ -34,9 +34,10 @@ class ConcatGradientOp : public Operator<Context> {
ConcatGradientOp(const OperatorDef& op_def, Workspace* ws) ConcatGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 1)), axis(OperatorBase::GetSingleArg<int>("axis", 1)),
nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {} nin(OperatorBase::GetSingleArg<int>("num_input", 1)) {
DISABLE_SHARE_GRADIENT;
}
void ShareGradient() override;
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -46,7 +46,9 @@ class CropGradientOp final : public Operator<Context > { ...@@ -46,7 +46,9 @@ class CropGradientOp final : public Operator<Context > {
start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)), start_axis(OperatorBase::GetSingleArg<int>("start_axis", -1)),
offsets(OperatorBase::GetRepeatedArg<int>("offsets")), offsets(OperatorBase::GetRepeatedArg<int>("offsets")),
shape(OperatorBase::GetRepeatedArg<int>("shape")), shape(OperatorBase::GetRepeatedArg<int>("shape")),
shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {} shape_like(OperatorBase::GetSingleArg<string>("shape_like", "")) {
DISABLE_SHARE_GRADIENT;
}
void Setup(); void Setup();
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -28,7 +28,9 @@ template <class Context> ...@@ -28,7 +28,9 @@ template <class Context>
class ExpandDimsGradientOp final : public Operator<Context> { class ExpandDimsGradientOp final : public Operator<Context> {
public: public:
ExpandDimsGradientOp(const OperatorDef& op_def, Workspace* ws) ExpandDimsGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -32,7 +32,9 @@ template <class Context> ...@@ -32,7 +32,9 @@ template <class Context>
class FlattenGradientOp final : public Operator<Context> { class FlattenGradientOp final : public Operator<Context> {
public: public:
FlattenGradientOp(const OperatorDef& op_def, Workspace* ws) FlattenGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -63,6 +63,7 @@ class PadGradientOp final : public Operator<Context> { ...@@ -63,6 +63,7 @@ class PadGradientOp final : public Operator<Context> {
} }
std::sort(process_axes.begin(), process_axes.end()); std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end()); std::reverse(process_axes.begin(), process_axes.end());
DISABLE_SHARE_GRADIENT;
} }
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -34,7 +34,9 @@ class RandomPickGradientOp final : public Operator<Context> { ...@@ -34,7 +34,9 @@ class RandomPickGradientOp final : public Operator<Context> {
public: public:
RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws) RandomPickGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {} axis(OperatorBase::GetSingleArg<int>("axis", 0)) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
......
...@@ -31,7 +31,9 @@ template <class Context> ...@@ -31,7 +31,9 @@ template <class Context>
class ReshapeGradientOp final : public Operator<Context> { class ReshapeGradientOp final : public Operator<Context> {
public: public:
ReshapeGradientOp(const OperatorDef& op_def, Workspace* ws) ReshapeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -44,6 +44,7 @@ class TileGradientOp : public Operator<Context> { ...@@ -44,6 +44,7 @@ class TileGradientOp : public Operator<Context> {
process_axes.push_back({ multiples[i], i }); process_axes.push_back({ multiples[i], i });
std::sort(process_axes.begin(), process_axes.end()); std::sort(process_axes.begin(), process_axes.end());
std::reverse(process_axes.begin(), process_axes.end()); std::reverse(process_axes.begin(), process_axes.end());
DISABLE_SHARE_GRADIENT;
} }
void RunOnDevice() override; void RunOnDevice() override;
......
...@@ -29,7 +29,7 @@ class ROIAlignOp : public Operator<Context> { ...@@ -29,7 +29,7 @@ class ROIAlignOp : public Operator<Context> {
protected: protected:
int pool_h, pool_w; int pool_h, pool_w;
float spatial_scale; float spatial_scale;
Tensor* mask; Tensor* mask_h, *mask_w;
}; };
template <class Context> template <class Context>
...@@ -50,7 +50,7 @@ class ROIAlignGradientOp : public Operator<Context> { ...@@ -50,7 +50,7 @@ class ROIAlignGradientOp : public Operator<Context> {
protected: protected:
int pool_h, pool_w; int pool_h, pool_w;
float spatial_scale; float spatial_scale;
Tensor* mask; Tensor* mask_h, *mask_w;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -813,7 +813,8 @@ void ROIAlign(const float spatial_scale, ...@@ -813,7 +813,8 @@ void ROIAlign(const float spatial_scale,
const int pool_w, const int pool_w,
Tensor* x, Tensor* x,
Tensor* roi, Tensor* roi,
Tensor* mask, Tensor* mask_h,
Tensor* mask_w,
Tensor* y); Tensor* y);
template <typename T, class Context> template <typename T, class Context>
...@@ -822,7 +823,8 @@ void ROIAlignGrad(const float spatial_scale, ...@@ -822,7 +823,8 @@ void ROIAlignGrad(const float spatial_scale,
const int pool_w, const int pool_w,
Tensor* dy, Tensor* dy,
Tensor* roi, Tensor* roi,
Tensor* mask, Tensor* mask_h,
Tensor* mask_w,
Tensor* dx); Tensor* dx);
} // namespace kernel } // namespace kernel
......
...@@ -21,6 +21,7 @@ from .neuron import ReLULayer, \ ...@@ -21,6 +21,7 @@ from .neuron import ReLULayer, \
ELULayer, \ ELULayer, \
SELULayer, \ SELULayer, \
DropoutLayer, \ DropoutLayer, \
SigmoidLayer, \
TanHLayer, \ TanHLayer, \
PowerLayer PowerLayer
...@@ -53,6 +54,7 @@ from .common import InnerProductLayer, \ ...@@ -53,6 +54,7 @@ from .common import InnerProductLayer, \
NormalizeLayer, \ NormalizeLayer, \
InstanceNormLayer, \ InstanceNormLayer, \
TileLayer, \ TileLayer, \
ReductionLayer, \
ExpandDimsLayer, \ ExpandDimsLayer, \
ProposalLayer, \ ProposalLayer, \
DenseConcatLayer DenseConcatLayer
\ No newline at end of file
...@@ -553,6 +553,32 @@ class TileLayer(Layer): ...@@ -553,6 +553,32 @@ class TileLayer(Layer):
return ops.Tile(input, **self._param) return ops.Tile(input, **self._param)
class ReductionLayer(Layer):
"""The extended implementation of ``ReductionLayer``.
Parameters
----------
operation : caffe_pb2.ReductionOp
The operation. Refer `ReductionParameter.operation`_.
axis : int
The axis to to reduce. Refer `ReductionParameter.axis`_.
"""
def __init__(self, LayerParameter):
super(ReductionLayer, self).__init__(LayerParameter)
param = LayerParameter.reduction_param
if param.axis < 0:
if param.axis != -1:
raise ValueError('The negative axis can only be -1(reduce all).')
self._param = {'operation': {1: 'SUM', 4: 'MEAN'}[param.operation],
'axis': param.axis}
def Setup(self, bottom):
super(ReductionLayer, self).Setup(bottom)
input = bottom[0] if isinstance(bottom, list) else bottom
return ops.Reduce(input, **self._param)
class ExpandDimsLayer(Layer): class ExpandDimsLayer(Layer):
"""The implementation of ``ExpandDimsLayer``. """The implementation of ``ExpandDimsLayer``.
......
...@@ -27,7 +27,7 @@ class SoftmaxWithLossLayer(Layer): ...@@ -27,7 +27,7 @@ class SoftmaxWithLossLayer(Layer):
super(SoftmaxWithLossLayer, self).__init__(LayerParameter) super(SoftmaxWithLossLayer, self).__init__(LayerParameter)
param = LayerParameter.loss_param param = LayerParameter.loss_param
softmax_param = LayerParameter.softmax_param softmax_param = LayerParameter.softmax_param
norm_mode = {0: 'FULL', 1: 'VALID', 2: 'BATCH_SIZE', 3: 'NONE'} norm_mode = {0: 'FULL', 1: 'VALID', 2: 'BATCH_SIZE', 3: 'NONE', 4: 'UNIT'}
normalization = 'VALID' normalization = 'VALID'
if param.HasField('normalize'): if param.HasField('normalize'):
if not param.normalize: normalization = 'BATCH_SIZE' if not param.normalize: normalization = 'BATCH_SIZE'
...@@ -57,7 +57,7 @@ class SigmoidCrossEntropyLossLayer(Layer): ...@@ -57,7 +57,7 @@ class SigmoidCrossEntropyLossLayer(Layer):
def __init__(self, LayerParameter): def __init__(self, LayerParameter):
super(SigmoidCrossEntropyLossLayer, self).__init__(LayerParameter) super(SigmoidCrossEntropyLossLayer, self).__init__(LayerParameter)
param = LayerParameter.loss_param param = LayerParameter.loss_param
norm_mode = {0: 'FULL', 1: 'BATCH_SIZE', 2: 'BATCH_SIZE', 3: 'NONE'} norm_mode = {0: 'FULL', 1: 'BATCH_SIZE', 2: 'BATCH_SIZE', 3: 'NONE', 4: 'UNIT'}
normalization = 'BATCH_SIZE' normalization = 'BATCH_SIZE'
if param.HasField('normalize'): if param.HasField('normalize'):
if param.normalize: normalization = 'FULL' if param.normalize: normalization = 'FULL'
...@@ -157,7 +157,7 @@ class SoftmaxWithFocalLossLayer(Layer): ...@@ -157,7 +157,7 @@ class SoftmaxWithFocalLossLayer(Layer):
param = LayerParameter.loss_param param = LayerParameter.loss_param
softmax_param = LayerParameter.softmax_param softmax_param = LayerParameter.softmax_param
focal_loss_param = LayerParameter.focal_loss_param focal_loss_param = LayerParameter.focal_loss_param
norm_mode = {0: 'FULL', 1: 'VALID', 2: 'BATCH_SIZE', 3: 'NONE'} norm_mode = {0: 'FULL', 1: 'VALID', 2: 'BATCH_SIZE', 3: 'NONE', 4: 'UNIT'}
normalization = 'VALID' normalization = 'VALID'
if param.HasField('normalize'): if param.HasField('normalize'):
if not param.normalize: normalization = 'BATCH_SIZE' if not param.normalize: normalization = 'BATCH_SIZE'
......
...@@ -217,6 +217,11 @@ class Net(object): ...@@ -217,6 +217,11 @@ class Net(object):
for idx, loss_weight in enumerate(LayerParameter.loss_weight): for idx, loss_weight in enumerate(LayerParameter.loss_weight):
if loss_weight <= 0: continue if loss_weight <= 0: continue
self._costs.append(self.blobs[LayerParameter.top[idx]].data) self._costs.append(self.blobs[LayerParameter.top[idx]].data)
else:
if len(LayerParameter.loss_weight) != 0:
for idx, loss_weight in enumerate(LayerParameter.loss_weight):
if loss_weight <= 0: continue
self._costs.append(self.blobs[LayerParameter.top[idx]].data)
if self._phase != 'TRAIN': return if self._phase != 'TRAIN': return
......
...@@ -473,6 +473,8 @@ message LossParameter { ...@@ -473,6 +473,8 @@ message LossParameter {
BATCH_SIZE = 2; BATCH_SIZE = 2;
// Do not normalize the loss. // Do not normalize the loss.
NONE = 3; NONE = 3;
// Do not reduce the loss.
UNIT = 4;
} }
optional NormalizationMode normalization = 3 [default = VALID]; optional NormalizationMode normalization = 3 [default = VALID];
// Deprecated. Ignored if normalization is specified. If normalization // Deprecated. Ignored if normalization is specified. If normalization
......
...@@ -296,7 +296,10 @@ class Solver(object): ...@@ -296,7 +296,10 @@ class Solver(object):
for i in xrange(self._param.iter_size): for i in xrange(self._param.iter_size):
self.train(return_outputs=False) self.train(return_outputs=False)
if root_solver(): if root_solver():
for cost in self._net._costs: loss += ws.FetchTensor(cost)[0] for cost in self._net._costs:
cost_value = ws.FetchTensor(cost)
if cost_value.size == 1:
loss += cost_value[0]
if root_solver(): if root_solver():
loss /= self._param.iter_size loss /= self._param.iter_size
......
...@@ -279,6 +279,7 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -279,6 +279,7 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
external_input_exprs = OrderedDict(external_input_exprs, **new_tensor.expressions) external_input_exprs = OrderedDict(external_input_exprs, **new_tensor.expressions)
else: else:
external_input_exprs = dict(external_input_exprs, **new_tensor.expressions) external_input_exprs = dict(external_input_exprs, **new_tensor.expressions)
external_input_exprs = OrderedDict(sorted(external_input_exprs.items(), lambda x, y: cmp(x[1], y[1])))
elif isinstance(new_tensor, np.ndarray): elif isinstance(new_tensor, np.ndarray):
ws.FeedTensor(new_tensor, GetTensorName()) ws.FeedTensor(new_tensor, GetTensorName())
external_input_ops = [v for k, v in external_input_exprs.items()] external_input_ops = [v for k, v in external_input_exprs.items()]
......
...@@ -104,17 +104,6 @@ void ConcatGradientOp<Context>::RunOnDevice() { ...@@ -104,17 +104,6 @@ void ConcatGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
template <class Context>
void ConcatGradientOp<Context>::ShareGradient() {
for (int i = 0; i < OutputSize(); i++) {
if (output(i)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
ws()->CreateAvatar(output(i), dX);
break;
}
}
}
DEPLOY_CPU(ConcatGradient); DEPLOY_CPU(ConcatGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ConcatGradient); DEPLOY_CUDA(ConcatGradient);
......
...@@ -11,17 +11,20 @@ void ROIAlignOp<Context>::RunWithType() { ...@@ -11,17 +11,20 @@ void ROIAlignOp<Context>::RunWithType() {
pool_h, pool_w, pool_h, pool_w,
&input(0), &input(0),
&input(1), &input(1),
mask, mask_h,
mask_w,
output(0)); output(0));
} }
template <class Context> template <class Context>
void ROIAlignOp<Context>::RunOnDevice() { void ROIAlignOp<Context>::RunOnDevice() {
mask = ws()->CreateTensor("/mnt/" + anchor() + "/roi_align_mask"); mask_h = ws()->CreateTensor("/mnt/" + anchor() + "/roi_align_mask_h");
mask_w = ws()->CreateTensor("/mnt/" + anchor() + "/roi_align_mask_w");
vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w}); vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w});
output(0)->Reshape(dims); output(0)->Reshape(dims);
mask->Reshape(dims); mask_h->Reshape(dims);
mask_w->Reshape(dims);
if (input(0).template IsType<float>()) return RunWithType<float>(); if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
...@@ -39,13 +42,15 @@ void ROIAlignGradientOp<Context>::RunWithType() { ...@@ -39,13 +42,15 @@ void ROIAlignGradientOp<Context>::RunWithType() {
pool_h, pool_w, pool_h, pool_w,
&input(-1), &input(-1),
&input(1), &input(1),
mask, mask_h,
mask_w,
output(0)); output(0));
} }
template <class Context> template <class Context>
void ROIAlignGradientOp<Context>::RunOnDevice() { void ROIAlignGradientOp<Context>::RunOnDevice() {
mask = ws()->GetTensor("/mnt/" + anchor() + "/roi_align_mask"); mask_h = ws()->GetTensor("/mnt/" + anchor() + "/roi_align_mask_h");
mask_w = ws()->GetTensor("/mnt/" + anchor() + "/roi_align_mask_w");
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
......
...@@ -2640,7 +2640,8 @@ template<> void ROIAlign<float, CPUContext>(const float spatial_scale, ...@@ -2640,7 +2640,8 @@ template<> void ROIAlign<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
Tensor* x, Tensor* x,
Tensor* roi, Tensor* roi,
Tensor* mask, Tensor* mask_h,
Tensor* mask_w,
Tensor* y) { Tensor* y) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -2649,7 +2650,8 @@ template<> void ROIAlignGrad<float, CPUContext>(const float spatial_scale, ...@@ -2649,7 +2650,8 @@ template<> void ROIAlignGrad<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
Tensor* dy, Tensor* dy,
Tensor* roi, Tensor* roi,
Tensor* mask, Tensor* mask_h,
Tensor* mask_w,
Tensor* dx) { Tensor* dx) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
......
...@@ -3937,7 +3937,8 @@ __global__ void _ROIAlign(const int count, ...@@ -3937,7 +3937,8 @@ __global__ void _ROIAlign(const int count,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
const T* x, const T* x,
const T* roi, const T* roi,
T* mask, T* mask_h,
T* mask_w,
T* y) { T* y) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_KERNEL_LOOP(idx, count) {
int pw = idx % pool_w; int pw = idx % pool_w;
...@@ -3970,18 +3971,17 @@ __global__ void _ROIAlign(const int count, ...@@ -3970,18 +3971,17 @@ __global__ void _ROIAlign(const int count,
bool is_empty = (hend <= hstart) || (wend <= wstart); bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -FLT_MAX; T maxval = is_empty ? 0 : -FLT_MAX;
int maxidx = -1; T max_h_idx = -1;
int x_idx = 0; T max_w_idx = -1;
x += (roi_batch_ind * channels + c) * height * width; x += (roi_batch_ind * channels + c) * height * width;
T h_stride = (hend - hstart) / 3.0; T h_stride = (hend - hstart) / 3.0;
T w_stride = (wend - wstart) / 3.0; T w_stride = (wend - wstart) / 3.0;
for (T h = hstart + h_stride; h <= hend - h_stride + 0.01; h += max(h_stride, 0.01)) { for (T h = hstart + h_stride; h <= hend - h_stride + 0.01; h += max(h_stride, 0.01)) {
for (T w = wstart + w_stride; w <= wend - w_stride + 0.01; w += max(w_stride, 0.01)) { for (T w = wstart + w_stride; w <= wend - w_stride + 0.01; w += max(w_stride, 0.01)) {
x_idx++;
int hlow = min(max(static_cast<int>(floor(h)), 0), height - 1); int hlow = min(max(static_cast<int>(floor(h)), 0), height - 1);
int hhigh = min(hlow + 1, height - 1); int hhigh = min(max(static_cast<int>(ceil(h)), 0), height - 1);
int wleft = min(max(static_cast<int>(floor(w)), 0), width - 1); int wleft = min(max(static_cast<int>(floor(w)), 0), width - 1);
int wright = min(wleft + 1, width - 1); int wright = min(max(static_cast<int>(ceil(w)), 0), width - 1);
int topleft = hlow * width + wleft; int topleft = hlow * width + wleft;
int topright = hlow * width + wright; int topright = hlow * width + wright;
int bottomleft = hhigh * width + wleft; int bottomleft = hhigh * width + wleft;
...@@ -3994,12 +3994,14 @@ __global__ void _ROIAlign(const int count, ...@@ -3994,12 +3994,14 @@ __global__ void _ROIAlign(const int count,
if (value > maxval) { if (value > maxval) {
maxval = value; maxval = value;
maxidx = x_idx; max_h_idx = h;
max_w_idx = w;
} }
} }
} }
y[idx] = maxval; y[idx] = maxval;
mask[idx] = maxidx; mask_h[idx] = max_h_idx;
mask_w[idx] = max_w_idx;
} }
} }
...@@ -4007,12 +4009,14 @@ template<> void ROIAlign<float, CUDAContext>(const float spatial_scale, ...@@ -4007,12 +4009,14 @@ template<> void ROIAlign<float, CUDAContext>(const float spatial_scale,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
Tensor* x, Tensor* x,
Tensor* roi, Tensor* roi,
Tensor* mask, Tensor* mask_h,
Tensor* mask_w,
Tensor* y) { Tensor* y) {
auto* Xdata = x->data<float, CUDAContext>(); auto* Xdata = x->data<float, CUDAContext>();
auto* Rdata = roi->data<float, CUDAContext>(); auto* Rdata = roi->data<float, CUDAContext>();
auto* Ydata = y->mutable_data<float, CUDAContext>(); auto* Ydata = y->mutable_data<float, CUDAContext>();
auto* Mdata = mask->mutable_data<float, CUDAContext>(); auto* MHdata = mask_h->mutable_data<float, CUDAContext>();
auto* MWdata = mask_w->mutable_data<float, CUDAContext>();
TIndex channels = x->dim(1), count = y->count(); TIndex channels = x->dim(1), count = y->count();
TIndex height = x->dim(2), width = x->dim(3); TIndex height = x->dim(2), width = x->dim(3);
_ROIAlign<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _ROIAlign<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
...@@ -4022,7 +4026,8 @@ template<> void ROIAlign<float, CUDAContext>(const float spatial_scale, ...@@ -4022,7 +4026,8 @@ template<> void ROIAlign<float, CUDAContext>(const float spatial_scale,
pool_h, pool_w, pool_h, pool_w,
Xdata, Xdata,
Rdata, Rdata,
Mdata, MHdata,
MWdata,
Ydata); Ydata);
CUDA_POST_KERNEL_CHECK; CUDA_POST_KERNEL_CHECK;
} }
...@@ -4036,7 +4041,8 @@ __global__ void _ROIAlignGrad(const int count, ...@@ -4036,7 +4041,8 @@ __global__ void _ROIAlignGrad(const int count,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
const T* dy, const T* dy,
const T* roi, const T* roi,
const T* mask, const T* mask_h,
const T* mask_w,
T* dx) { T* dx) {
CUDA_KERNEL_LOOP(idx, count) { CUDA_KERNEL_LOOP(idx, count) {
int w = idx % width; int w = idx % width;
...@@ -4063,47 +4069,24 @@ __global__ void _ROIAlignGrad(const int count, ...@@ -4063,47 +4069,24 @@ __global__ void _ROIAlignGrad(const int count,
int offset = (roi_n * channels + c) * pool_h * pool_w; int offset = (roi_n * channels + c) * pool_h * pool_w;
const T* offset_dy = dy + offset; const T* offset_dy = dy + offset;
const T* offset_mask = mask + offset; const T* offset_mask_h = mask_h + offset;
const T* offset_mask_w = mask_w + offset;
T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(1)); T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(1));
T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(1)); T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pool_h);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pool_w);
for (int ph = 0; ph < pool_h; ++ph) { for (int ph = 0; ph < pool_h; ++ph) {
for (int pw = 0; pw < pool_w; ++pw) { for (int pw = 0; pw < pool_w; ++pw) {
T hstart = static_cast<T>((ph)* bin_size_h);
T wstart = static_cast<T>((pw)* bin_size_w);
T hend = static_cast<T>((ph + 1) * bin_size_h);
T wend = static_cast<T>((pw + 1) * bin_size_w);
hstart = min(max(hstart + roi_start_h, static_cast<T>(0)), static_cast<T>(height));
hend = min(max(hend + roi_start_h, static_cast<T>(0)), static_cast<T>(height));
wstart = min(max(wstart + roi_start_w, static_cast<T>(0)), static_cast<T>(width));
wend = min(max(wend + roi_start_w, static_cast<T>(0)), static_cast<T>(width));
bool in_bin = (w > wstart - 1.0 &&
w < wend + 1.0 &&
h > hstart - 1.0
&& h < hend + 1.0);
if (!in_bin) continue;
const int pool_idx = ph * pool_w + pw; const int pool_idx = ph * pool_w + pw;
int x_idx = 0; T a_h = offset_mask_h[pool_idx];
T h_stride = (hend - hstart) / 3.0; T a_w = offset_mask_w[pool_idx];
T w_stride = (wend - wstart) / 3.0; int hlow = min(max(static_cast<int>(floor(a_h)), 0), height - 1);
for (T rh = hstart + h_stride; rh <= hend - h_stride + 0.01; rh += max(h_stride, 0.01)) { int hhigh = min(max(static_cast<int>(ceil(a_h)), 0), height - 1);
for (T rw = wstart + w_stride; rw <= wend - w_stride + 0.01; rw += max(w_stride, 0.01)) { int wleft = min(max(static_cast<int>(floor(a_w)), 0), width - 1);
x_idx++; int wright = min(max(static_cast<int>(ceil(a_w)), 0), width - 1);
if (offset_mask[pool_idx] != x_idx) continue;
int hlow = min(max(static_cast<int>(floor(rh)), 0), height - 1);
int hhigh = min(hlow + 1, height - 1);
int wleft = min(max(static_cast<int>(floor(rw)), 0), width - 1);
int wright = min(wleft + 1, width - 1);
if (h != hlow && h != hhigh && w != wleft && w != wright) continue; if (h != hlow && h != hhigh && w != wleft && w != wright) continue;
T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (rh - hlow) / (hhigh - hlow); T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (a_h - hlow) / (hhigh - hlow);
T beta = (wleft == wright) ? static_cast<T>(0.5) : (rw - wleft) / (wright - wleft); T beta = (wleft == wright) ? static_cast<T>(0.5) : (a_w - wleft) / (wright - wleft);
if (h == hlow && w == wleft) gradient += offset_dy[pool_idx] * (1 - alpha) * (1 - beta); if (h == hlow && w == wleft) gradient += offset_dy[pool_idx] * (1 - alpha) * (1 - beta);
else if (h == hlow && w == wright) gradient += offset_dy[pool_idx] * (1 - alpha) * beta; else if (h == hlow && w == wright) gradient += offset_dy[pool_idx] * (1 - alpha) * beta;
else if (h == hhigh && w == wleft) gradient += offset_dy[pool_idx] * alpha * (1 - beta); else if (h == hhigh && w == wleft) gradient += offset_dy[pool_idx] * alpha * (1 - beta);
...@@ -4111,8 +4094,6 @@ __global__ void _ROIAlignGrad(const int count, ...@@ -4111,8 +4094,6 @@ __global__ void _ROIAlignGrad(const int count,
} }
} }
} }
}
}
dx[idx] = gradient; dx[idx] = gradient;
} }
} }
...@@ -4121,11 +4102,13 @@ template<> void ROIAlignGrad<float, CUDAContext>(const float spatial_scale, ...@@ -4121,11 +4102,13 @@ template<> void ROIAlignGrad<float, CUDAContext>(const float spatial_scale,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
Tensor* dy, Tensor* dy,
Tensor* roi, Tensor* roi,
Tensor* mask, Tensor* mask_h,
Tensor* mask_w,
Tensor* dx) { Tensor* dx) {
auto* dYdata = dy->data<float, CUDAContext>(); auto* dYdata = dy->data<float, CUDAContext>();
auto* Rdata = roi->data<float, CUDAContext>(); auto* Rdata = roi->data<float, CUDAContext>();
auto* Mdata = mask->data<float, CUDAContext>(); auto* MHdata = mask_h->data<float, CUDAContext>();
auto* MWdata = mask_w->data<float, CUDAContext>();
auto* dXdata = dx->mutable_data<float, CUDAContext>(); auto* dXdata = dx->mutable_data<float, CUDAContext>();
TIndex channels = dx->dim(1), count = dx->count(); TIndex channels = dx->dim(1), count = dx->count();
TIndex height = dx->dim(2), width = dx->dim(3); TIndex height = dx->dim(2), width = dx->dim(3);
...@@ -4137,7 +4120,8 @@ template<> void ROIAlignGrad<float, CUDAContext>(const float spatial_scale, ...@@ -4137,7 +4120,8 @@ template<> void ROIAlignGrad<float, CUDAContext>(const float spatial_scale,
pool_h, pool_w, pool_h, pool_w,
dYdata, dYdata,
Rdata, Rdata,
Mdata, MHdata,
MWdata,
dXdata); dXdata);
CUDA_POST_KERNEL_CHECK; CUDA_POST_KERNEL_CHECK;
} }
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!