Commit 771e3d5a by Ting PAN

Add SELU & PReLU support

1 parent 4bef6a6b
......@@ -17,6 +17,7 @@
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <algorithm>
#include <mutex>
#include "core/types.h"
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class PReluOp : public Operator<Context> {
public:
PReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool channel_shared;
string data_format;
TIndex channels, dim;
};
template <class Context>
class PReluGradientOp : public Operator<Context> {
public:
PReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool channel_shared;
string data_format;
TIndex channels, dim;
Tensor* bcast_dw, *multiplier;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SEluOp : public Operator<Context> {
public:
SEluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class SEluGradientOp : public Operator<Context> {
public:
SEluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
\ No newline at end of file
......@@ -29,7 +29,7 @@ class ROIAlignOp : public Operator<Context> {
protected:
int pool_h, pool_w;
float spatial_scale;
Tensor* mask_h, *mask_w;
Tensor* mask;
};
template <class Context>
......@@ -51,7 +51,7 @@ class ROIAlignGradientOp : public Operator<Context> {
protected:
int pool_h, pool_w;
float spatial_scale;
Tensor* mask_h, *mask_w;
Tensor* mask;
};
} // namespace dragon
......
......@@ -49,6 +49,42 @@ void EluGrad(const int count,
const float alpha,
T* dx);
/******************** activation.prelu ********************/
template <typename T, class Context>
void PRelu(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const T* x,
const T* w,
T* y);
template <typename T, class Context>
void PReluGrad(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const T* dy,
const T* x,
const T* w,
T* dx);
template <typename T, class Context>
void PReluWGrad(const int rows,
const int row_offset,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const T* dy,
const T* x,
const T* multiplier,
T* bcast_dw,
T* dw);
/******************** activation.relu ********************/
template <typename T, class Context>
......@@ -61,6 +97,14 @@ void ReluGrad(const int count,
const float slope,
T* dx);
/******************** activation.selu ********************/
template <typename T, class Context>
void SElu(const int count, const T* x, T* y);
template <typename T, class Context>
void SEluGrad(const int count, const T* dy, const T* y, T* dx);
/******************** activation.sigmoid ********************/
template <typename T, class Context>
......@@ -745,8 +789,7 @@ void ROIAlign(const float spatial_scale,
const int pool_w,
Tensor* x,
Tensor* roi,
Tensor* mask_h,
Tensor* mask_y,
Tensor* mask,
Tensor* y);
template <typename T, class Context>
......@@ -755,8 +798,7 @@ void ROIAlignGrad(const float spatial_scale,
const int pool_w,
Tensor* dy,
Tensor* roi,
Tensor* mask_h,
Tensor* mask_y,
Tensor* mask,
Tensor* dx);
} // namespace kernel
......
......@@ -243,3 +243,24 @@ def SetLoggingLevel(level):
'ERROR': logging.ERROR,
'FATAL': logging.CRITICAL
}[level])
def SetLoggingFile(log_file):
"""Redirect the logging into the specific file.
Parameters
----------
log_file : str
The file for logging.
Notes
-----
The function will disable all possible logging at the terminal.
"""
global logger
new_logger = logging.getLogger('dragon_filehandler')
new_logger.setLevel(logger.level)
file_handler = logging.FileHandler(log_file, mode="w", encoding="UTF-8")
new_logger.addHandler(file_handler)
logger = new_logger
\ No newline at end of file
......@@ -14,7 +14,6 @@ import dragon.protos.dragon_pb2 as pb
import numpy as np
import os
from dragon import *
from dragon.config import logger
from google.protobuf.message import Message
from six.moves import range as xrange
......@@ -339,7 +338,7 @@ def LogMetaGraph(meta_graph):
None
"""
from dragon.config import option
from dragon.config import option, logger
if option['log_meta_graph']:
logger.info(meta_graph)
......@@ -358,11 +357,12 @@ def GetOptimizedGraph(meta_graph):
The definition of optimized graph.
"""
from dragon.config import logger
graph_name = meta_graph.name
graph_tensor = 'GraphDef_' + graph_name
if not HasTensorCC(graph_tensor):
logger.info('graph: {} does not exist, ignore printing....'.format(graph_name))
logger.info('Graph({}) does not exist, ignore printing....'.format(graph_name))
return
opt_graph_def = pb.GraphDef()
......@@ -383,7 +383,7 @@ def LogOptimizedGraph(meta_graph):
None
"""
from dragon.config import option
from dragon.config import option, logger
if option['log_optimized_graph']:
optimized_graph = GetOptimizedGraph(meta_graph)
logger.info(optimized_graph)
......@@ -404,7 +404,7 @@ def ExportMetaGraph(meta_graph):
None
"""
from dragon.config import option
from dragon.config import option, logger
if option['export_meta_graph']:
if not os.path.exists(option['export_meta_graph']):
try:
......@@ -445,6 +445,7 @@ def Snapshot(tensors, filename, prefix='', suffix='.bin', format='default'):
Available formats: ['default', 'caffe'].
"""
from dragon.config import logger
filepath = prefix + filename + suffix
if mpi.Is_Init():
if not mpi.AllowSnapshot(): return
......@@ -488,6 +489,7 @@ def Restore(filepath, format='default'):
Available formats: ['default', 'caffe'].
"""
from dragon.config import logger
assert os.path.exists(filepath), 'model of path({}) does not exist.'.format(filepath)
if format == 'default':
content = cPickle.load(open(filepath, 'rb'))
......
......@@ -22,6 +22,7 @@ List Brief
`LogOptimizedGraph`_ Enable to log optimized graph globally.
`ExportMetaGraph`_ Enable to export all runnable meta graphs into text files.
`SetLoggingLevel`_ Set the minimum level of Logging.
`SetLoggingFile`_ Redirect the logging into the specific file.
==================== =============================================================================
API Reference
......@@ -41,3 +42,4 @@ API Reference
.. _LogOptimizedGraph: #dragon.config.LogOptimizedGraph
.. _ExportMetaGraph: #dragon.config.ExportMetaGraph
.. _SetLoggingLevel: #dragon.config.SetLoggingLevel
.. _SetLoggingFile: #dragon.config.SetLoggingFile
\ No newline at end of file
......@@ -12,11 +12,15 @@
.. |tanh_function| mathmacro:: \, y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
.. |relu_function| mathmacro:: \, y = \max(0, x)
.. |relu_function| mathmacro:: \, y = \left\{ \begin{array} \\ x & & (x > 0) \\ 0 & & (x <= 0) \\ \end{array} \right.
.. |lrelu_function| mathmacro:: \, y = \left\{ \begin{array} \\ x & & (x > 0) \\ Slope * x & & (x <= 0) \\ \end{array} \right.
.. |prelu_function| mathmacro:: \, y_{i} = \left\{ \begin{array} \\ x_{i} & & (x_{i} > 0) \\ \alpha_{i} * x_{i} & & (x <= 0) \\ \end{array} \right.
.. |elu_function| mathmacro:: \, y = \left\{ \begin{array} \\ x & & (x > 0) \\ Alpha * (e^{x} - 1) & & (x <= 0) \\ \end{array} \right.
.. |leaky_relu_function| mathmacro:: \, y = \max(x, 0) + Slope * \min(x, 0)
.. |selu_function| mathmacro:: \, y = 1.0507 \left\{ \begin{array} \\ x & & (x > 0) \\ 1.6733 * (e^{x} - 1) & & (x <= 0) \\ \end{array} \right.
.. |dropout_function| mathmacro:: \, y = x * Bernoulli(p=1 - prob)
......
......@@ -61,7 +61,9 @@ List Brief
`Tanh`_ Tanh function.
`Relu`_ Rectified Linear Unit function, introduces by `[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_.
`LRelu`_ Leaky Rectified Linear Unit function.
`PRelu`_ Parametric Rectified Linear Unit function, introduces by `[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
`Elu`_ Exponential Linear Unit function, introduces by `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`SElu`_ Scaled Exponential Linear Unit function, introduces by `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`Softmax`_ Softmax function.
`Dropout`_ Randomly set a unit into zero, introduced by `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
=============== ======================================================================
......@@ -209,7 +211,9 @@ List Brief
.. _Tanh: operators/activation.html#dragon.operators.activation.Tanh
.. _Relu: operators/activation.html#dragon.operators.activation.Relu
.. _LRelu: operators/activation.html#dragon.operators.activation.LRelu
.. _PRelu: operators/activation.html#dragon.operators.activation.PRelu
.. _Elu: operators/activation.html#dragon.operators.activation.Elu
.. _SElu: operators/activation.html#dragon.operators.activation.SElu
.. _Softmax: operators/activation.html#dragon.operators.activation.Softmax
.. _Dropout: operators/activation.html#dragon.operators.activation.Dropout
......
......@@ -42,7 +42,9 @@ Neuron
List Brief
==================== =============================================================================
`ReLULayer`_ The implementation of ``ReLULayer``.
`PReLULayer`_ The implementation of ``PReLULayer``.
`ELULayer`_ The implementation of ``ELULayer``.
`SELULayer`_ The implementation of ``SELULayer``.
`SigmoidLayer`_ The implementation of ``SigmoidLayer``.
`TanHLayer`_ The implementation of ``TanHLayer``.
`DropoutLayer`_ The implementation of ``DropoutLayer``.
......@@ -154,7 +156,9 @@ API Reference
.. _BilinearResizeLayer: #dragon.vm.caffe.layers.vision.BilinearResizeLayer
.. _ReLULayer: #dragon.vm.caffe.layers.neuron.ReLULayer
.. _PReLULayer: #dragon.vm.caffe.layers.neuron.PReLULayer
.. _ELULayer: #dragon.vm.caffe.layers.neuron.ELULayer
.. _SELULayer: #dragon.vm.caffe.layers.neuron.SELULayer
.. _SigmoidLayer: #dragon.vm.caffe.layers.neuron.SigmoidLayer
.. _TanHLayer: #dragon.vm.caffe.layers.neuron.TanHLayer
.. _DropoutLayer: #dragon.vm.caffe.layers.neuron.DropoutLayer
......@@ -232,6 +236,8 @@ API Reference
.. _ResizeParameter.fy: https://github.com/neopenx/Dragon/tree/master/Dragon/python/dragon/vm/caffe/proto/caffe.proto#L1466
.. _ReLUParameter.negative_slope: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L1000
.. _PReLUParameter.filler: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L1409
.. _PReLUParameter.channel_shared: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L1411
.. _ELUParameter.alpha: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L717
.. _DropoutParameter.dropout_ratio: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L676
.. _DropoutParameter.scale_train: https://github.com/rbgirshick/caffe-fast-rcnn/blob/0dcd397b29507b8314e252e850518c5695efbb83/src/caffe/proto/caffe.proto#L638
......
......@@ -8,8 +8,6 @@ import numpy as np
from multiprocessing import Process
from six.moves import range as xrange
from dragon.config import logger
from .utils import GetProperty
class BlobFetcher(Process):
......@@ -40,6 +38,7 @@ class BlobFetcher(Process):
self.daemon = True
def cleanup():
from dragon.config import logger
logger.info('Terminating BlobFetcher......')
self.terminate()
self.join()
......
......@@ -10,7 +10,6 @@ from multiprocessing import Queue
from six.moves import range as xrange
import dragon.core.mpi as mpi
from dragon.config import logger
from .data_reader import DataReader
from .data_transformer import DataTransformer
......@@ -171,6 +170,7 @@ class DataBatch(object):
"""
Print I/O Information.
"""
from dragon.config import logger
logger.info('---------------------------------------------------------')
logger.info('BatchReader, Using config:')
params = {'prefetching': self._prefetch,
......
......@@ -9,7 +9,6 @@ import numpy.random as npr
from multiprocessing import Process
import dragon.config as config
from dragon.config import logger
from dragon.tools.db import LMDB
from .utils import GetProperty
......@@ -55,6 +54,7 @@ class DataReader(Process):
self.daemon = True
def cleanup():
from dragon.config import logger
logger.info('Terminating DataReader......')
self.terminate()
self.join()
......
......@@ -9,7 +9,6 @@ import numpy.random as npr
from multiprocessing import Process
import dragon.config as config
from dragon.config import logger
import dragon.vm.caffe.proto.caffe_pb2 as pb
from .utils import GetProperty
......@@ -72,6 +71,7 @@ class DataTransformer(Process):
self.daemon = True
def cleanup():
from dragon.config import logger
logger.info('Terminating DataTransformer......')
self.terminate()
self.join()
......
......@@ -44,7 +44,7 @@ def LRelu(inputs, slope=0.2, **kwargs):
Returns
-------
Tensor
The output tensor, calculated as: |leaky_relu_function|.
The output tensor, calculated as: |lrelu_function|.
"""
CheckInputs(inputs, 1)
......@@ -58,6 +58,35 @@ def LRelu(inputs, slope=0.2, **kwargs):
return output
def PRelu(inputs, channel_shared=False, data_format='NCHW', **kwargs):
"""Parametric Rectified Linear Unit function, introduces by `[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
Parameters
----------
inputs : list of Tensor
The input and trainable parameter(slope).
channel_shared : boolean
Whether to share the parameter(slope) across channels.
data_format : str
The data format, ``NCHW`` or ``NHWC``.
Returns
-------
Tensor
The output tensor, calculated as: |prelu_function|
"""
CheckInputs(inputs, 2)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='PRelu', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
return output
def Elu(inputs, alpha=1.0, **kwargs):
"""Exponential Linear Unit function, introduces by `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
......@@ -65,6 +94,8 @@ def Elu(inputs, alpha=1.0, **kwargs):
----------
inputs : Tensor
The input tensor.
alpha : float
The alpha.
Returns
-------
......@@ -83,6 +114,31 @@ def Elu(inputs, alpha=1.0, **kwargs):
return output
def SElu(inputs, **kwargs):
"""Scaled Exponential Linear Unit function, introduces by `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
Parameters
----------
inputs : Tensor
The input tensor.
Returns
-------
Tensor
The output tensor, calculated as: |selu_function|
"""
CheckInputs(inputs, 1)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='SElu', **arguments)
if inputs.shape is not None:
output.shape = inputs.shape[:]
return output
def Sigmoid(inputs, **kwargs):
"""Sigmoid function.
......
......@@ -232,7 +232,7 @@ def ROIPooling(inputs, pool_h, pool_w, spatial_scale, **kwargs):
return Tensor.CreateOperator(nout=1, op_type='ROIPooling', **arguments)
def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **arguments):
def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
"""Max ROIAlign, introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
The first dimension of input must be ``1``.
......
......@@ -50,7 +50,9 @@ Sigmoid = act.Sigmoid
Tanh = act.Tanh
Relu = act.Relu
LRelu = act.LRelu
PRelu = act.PRelu
Elu = act.Elu
SElu = act.SElu
Softmax = act.Softmax
Dropout = act.Dropout
......
......@@ -6,7 +6,6 @@
import numpy as np
import pprint
from dragon.config import logger
import dragon.core.workspace as ws
from dragon.core.tensor import Tensor
......@@ -85,6 +84,7 @@ class BaseUpdater(object):
"""
Print Updater Information.
"""
from dragon.config import logger
logger.info('---------------------------------------------------------')
logger.info('Optimizer: {}, Using config:'.format(self._type.split('Update')[0]))
pprint.pprint(self._hyper_params)
......
......@@ -17,7 +17,9 @@ from .vision import ConvolutionLayer, \
BilinearResizeLayer
from .neuron import ReLULayer, \
PReLULayer, \
ELULayer, \
SELULayer, \
DropoutLayer, \
TanHLayer, \
PowerLayer
......
......@@ -5,6 +5,7 @@
# --------------------------------------------------------
import dragon.ops as ops
from dragon.core.tensor import Tensor
from ..layer import Layer
......@@ -29,6 +30,35 @@ class ReLULayer(Layer):
return ops.Relu(input, **self._param)
class PReLULayer(Layer):
"""The implementation of ``PReLULayer``.
Parameters
----------
filler : FillerParameter
The filler of parameter(slope). Refer `PReLUParameter.filler`_.
channel_shared : boolean
Whether to share the parameter across channels. Refer `PReLUParameter.channel_shared`_.
"""
def __init__(self, LayerParameter):
super(PReLULayer, self).__init__(LayerParameter)
param = LayerParameter.prelu_param
self._param = {'channel_shared': param.channel_shared,
'data_format': 'NCHW'}
slope = Tensor(LayerParameter.name + '@param0')
slope_diff = Tensor(LayerParameter.name + '@param0_grad')
if param.HasField('filler'):
self.Fill(slope, param, 'filler')
else:
slope.Constant(value=0.25)
self._blobs.append({'data': slope, 'diff': slope_diff})
def Setup(self, bottom):
super(PReLULayer, self).Setup(bottom)
return ops.PRelu(bottom + [blob['data'] for blob in self._blobs], **self._param)
class ELULayer(Layer):
"""The implementation of ``ELULayer``.
......@@ -41,7 +71,7 @@ class ELULayer(Layer):
def __init__(self, LayerParameter):
super(ELULayer, self).__init__(LayerParameter)
param = LayerParameter.elu_param
self._param = {'alpha': param.alpha}
self._param = {'alpha': float(param.alpha)}
def Setup(self, bottom):
super(ELULayer, self).Setup(bottom)
......@@ -49,6 +79,19 @@ class ELULayer(Layer):
return ops.Elu(input, **self._param)
class SELULayer(Layer):
"""
The implementation of ``SELULayer``.
"""
def __init__(self, LayerParameter):
super(SELULayer, self).__init__(LayerParameter)
def Setup(self, bottom):
super(SELULayer, self).Setup(bottom)
input = bottom[0] if isinstance(bottom, list) else bottom
return ops.SElu(input, **self._param)
class SigmoidLayer(Layer):
"""
The implementation of ``SigmoidLayer``.
......
......@@ -14,7 +14,6 @@ import dragon.tools.summary_writer as sw
import dragon.vm.theano as theano
from dragon.vm.caffe.proto import caffe_pb2 as pb
from dragon.config import logger
from dragon.vm.caffe.misc import root_solver
from dragon.vm.caffe.net import Net
from google.protobuf.text_format import Parse
......@@ -172,6 +171,7 @@ class Solver(object):
The implementation of `GetLearningRate(solver.cpp, L27)`_.
"""
from dragon.config import logger
policy = self._param.lr_policy
if policy == "step":
......@@ -232,6 +232,7 @@ class Solver(object):
The implementation of `Test(solver.cpp, L328)`_.
"""
from dragon.config import logger
test_score = []
output_id = []
test_iter = self._param.test_iter[test_idx]
......@@ -278,6 +279,7 @@ class Solver(object):
The implementation of `Step(solver.cpp, L180)`_.
"""
from dragon.config import logger
start_iter = self._iter; stop_iter = self._iter + iters
loss_vec = []; smoothed_loss = 0
tic = time.time()
......
#include "operators/activation/prelu_op.h"
#include "core/workspace.h"
#include "utils/filler.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void PReluOp<Context>::RunWithType() {
if (channel_shared) {
TENSOR_FILL(input(1), vector<TIndex>(1, 1));
} else {
TENSOR_FILL(input(1), vector<TIndex>(1, input(0).dim(1)));
}
auto* Xdata = input(0).template data<T, Context>();
auto* Wdata = input(1).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::PRelu<T, Context>(output(0)->count(),
channels,
dim,
channel_shared,
data_format,
Xdata,
Wdata,
Ydata);
}
template <class Context>
void PReluOp<Context>::RunOnDevice() {
if (data_format == "NCHW") {
channels = input(0).dim(1);
dim = input(0).count(2);
} else {
channels = input(0).dim(-1);
dim = input(0).count() / channels;
}
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(PRelu);
#ifdef WITH_CUDA
DEPLOY_CUDA(PRelu);
#endif
OPERATOR_SCHEMA(PRelu).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void PReluGradientOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* dYdata = input(-1).template data<T, Context>();
if (output(1)->name() != "ignore") {
INIT_MULTIPLIER(multiplier, channels * dim);
bcast_dw = ws()->GetBuffer();
bcast_dw->Reshape(vector<TIndex>(1, channels * dim));
auto* dWdata = output(1)->template mutable_data<T, Context>();
auto* dWBdata = bcast_dw->template mutable_data<T, Context>();
kernel::PReluWGrad<T, Context>(input(0).dim(0),
input(0).count(1),
channels,
dim,
channel_shared,
data_format,
dYdata,
Xdata,
multiplier->template data<T, Context>(),
dWBdata,
dWdata);
ws()->ReleaseBuffer(bcast_dw);
}
if (output(0)->name() != "ignore") {
auto* Wdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::PReluGrad<T, Context>(output(0)->count(),
channels,
dim,
channel_shared,
data_format,
dYdata,
Xdata,
Wdata,
dXdata);
}
}
template <class Context>
void PReluGradientOp<Context>::RunOnDevice() {
if (data_format == "NCHW") {
channels = input(0).dim(1);
dim = input(0).count(2);
} else {
channels = input(0).dim(-1);
dim = input(0).count() / channels;
}
output(0)->ReshapeLike(input(0));
output(1)->ReshapeLike(input(1));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(PReluGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(PReluGradient);
#endif
OPERATOR_SCHEMA(PReluGradient).NumInputs(3).NumOutputs(2);
class GetPReluGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetPReluGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0), GI(1)});
}
};
REGISTER_GRADIENT(PRelu, GetPReluGradient);
} // namespace dragon
\ No newline at end of file
#include "operators/activation/selu_op.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void SEluOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::SElu<T, Context>(output(0)->count(), Xdata, Ydata);
}
template <class Context>
void SEluOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(SElu);
#ifdef WITH_CUDA
DEPLOY_CUDA(SElu);
#endif
OPERATOR_SCHEMA(SElu).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } });
template <class Context> template <typename T>
void SEluGradientOp<Context>::RunWithType() {
auto* Ydata = input(0).template data<T, Context>();
auto* dYdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::SEluGrad<T, Context>(output(0)->count(), dYdata, Ydata, dXdata);
}
template <class Context>
void SEluGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(SEluGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(SEluGradient);
#endif
OPERATOR_SCHEMA(SEluGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 }});
class GetSEluGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetSEluGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {O(0), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(SElu, GetSEluGradient);
} // namespace dragon
\ No newline at end of file
......@@ -11,7 +11,7 @@ void ROIAlignOp<Context>::RunWithType() {
pool_h, pool_w,
&input(0),
&input(1),
mask_h, mask_w,
mask,
output(0));
}
......@@ -20,10 +20,8 @@ void ROIAlignOp<Context>::RunOnDevice() {
vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w});
output(0)->Reshape(dims);
mask_h = ws()->CreateTensor("_t_" + anchor() + "_roi_align_mask_h");
mask_h->Reshape(dims);
mask_w = ws()->CreateTensor("_t_" + anchor() + "_roi_align_mask_w");
mask_w->Reshape(dims);
mask = ws()->CreateTensor("_t_" + anchor() + "_roi_align_mask");
mask->Reshape(dims);
if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -41,7 +39,7 @@ void ROIAlignGradientOp<Context>::RunWithType() {
pool_h, pool_w,
&input(-1),
&input(1),
mask_h, mask_w,
mask,
output(0));
}
......@@ -49,8 +47,7 @@ template <class Context>
void ROIAlignGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
mask_h = ws()->GetTensor("_t_" + anchor() + "_roi_align_mask_h");
mask_w = ws()->GetTensor("_t_" + anchor() + "_roi_align_mask_w");
mask = ws()->GetTensor("_t_" + anchor() + "_roi_align_mask");
if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -59,8 +56,7 @@ void ROIAlignGradientOp<Context>::RunOnDevice() {
template <class Context>
void ROIAlignGradientOp<Context>::CleanResource() {
Operator<Context>::CleanResource();
ws()->ReleaseBuffer(mask_h, "Common", true);
ws()->ReleaseBuffer(mask_w, "Common", true);
ws()->ReleaseBuffer(mask, "Common", true);
}
DEPLOY_CPU(ROIAlignGradient);
......
......@@ -72,6 +72,124 @@ template<> void EluGrad<float, CPUContext>(const int count,
}
}
/******************** activation.prelu ********************/
template<> void PRelu<float, CPUContext>(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* x,
const float* w,
float* y) {
if (channel_shared) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = std::max(x[i], float(0)) + w[0] * std::min(x[i], float(0));
}
} else {
if (data_format == "NCHW") {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
int c = (i / dim) % channels;
y[i] = std::max(x[i], float(0)) + w[c] * std::min(x[i], float(0));
}
} else if (data_format == "NHWC") {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
int c = i % channels;
y[i] = std::max(x[i], float(0)) + w[c] * std::min(x[i], float(0));
}
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
}
template<> void PReluGrad<float, CPUContext>(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* dy,
const float* x,
const float* w,
float* dx) {
if (channel_shared) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
dx[i] = dy[i] * ((x[i] > 0) + w[0] * (x[i] <= 0));
}
} else {
if (data_format == "NCHW") {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
int c = (i / dim) % channels;
dx[i] = dy[i] * ((x[i] > 0) + w[c] * (x[i] <= 0));
}
} else if (data_format == "NHWC") {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
int c = i % channels;
dx[i] = dy[i] * ((x[i] > 0) + w[c] * (x[i] <= 0));
}
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
}
template<> void PReluWGrad<float, CPUContext>(const int rows,
const int row_offset,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* dy,
const float* x,
const float* multiplier,
float* bcast_dw,
float* dw) {
const int cdim = channels * dim;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(cdim))
#endif
for (int i = 0; i < cdim; ++i) {
bcast_dw[i] = dy[i] * x[i] * (x[i] <= 0);
for (int n = 1; n < rows; n++) {
const int cur_idx = i + n * row_offset;
bcast_dw[i] += dy[cur_idx] * x[cur_idx] * (x[cur_idx] <= 0);
}
}
if (channel_shared) {
float w_sum = math::Dot<float, CPUContext>(channels * dim, bcast_dw, multiplier);
math::AddScalar<float, CPUContext>(1, w_sum, dw);
} else {
if (data_format == "NCHW") {
math::Gemv<float, CPUContext>(CblasNoTrans, channels, dim,
1.0,
bcast_dw, multiplier,
1.0,
dw);
} else if (data_format == "NHWC") {
math::Gemv<float, CPUContext>(CblasTrans, dim, channels,
1.0,
bcast_dw, multiplier,
1.0,
dw);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
}
/******************** activation.relu ********************/
template<> void Relu<float, CPUContext>(const int count,
......@@ -99,6 +217,32 @@ template<> void ReluGrad<float, CPUContext>(const int count,
}
}
/******************** activation.selu ********************/
template<> void SElu<float, CPUContext>(const int count,
const float* x,
float* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = 1.0507 * std::max(x[i], float(0))
+ 1.7581 * (std::exp(std::min(x[i], float(0))) - float(1));
}
}
template<> void SEluGrad<float, CPUContext>(const int count,
const float* dy,
const float* y,
float* dx) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
dx[i] = y[i] > 0 ? 1.0507 * dy[i] : (1.7581 + y[i]) * dy[i];
}
}
/******************** activation.sigmoid ********************/
template <typename T>
......@@ -1824,7 +1968,7 @@ template<> void ROIAlign<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w,
Tensor* x,
Tensor* roi,
Tensor* mask_h, Tensor* mask_w,
Tensor* mask,
Tensor* y) {
NOT_IMPLEMENTED;
}
......@@ -1833,7 +1977,7 @@ template<> void ROIAlignGrad<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w,
Tensor* dy,
Tensor* roi,
Tensor* mask_h, Tensor* mask_w,
Tensor* mask,
Tensor* dx) {
NOT_IMPLEMENTED;
}
......
......@@ -85,6 +85,219 @@ template<> void DropoutGrad<float, CUDAContext>(const int count,
CUDA_POST_KERNEL_CHECK;
}
/******************** activation.prelu ********************/
template <typename T>
__global__ void _PRelu(const int count,
const int channels,
const int dim,
const T* x,
const T* w,
T* y) {
CUDA_KERNEL_LOOP(idx, count) {
y[idx] = (x[idx] > 0) * x[idx] + (x[idx] < 0) * x[idx] * w[0];
}
}
template <typename T>
__global__ void _PReluNCHW(const int count,
const int channels,
const int dim,
const T* x,
const T* w,
T* y) {
CUDA_KERNEL_LOOP(idx, count) {
const int c = (idx / dim) % channels;
y[idx] = (x[idx] > 0) * x[idx] + (x[idx] < 0) * x[idx] * w[c];
}
}
template <typename T>
__global__ void _PReluNHWC(const int count,
const int channels,
const int dim,
const T* x,
const T* w,
T* y) {
CUDA_KERNEL_LOOP(idx, count) {
const int c = idx % channels;
y[idx] = (x[idx] > 0) * x[idx] + (x[idx] < 0) * x[idx] * w[c];
}
}
template<> void PRelu<float, CUDAContext>(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* x,
const float* w,
float* y) {
if (channel_shared) {
_PRelu<float> << < GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
channels,
dim,
x,
w,
y);
} else {
if (data_format == "NCHW") {
_PReluNCHW<float> << < GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
channels,
dim,
x,
w,
y);
} else if (data_format == "NHWC") {
_PReluNHWC<float> << < GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
channels,
dim,
x,
w,
y);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
CUDA_POST_KERNEL_CHECK;
}
template <typename T>
__global__ void _PReluGrad(const int count,
const int channels,
const int dim,
const T* dy,
const T* x,
const T* w,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
dx[idx] = dy[idx] * ((x[idx] > 0) + (x[idx] <= 0) * w[0]);
}
}
template <typename T>
__global__ void _PReluGradNCHW(const int count,
const int channels,
const int dim,
const T* dy,
const T* x,
const T* w,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
const int c = (idx / dim) % channels;
dx[idx] = dy[idx] * ((x[idx] > 0) + (x[idx] <= 0) * w[c]);
}
}
template <typename T>
__global__ void _PReluGradNHWC(const int count,
const int channels,
const int dim,
const T* dy,
const T* x,
const T* w,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
const int c = idx % channels;
dx[idx] = dy[idx] * ((x[idx] > 0) + (x[idx] <= 0) * w[c]);
}
}
template<> void PReluGrad<float, CUDAContext>(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* dy,
const float* x,
const float* w,
float* dx) {
if (channel_shared) {
_PReluGrad<float> << < GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
channels,
dim,
dy,
x,
w,
dx);
} else {
if (data_format == "NCHW") {
_PReluGradNCHW<float> << < GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
channels,
dim,
dy,
x,
w,
dx);
} else if (data_format == "NHWC") {
_PReluGradNHWC<float> << < GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
channels,
dim,
dy,
x,
w,
dx);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
CUDA_POST_KERNEL_CHECK;
}
template <typename T>
__global__ void _PReluWGradBcast(const int count,
const int rows,
const int row_offset,
const T* dy,
const T* x,
T* bcast_dw) {
CUDA_KERNEL_LOOP(idx, count) {
bcast_dw[idx] = dy[idx] * x[idx] * (x[idx] <= 0);
for (int n = 1; n < rows; n++) {
const int cur_idx = idx + n * row_offset;
bcast_dw[idx] += dy[cur_idx] * x[cur_idx] * (x[cur_idx] <= 0);
}
}
}
template<> void PReluWGrad<float, CUDAContext>(const int rows,
const int row_offset,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* dy,
const float* x,
const float* multiplier,
float* bcast_dw,
float* dw) {
const int cdim = channels * dim;
_PReluWGradBcast<float> << < GET_BLOCKS(cdim), CUDA_NUM_THREADS >> >(cdim,
rows,
row_offset,
dy,
x,
bcast_dw);
CUDA_POST_KERNEL_CHECK;
if (channel_shared) {
float w_sum = math::Dot<float, CUDAContext>(channels * dim, bcast_dw, multiplier);
math::AddScalar<float, CUDAContext>(1, w_sum, dw);
} else {
if (data_format == "NCHW") {
math::Gemv<float, CUDAContext>(CblasNoTrans, channels, dim,
1.0,
bcast_dw, multiplier,
1.0,
dw);
} else if (data_format == "NHWC") {
math::Gemv<float, CUDAContext>(CblasTrans, dim, channels,
1.0,
bcast_dw, multiplier,
1.0,
dw);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
}
/******************** activation.elu ********************/
template <typename T>
......@@ -109,7 +322,7 @@ __global__ void _EluGrad(const int count,
const float alpha,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
dx[idx] = y[idx] > 0 ? dy[idx] : dy[idx] * (y[idx] + alpha);
dx[idx] = dy[idx] * ((y[idx] > 0) + (alpha + y[idx]) * (y[idx] <= 0));
}
}
......@@ -191,6 +404,43 @@ template<> void ReluGrad<float, CUDAContext>(const int count,
CUDA_POST_KERNEL_CHECK;
}
/******************** activation.selu ********************/
template <typename T>
__global__ void _SElu(const int count, const T* x, T* y) {
CUDA_KERNEL_LOOP(idx, count) {
y[idx] = x[idx] > 0 ? 1.0507 * x[idx] : 1.7581 * (std::exp(x[idx]) - 1);
}
}
template<> void SElu<float, CUDAContext>(const int count,
const float* x,
float* y) {
_SElu<float> << < GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, x, y);
CUDA_POST_KERNEL_CHECK;
}
template <typename T>
__global__ void _SEluGrad(const int count,
const T* dy,
const T* y,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
dx[idx] = y[idx] > 0 ? 1.0507 * dy[idx] : (1.7581 + y[idx]) * dy[idx];
}
}
template<> void SEluGrad<float, CUDAContext>(const int count,
const float* dy,
const float* y,
float* dx) {
_SEluGrad<float> << < GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
dy,
y,
dx);
CUDA_POST_KERNEL_CHECK;
}
/******************** activation.sigmoid ********************/
template <typename T>
......@@ -3154,8 +3404,7 @@ __global__ void _ROIAlign(const int count,
const int pool_h, const int pool_w,
const T* x,
const T* roi,
T* mask_h,
T* mask_w,
T* mask,
T* y) {
CUDA_KERNEL_LOOP(idx, count) {
int pw = idx % pool_w;
......@@ -3164,76 +3413,60 @@ __global__ void _ROIAlign(const int count,
int n = idx / pool_w / pool_h / channels;
roi += n * 5;
int im_idx = roi[0];
T x1 = roi[1] * spatial_scale;
T y1 = roi[2] * spatial_scale;
T x2 = roi[3] * spatial_scale;
T y2 = roi[4] * spatial_scale;
T roi_height = max(y2 - y1, T(1));
T roi_width = max(x2 - x1, T(1));
const T bin_size_h = roi_height / pool_h;
const T bin_size_w = roi_width / pool_w;
T start_h = bin_size_h * ph;
T start_w = bin_size_w * pw;
T end_h = bin_size_h * (ph + 1);
T end_w = bin_size_w * (pw + 1);
start_h = max(start_h + y1, T(0));
start_w = max(start_w + x1, T(0));
end_h = max(end_h + y1, T(0));
end_w = max(end_w + x1, T(0));
start_h = min(start_h, T(height));
start_w = min(start_w, T(width));
end_h = min(end_h, T(height));
end_w = min(end_w, T(width));
bool is_empty = (end_h <= start_h) || (end_w <= start_w);
T max_val = is_empty ? 0 : -FLT_MAX;
T max_h = -1, max_w = -1;
x += ((im_idx * channels + c) * height * width);
for (T h = start_h; h < end_h; ++h) {
for (T w = start_w; w < end_w; ++w) {
if (int(ceil(h)) == height) h = height - 1;
if (int(ceil(w)) == width) w = width - 1;
int h1 = h, h2 = int(ceil(h));
int w1 = int(w), w2 = int(ceil(w));
T q11 = x[h1 * width + w1];
T q21 = x[h2 * width + w1];
T q12 = x[h1 * width + w2];
T q22 = x[h2 * width + w2];
T val;
if (h1 == h2) {
if (w1 == w2) val = q11;
else val = q11 * (w2 - w) + q12 * (w - w1);
} else if (w1 == w2) {
val = q11 * (h2 - h) + q21 * (h - h1);
} else {
val = q11 * (h2 - h) * (w2 - w) +
q12 * (h2 - h) * (w - w1) +
q21 * (h - h1) * (w2 - w) +
q22 * (h - h1) * (w - w1);
}
if (val > max_val) {
max_val = val;
max_h = h;
max_w = w;
}
} //end w
} // end h
y[idx] = max_val;
mask_h[idx] = max_h;
mask_w[idx] = max_w;
int roi_batch_ind = roi[0];
T roi_start_w = (roi[1]) * spatial_scale;
T roi_start_h = (roi[2]) * spatial_scale;
T roi_end_w = (roi[3]) * spatial_scale;
T roi_end_h = (roi[4]) * spatial_scale;
T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(1));
T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pool_h);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pool_w);
T hstart = static_cast<T>((ph)* bin_size_h);
T wstart = static_cast<T>((pw)* bin_size_w);
T hend = static_cast<T>((ph + 1) * bin_size_h);
T wend = static_cast<T>((pw + 1) * bin_size_w);
hstart = min(max(hstart + roi_start_h, static_cast<T>(0)), static_cast<T>(height));
hend = min(max(hend + roi_start_h, static_cast<T>(0)), static_cast<T>(height));
wstart = min(max(wstart + roi_start_w, static_cast<T>(0)), static_cast<T>(width));
wend = min(max(wend + roi_start_w, static_cast<T>(0)), static_cast<T>(width));
bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -FLT_MAX;
int maxidx = -1;
int x_idx = 0;
x += (roi_batch_ind * channels + c) * height * width;
T h_stride = (hend - hstart) / 3.0;
T w_stride = (wend - wstart) / 3.0;
for (T h = hstart + h_stride; h <= hend - h_stride + 0.01; h += max(h_stride, 0.01)) {
for (T w = wstart + w_stride; w <= wend - w_stride + 0.01; w += max(w_stride, 0.01)) {
x_idx++;
int hlow = min(max(static_cast<int>(floor(h)), 0), height - 1);
int hhigh = hlow + 1;
int wleft = min(max(static_cast<int>(floor(w)), 0), width - 1);
int wright = wleft + 1;
int topleft = hlow * width + wleft;
int topright = hlow * width + wright;
int bottomleft = hhigh * width + wleft;
int bottomright = hhigh * width + wright;
T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (h - hlow) / (hhigh - hlow);
T beta = (wleft == wright) ? static_cast<T>(0.5) : (w - wleft) / (wright - wleft);
T value = (1 - alpha) * (1 - beta) * x[topleft] + alpha * (1 - beta) * x[bottomleft]
+ (1 - alpha) * beta * x[topright] + alpha * beta * x[bottomright];
if (value > maxval) {
maxval = value;
maxidx = x_idx;
}
}
}
y[idx] = maxval;
mask[idx] = maxidx;
}
}
......@@ -3241,13 +3474,12 @@ template<> void ROIAlign<float, CUDAContext>(const float spatial_scale,
const int pool_h, const int pool_w,
Tensor* x,
Tensor* roi,
Tensor* mask_h, Tensor* mask_w,
Tensor* mask,
Tensor* y) {
auto* Xdata = x->data<float, CUDAContext>();
auto* Rdata = roi->data<float, CUDAContext>();
auto* Ydata = y->mutable_data<float, CUDAContext>();
auto* MHdata = mask_h->mutable_data<float, CUDAContext>();
auto* MWdata = mask_w->mutable_data<float, CUDAContext>();
auto* Mdata = mask->mutable_data<float, CUDAContext>();
TIndex channels = x->dim(1), count = y->count();
TIndex height = x->dim(2), width = x->dim(3);
_ROIAlign<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
......@@ -3257,7 +3489,7 @@ template<> void ROIAlign<float, CUDAContext>(const float spatial_scale,
pool_h, pool_w,
Xdata,
Rdata,
MHdata, MWdata,
Mdata,
Ydata);
CUDA_POST_KERNEL_CHECK;
}
......@@ -3271,70 +3503,84 @@ __global__ void _ROIAlignGrad(const int count,
const int pool_h, const int pool_w,
const T* dy,
const T* roi,
const T* mask_h, const T* mask_w,
const T* mask,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
int w = idx % width;
int h = (idx / width) % height;
int c = (idx / width / height) % channels;
int im_idx = idx / width / height / channels;
T diff = 0;
for (int n = 0; n < num_rois; n++) {
const T* cur_roi = roi + n * 5;
const int im_idx_spec = cur_roi[0];
// ignore wrong im_batch_idx
if (im_idx != im_idx_spec) continue;
T x1 = cur_roi[1] * spatial_scale;
T y1 = cur_roi[2] * spatial_scale;
T x2 = cur_roi[3] * spatial_scale;
T y2 = cur_roi[4] * spatial_scale;
const bool is_in = (w + 1 > x1 && w < x2 + 1 && h + 1 > y1 && h < y2 + 1);
if (!is_in) continue;
int n = idx / width / height / channels;
T roi_height = max(y2 - y1, T(1));
T roi_width = max(x2 - x1, T(1));
const T bin_size_h = roi_height / pool_h;
const T bin_size_w = roi_width / pool_w;
int start_ph = ceil((h - 1 - y1) / bin_size_h - 1);
int end_ph = ceil((h + 1 - y1) / bin_size_h);
int start_pw = ceil((w - 1 - x1) / bin_size_w - 1);
int end_pw = ceil((w + 1 - x1) / bin_size_w);
start_ph = min(max(start_ph, 0), pool_h);
start_pw = min(max(start_pw, 0), pool_w);
end_ph = min(max(end_ph, 0), pool_h);
end_pw = min(max(end_pw, 0), pool_w);
int y_offset = (n * channels + c) * pool_h * pool_w;
const T* dy_off = dy + y_offset;
const T* mask_h_off = mask_h + y_offset;
const T* mask_w_off = mask_w + y_offset;
for (int ph = start_ph; ph < end_ph; ++ph) {
for (int pw = start_pw; pw < end_pw; ++pw) {
T mh = mask_h_off[ph * pool_w + pw];
T mw = mask_w_off[ph * pool_w + pw];
int h1 = int(mh), h2 = int(ceil(mh));
int w1 = int(mw), w2 = int(ceil(mw));
if (h1 <= h && h <= h2 && w1 <= w && w <= w2) {
T gradient_factor = 1.0;
if (h == h1) gradient_factor *= h2 - mh;
else gradient_factor *= mh - h1;
if (w == w1) gradient_factor *= w2 - mw;
else gradient_factor *= mw - w1;
diff += dy_off[ph * pool_w + pw] * gradient_factor;
T gradient = 0;
for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
const T* offset_roi = roi + roi_n * 5;
int roi_batch_ind = offset_roi[0];
if (n != roi_batch_ind) continue;
T roi_start_w = (offset_roi[1]) * spatial_scale;
T roi_start_h = (offset_roi[2]) * spatial_scale;
T roi_end_w = (offset_roi[3]) * spatial_scale;
T roi_end_h = (offset_roi[4]) * spatial_scale;
const bool in_roi = (w > roi_start_w - 1.0 &&
w < roi_end_w + 1.0 &&
h > roi_start_h - 1.0
&& h < roi_end_h + 1.0);
if (!in_roi) continue;
int offset = (roi_n * channels + c) * pool_h * pool_w;
const T* offset_dy = dy + offset;
const T* offset_mask = mask + offset;
T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(1));
T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pool_h);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pool_w);
for (int ph = 0; ph < pool_h; ++ph) {
for (int pw = 0; pw < pool_w; ++pw) {
T hstart = static_cast<T>((ph)* bin_size_h);
T wstart = static_cast<T>((pw)* bin_size_w);
T hend = static_cast<T>((ph + 1) * bin_size_h);
T wend = static_cast<T>((pw + 1) * bin_size_w);
hstart = min(max(hstart + roi_start_h, static_cast<T>(0)), static_cast<T>(height));
hend = min(max(hend + roi_start_h, static_cast<T>(0)), static_cast<T>(height));
wstart = min(max(wstart + roi_start_w, static_cast<T>(0)), static_cast<T>(width));
wend = min(max(wend + roi_start_w, static_cast<T>(0)), static_cast<T>(width));
bool in_bin = (w > wstart - 1.0 &&
w < wend + 1.0 &&
h > hstart - 1.0
&& h < hend + 1.0);
if (!in_bin) continue;
const int pool_idx = ph * pool_w + pw;
int x_idx = 0;
T h_stride = (hend - hstart) / 3.0;
T w_stride = (wend - wstart) / 3.0;
for (T rh = hstart + h_stride; rh <= hend - h_stride + 0.01; rh += max(h_stride, 0.01)) {
for (T rw = wstart + w_stride; rw <= wend - w_stride + 0.01; rw += max(w_stride, 0.01)) {
x_idx++;
if (offset_mask[pool_idx] != x_idx) continue;
int hlow = min(max(static_cast<int>(floor(rh)), 0), height - 1);
int hhigh = hlow + 1;
int wleft = min(max(static_cast<int>(floor(rw)), 0), width - 1);
int wright = wleft + 1;
if (h != hlow && h != hhigh && w != wleft && w != wright) continue;
T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (rh - hlow) / (hhigh - hlow);
T beta = (wleft == wright) ? static_cast<T>(0.5) : (rw - wleft) / (wright - wleft);
if (h == hlow && w == wleft) gradient += offset_dy[pool_idx] * (1 - alpha) * (1 - beta);
else if (h == hlow && w == wright) gradient += offset_dy[pool_idx] * (1 - alpha) * beta;
else if (h == hhigh && w == wleft) gradient += offset_dy[pool_idx] * alpha * (1 - beta);
else if (h == hhigh && w == wright) gradient += offset_dy[pool_idx] * alpha * beta;
}
}
} // end pw
} // end ph
} // end n
dx[idx] = diff;
}
}
}
dx[idx] = gradient;
}
}
......@@ -3342,12 +3588,11 @@ template<> void ROIAlignGrad<float, CUDAContext>(const float spatial_scale,
const int pool_h, const int pool_w,
Tensor* dy,
Tensor* roi,
Tensor* mask_h, Tensor* mask_w,
Tensor* mask,
Tensor* dx) {
auto* dYdata = dy->data<float, CUDAContext>();
auto* Rdata = roi->data<float, CUDAContext>();
auto* MHdata = mask_h->data<float, CUDAContext>();
auto* MWdata = mask_w->data<float, CUDAContext>();
auto* Mdata = mask->data<float, CUDAContext>();
auto* dXdata = dx->mutable_data<float, CUDAContext>();
TIndex channels = dx->dim(1), count = dx->count();
TIndex height = dx->dim(2), width = dx->dim(3);
......@@ -3359,7 +3604,7 @@ template<> void ROIAlignGrad<float, CUDAContext>(const float spatial_scale,
pool_h, pool_w,
dYdata,
Rdata,
MHdata, MWdata,
Mdata,
dXdata);
CUDA_POST_KERNEL_CHECK;
}
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!