Commit 771e3d5a by Ting PAN

Add SELU & PReLU support

1 parent 4bef6a6b
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <algorithm>
#include <mutex> #include <mutex>
#include "core/types.h" #include "core/types.h"
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class PReluOp : public Operator<Context> {
public:
PReluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool channel_shared;
string data_format;
TIndex channels, dim;
};
template <class Context>
class PReluGradientOp : public Operator<Context> {
public:
PReluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
channel_shared(OperatorBase::GetSingleArg<bool>("channel_shared", false)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
bool channel_shared;
string data_format;
TIndex channels, dim;
Tensor* bcast_dw, *multiplier;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_PRELU_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class SEluOp : public Operator<Context> {
public:
SEluOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class SEluGradientOp : public Operator<Context> {
public:
SEluGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
DISABLE_SHARE_GRADIENT;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ACTIVATION_SELU_OP_H_
\ No newline at end of file
...@@ -29,7 +29,7 @@ class ROIAlignOp : public Operator<Context> { ...@@ -29,7 +29,7 @@ class ROIAlignOp : public Operator<Context> {
protected: protected:
int pool_h, pool_w; int pool_h, pool_w;
float spatial_scale; float spatial_scale;
Tensor* mask_h, *mask_w; Tensor* mask;
}; };
template <class Context> template <class Context>
...@@ -51,7 +51,7 @@ class ROIAlignGradientOp : public Operator<Context> { ...@@ -51,7 +51,7 @@ class ROIAlignGradientOp : public Operator<Context> {
protected: protected:
int pool_h, pool_w; int pool_h, pool_w;
float spatial_scale; float spatial_scale;
Tensor* mask_h, *mask_w; Tensor* mask;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -49,6 +49,42 @@ void EluGrad(const int count, ...@@ -49,6 +49,42 @@ void EluGrad(const int count,
const float alpha, const float alpha,
T* dx); T* dx);
/******************** activation.prelu ********************/
template <typename T, class Context>
void PRelu(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const T* x,
const T* w,
T* y);
template <typename T, class Context>
void PReluGrad(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const T* dy,
const T* x,
const T* w,
T* dx);
template <typename T, class Context>
void PReluWGrad(const int rows,
const int row_offset,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const T* dy,
const T* x,
const T* multiplier,
T* bcast_dw,
T* dw);
/******************** activation.relu ********************/ /******************** activation.relu ********************/
template <typename T, class Context> template <typename T, class Context>
...@@ -61,6 +97,14 @@ void ReluGrad(const int count, ...@@ -61,6 +97,14 @@ void ReluGrad(const int count,
const float slope, const float slope,
T* dx); T* dx);
/******************** activation.selu ********************/
template <typename T, class Context>
void SElu(const int count, const T* x, T* y);
template <typename T, class Context>
void SEluGrad(const int count, const T* dy, const T* y, T* dx);
/******************** activation.sigmoid ********************/ /******************** activation.sigmoid ********************/
template <typename T, class Context> template <typename T, class Context>
...@@ -745,8 +789,7 @@ void ROIAlign(const float spatial_scale, ...@@ -745,8 +789,7 @@ void ROIAlign(const float spatial_scale,
const int pool_w, const int pool_w,
Tensor* x, Tensor* x,
Tensor* roi, Tensor* roi,
Tensor* mask_h, Tensor* mask,
Tensor* mask_y,
Tensor* y); Tensor* y);
template <typename T, class Context> template <typename T, class Context>
...@@ -755,8 +798,7 @@ void ROIAlignGrad(const float spatial_scale, ...@@ -755,8 +798,7 @@ void ROIAlignGrad(const float spatial_scale,
const int pool_w, const int pool_w,
Tensor* dy, Tensor* dy,
Tensor* roi, Tensor* roi,
Tensor* mask_h, Tensor* mask,
Tensor* mask_y,
Tensor* dx); Tensor* dx);
} // namespace kernel } // namespace kernel
......
...@@ -242,4 +242,25 @@ def SetLoggingLevel(level): ...@@ -242,4 +242,25 @@ def SetLoggingLevel(level):
'WARNING': logging.WARNING, 'WARNING': logging.WARNING,
'ERROR': logging.ERROR, 'ERROR': logging.ERROR,
'FATAL': logging.CRITICAL 'FATAL': logging.CRITICAL
}[level]) }[level])
\ No newline at end of file
def SetLoggingFile(log_file):
"""Redirect the logging into the specific file.
Parameters
----------
log_file : str
The file for logging.
Notes
-----
The function will disable all possible logging at the terminal.
"""
global logger
new_logger = logging.getLogger('dragon_filehandler')
new_logger.setLevel(logger.level)
file_handler = logging.FileHandler(log_file, mode="w", encoding="UTF-8")
new_logger.addHandler(file_handler)
logger = new_logger
\ No newline at end of file
...@@ -14,7 +14,6 @@ import dragon.protos.dragon_pb2 as pb ...@@ -14,7 +14,6 @@ import dragon.protos.dragon_pb2 as pb
import numpy as np import numpy as np
import os import os
from dragon import * from dragon import *
from dragon.config import logger
from google.protobuf.message import Message from google.protobuf.message import Message
from six.moves import range as xrange from six.moves import range as xrange
...@@ -339,7 +338,7 @@ def LogMetaGraph(meta_graph): ...@@ -339,7 +338,7 @@ def LogMetaGraph(meta_graph):
None None
""" """
from dragon.config import option from dragon.config import option, logger
if option['log_meta_graph']: if option['log_meta_graph']:
logger.info(meta_graph) logger.info(meta_graph)
...@@ -358,11 +357,12 @@ def GetOptimizedGraph(meta_graph): ...@@ -358,11 +357,12 @@ def GetOptimizedGraph(meta_graph):
The definition of optimized graph. The definition of optimized graph.
""" """
from dragon.config import logger
graph_name = meta_graph.name graph_name = meta_graph.name
graph_tensor = 'GraphDef_' + graph_name graph_tensor = 'GraphDef_' + graph_name
if not HasTensorCC(graph_tensor): if not HasTensorCC(graph_tensor):
logger.info('graph: {} does not exist, ignore printing....'.format(graph_name)) logger.info('Graph({}) does not exist, ignore printing....'.format(graph_name))
return return
opt_graph_def = pb.GraphDef() opt_graph_def = pb.GraphDef()
...@@ -383,7 +383,7 @@ def LogOptimizedGraph(meta_graph): ...@@ -383,7 +383,7 @@ def LogOptimizedGraph(meta_graph):
None None
""" """
from dragon.config import option from dragon.config import option, logger
if option['log_optimized_graph']: if option['log_optimized_graph']:
optimized_graph = GetOptimizedGraph(meta_graph) optimized_graph = GetOptimizedGraph(meta_graph)
logger.info(optimized_graph) logger.info(optimized_graph)
...@@ -404,7 +404,7 @@ def ExportMetaGraph(meta_graph): ...@@ -404,7 +404,7 @@ def ExportMetaGraph(meta_graph):
None None
""" """
from dragon.config import option from dragon.config import option, logger
if option['export_meta_graph']: if option['export_meta_graph']:
if not os.path.exists(option['export_meta_graph']): if not os.path.exists(option['export_meta_graph']):
try: try:
...@@ -445,6 +445,7 @@ def Snapshot(tensors, filename, prefix='', suffix='.bin', format='default'): ...@@ -445,6 +445,7 @@ def Snapshot(tensors, filename, prefix='', suffix='.bin', format='default'):
Available formats: ['default', 'caffe']. Available formats: ['default', 'caffe'].
""" """
from dragon.config import logger
filepath = prefix + filename + suffix filepath = prefix + filename + suffix
if mpi.Is_Init(): if mpi.Is_Init():
if not mpi.AllowSnapshot(): return if not mpi.AllowSnapshot(): return
...@@ -488,6 +489,7 @@ def Restore(filepath, format='default'): ...@@ -488,6 +489,7 @@ def Restore(filepath, format='default'):
Available formats: ['default', 'caffe']. Available formats: ['default', 'caffe'].
""" """
from dragon.config import logger
assert os.path.exists(filepath), 'model of path({}) does not exist.'.format(filepath) assert os.path.exists(filepath), 'model of path({}) does not exist.'.format(filepath)
if format == 'default': if format == 'default':
content = cPickle.load(open(filepath, 'rb')) content = cPickle.load(open(filepath, 'rb'))
......
...@@ -22,6 +22,7 @@ List Brief ...@@ -22,6 +22,7 @@ List Brief
`LogOptimizedGraph`_ Enable to log optimized graph globally. `LogOptimizedGraph`_ Enable to log optimized graph globally.
`ExportMetaGraph`_ Enable to export all runnable meta graphs into text files. `ExportMetaGraph`_ Enable to export all runnable meta graphs into text files.
`SetLoggingLevel`_ Set the minimum level of Logging. `SetLoggingLevel`_ Set the minimum level of Logging.
`SetLoggingFile`_ Redirect the logging into the specific file.
==================== ============================================================================= ==================== =============================================================================
API Reference API Reference
...@@ -40,4 +41,5 @@ API Reference ...@@ -40,4 +41,5 @@ API Reference
.. _LogMetaGraph: #dragon.config.LogMetaGraph .. _LogMetaGraph: #dragon.config.LogMetaGraph
.. _LogOptimizedGraph: #dragon.config.LogOptimizedGraph .. _LogOptimizedGraph: #dragon.config.LogOptimizedGraph
.. _ExportMetaGraph: #dragon.config.ExportMetaGraph .. _ExportMetaGraph: #dragon.config.ExportMetaGraph
.. _SetLoggingLevel: #dragon.config.SetLoggingLevel .. _SetLoggingLevel: #dragon.config.SetLoggingLevel
\ No newline at end of file .. _SetLoggingFile: #dragon.config.SetLoggingFile
\ No newline at end of file
...@@ -12,11 +12,15 @@ ...@@ -12,11 +12,15 @@
.. |tanh_function| mathmacro:: \, y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}} .. |tanh_function| mathmacro:: \, y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
.. |relu_function| mathmacro:: \, y = \max(0, x) .. |relu_function| mathmacro:: \, y = \left\{ \begin{array} \\ x & & (x > 0) \\ 0 & & (x <= 0) \\ \end{array} \right.
.. |lrelu_function| mathmacro:: \, y = \left\{ \begin{array} \\ x & & (x > 0) \\ Slope * x & & (x <= 0) \\ \end{array} \right.
.. |prelu_function| mathmacro:: \, y_{i} = \left\{ \begin{array} \\ x_{i} & & (x_{i} > 0) \\ \alpha_{i} * x_{i} & & (x <= 0) \\ \end{array} \right.
.. |elu_function| mathmacro:: \, y = \left\{ \begin{array} \\ x & & (x > 0) \\ Alpha * (e^{x} - 1) & & (x <= 0) \\ \end{array} \right. .. |elu_function| mathmacro:: \, y = \left\{ \begin{array} \\ x & & (x > 0) \\ Alpha * (e^{x} - 1) & & (x <= 0) \\ \end{array} \right.
.. |leaky_relu_function| mathmacro:: \, y = \max(x, 0) + Slope * \min(x, 0) .. |selu_function| mathmacro:: \, y = 1.0507 \left\{ \begin{array} \\ x & & (x > 0) \\ 1.6733 * (e^{x} - 1) & & (x <= 0) \\ \end{array} \right.
.. |dropout_function| mathmacro:: \, y = x * Bernoulli(p=1 - prob) .. |dropout_function| mathmacro:: \, y = x * Bernoulli(p=1 - prob)
......
...@@ -61,7 +61,9 @@ List Brief ...@@ -61,7 +61,9 @@ List Brief
`Tanh`_ Tanh function. `Tanh`_ Tanh function.
`Relu`_ Rectified Linear Unit function, introduces by `[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_. `Relu`_ Rectified Linear Unit function, introduces by `[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_.
`LRelu`_ Leaky Rectified Linear Unit function. `LRelu`_ Leaky Rectified Linear Unit function.
`PRelu`_ Parametric Rectified Linear Unit function, introduces by `[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
`Elu`_ Exponential Linear Unit function, introduces by `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_. `Elu`_ Exponential Linear Unit function, introduces by `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
`SElu`_ Scaled Exponential Linear Unit function, introduces by `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
`Softmax`_ Softmax function. `Softmax`_ Softmax function.
`Dropout`_ Randomly set a unit into zero, introduced by `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_. `Dropout`_ Randomly set a unit into zero, introduced by `[Srivastava et.al, 2014] <http://jmlr.org/papers/v15/srivastava14a.html>`_.
=============== ====================================================================== =============== ======================================================================
...@@ -209,7 +211,9 @@ List Brief ...@@ -209,7 +211,9 @@ List Brief
.. _Tanh: operators/activation.html#dragon.operators.activation.Tanh .. _Tanh: operators/activation.html#dragon.operators.activation.Tanh
.. _Relu: operators/activation.html#dragon.operators.activation.Relu .. _Relu: operators/activation.html#dragon.operators.activation.Relu
.. _LRelu: operators/activation.html#dragon.operators.activation.LRelu .. _LRelu: operators/activation.html#dragon.operators.activation.LRelu
.. _PRelu: operators/activation.html#dragon.operators.activation.PRelu
.. _Elu: operators/activation.html#dragon.operators.activation.Elu .. _Elu: operators/activation.html#dragon.operators.activation.Elu
.. _SElu: operators/activation.html#dragon.operators.activation.SElu
.. _Softmax: operators/activation.html#dragon.operators.activation.Softmax .. _Softmax: operators/activation.html#dragon.operators.activation.Softmax
.. _Dropout: operators/activation.html#dragon.operators.activation.Dropout .. _Dropout: operators/activation.html#dragon.operators.activation.Dropout
......
...@@ -42,7 +42,9 @@ Neuron ...@@ -42,7 +42,9 @@ Neuron
List Brief List Brief
==================== ============================================================================= ==================== =============================================================================
`ReLULayer`_ The implementation of ``ReLULayer``. `ReLULayer`_ The implementation of ``ReLULayer``.
`PReLULayer`_ The implementation of ``PReLULayer``.
`ELULayer`_ The implementation of ``ELULayer``. `ELULayer`_ The implementation of ``ELULayer``.
`SELULayer`_ The implementation of ``SELULayer``.
`SigmoidLayer`_ The implementation of ``SigmoidLayer``. `SigmoidLayer`_ The implementation of ``SigmoidLayer``.
`TanHLayer`_ The implementation of ``TanHLayer``. `TanHLayer`_ The implementation of ``TanHLayer``.
`DropoutLayer`_ The implementation of ``DropoutLayer``. `DropoutLayer`_ The implementation of ``DropoutLayer``.
...@@ -154,7 +156,9 @@ API Reference ...@@ -154,7 +156,9 @@ API Reference
.. _BilinearResizeLayer: #dragon.vm.caffe.layers.vision.BilinearResizeLayer .. _BilinearResizeLayer: #dragon.vm.caffe.layers.vision.BilinearResizeLayer
.. _ReLULayer: #dragon.vm.caffe.layers.neuron.ReLULayer .. _ReLULayer: #dragon.vm.caffe.layers.neuron.ReLULayer
.. _PReLULayer: #dragon.vm.caffe.layers.neuron.PReLULayer
.. _ELULayer: #dragon.vm.caffe.layers.neuron.ELULayer .. _ELULayer: #dragon.vm.caffe.layers.neuron.ELULayer
.. _SELULayer: #dragon.vm.caffe.layers.neuron.SELULayer
.. _SigmoidLayer: #dragon.vm.caffe.layers.neuron.SigmoidLayer .. _SigmoidLayer: #dragon.vm.caffe.layers.neuron.SigmoidLayer
.. _TanHLayer: #dragon.vm.caffe.layers.neuron.TanHLayer .. _TanHLayer: #dragon.vm.caffe.layers.neuron.TanHLayer
.. _DropoutLayer: #dragon.vm.caffe.layers.neuron.DropoutLayer .. _DropoutLayer: #dragon.vm.caffe.layers.neuron.DropoutLayer
...@@ -232,6 +236,8 @@ API Reference ...@@ -232,6 +236,8 @@ API Reference
.. _ResizeParameter.fy: https://github.com/neopenx/Dragon/tree/master/Dragon/python/dragon/vm/caffe/proto/caffe.proto#L1466 .. _ResizeParameter.fy: https://github.com/neopenx/Dragon/tree/master/Dragon/python/dragon/vm/caffe/proto/caffe.proto#L1466
.. _ReLUParameter.negative_slope: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L1000 .. _ReLUParameter.negative_slope: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L1000
.. _PReLUParameter.filler: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L1409
.. _PReLUParameter.channel_shared: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L1411
.. _ELUParameter.alpha: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L717 .. _ELUParameter.alpha: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L717
.. _DropoutParameter.dropout_ratio: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L676 .. _DropoutParameter.dropout_ratio: https://github.com/BVLC/caffe/blob/effcdb0b62410b2a6a54f18f23cf90733a115673/src/caffe/proto/caffe.proto#L676
.. _DropoutParameter.scale_train: https://github.com/rbgirshick/caffe-fast-rcnn/blob/0dcd397b29507b8314e252e850518c5695efbb83/src/caffe/proto/caffe.proto#L638 .. _DropoutParameter.scale_train: https://github.com/rbgirshick/caffe-fast-rcnn/blob/0dcd397b29507b8314e252e850518c5695efbb83/src/caffe/proto/caffe.proto#L638
......
...@@ -8,8 +8,6 @@ import numpy as np ...@@ -8,8 +8,6 @@ import numpy as np
from multiprocessing import Process from multiprocessing import Process
from six.moves import range as xrange from six.moves import range as xrange
from dragon.config import logger
from .utils import GetProperty from .utils import GetProperty
class BlobFetcher(Process): class BlobFetcher(Process):
...@@ -40,6 +38,7 @@ class BlobFetcher(Process): ...@@ -40,6 +38,7 @@ class BlobFetcher(Process):
self.daemon = True self.daemon = True
def cleanup(): def cleanup():
from dragon.config import logger
logger.info('Terminating BlobFetcher......') logger.info('Terminating BlobFetcher......')
self.terminate() self.terminate()
self.join() self.join()
......
...@@ -10,7 +10,6 @@ from multiprocessing import Queue ...@@ -10,7 +10,6 @@ from multiprocessing import Queue
from six.moves import range as xrange from six.moves import range as xrange
import dragon.core.mpi as mpi import dragon.core.mpi as mpi
from dragon.config import logger
from .data_reader import DataReader from .data_reader import DataReader
from .data_transformer import DataTransformer from .data_transformer import DataTransformer
...@@ -171,6 +170,7 @@ class DataBatch(object): ...@@ -171,6 +170,7 @@ class DataBatch(object):
""" """
Print I/O Information. Print I/O Information.
""" """
from dragon.config import logger
logger.info('---------------------------------------------------------') logger.info('---------------------------------------------------------')
logger.info('BatchReader, Using config:') logger.info('BatchReader, Using config:')
params = {'prefetching': self._prefetch, params = {'prefetching': self._prefetch,
......
...@@ -9,7 +9,6 @@ import numpy.random as npr ...@@ -9,7 +9,6 @@ import numpy.random as npr
from multiprocessing import Process from multiprocessing import Process
import dragon.config as config import dragon.config as config
from dragon.config import logger
from dragon.tools.db import LMDB from dragon.tools.db import LMDB
from .utils import GetProperty from .utils import GetProperty
...@@ -55,6 +54,7 @@ class DataReader(Process): ...@@ -55,6 +54,7 @@ class DataReader(Process):
self.daemon = True self.daemon = True
def cleanup(): def cleanup():
from dragon.config import logger
logger.info('Terminating DataReader......') logger.info('Terminating DataReader......')
self.terminate() self.terminate()
self.join() self.join()
......
...@@ -9,7 +9,6 @@ import numpy.random as npr ...@@ -9,7 +9,6 @@ import numpy.random as npr
from multiprocessing import Process from multiprocessing import Process
import dragon.config as config import dragon.config as config
from dragon.config import logger
import dragon.vm.caffe.proto.caffe_pb2 as pb import dragon.vm.caffe.proto.caffe_pb2 as pb
from .utils import GetProperty from .utils import GetProperty
...@@ -72,6 +71,7 @@ class DataTransformer(Process): ...@@ -72,6 +71,7 @@ class DataTransformer(Process):
self.daemon = True self.daemon = True
def cleanup(): def cleanup():
from dragon.config import logger
logger.info('Terminating DataTransformer......') logger.info('Terminating DataTransformer......')
self.terminate() self.terminate()
self.join() self.join()
......
...@@ -44,7 +44,7 @@ def LRelu(inputs, slope=0.2, **kwargs): ...@@ -44,7 +44,7 @@ def LRelu(inputs, slope=0.2, **kwargs):
Returns Returns
------- -------
Tensor Tensor
The output tensor, calculated as: |leaky_relu_function|. The output tensor, calculated as: |lrelu_function|.
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
...@@ -58,6 +58,35 @@ def LRelu(inputs, slope=0.2, **kwargs): ...@@ -58,6 +58,35 @@ def LRelu(inputs, slope=0.2, **kwargs):
return output return output
def PRelu(inputs, channel_shared=False, data_format='NCHW', **kwargs):
"""Parametric Rectified Linear Unit function, introduces by `[He et.al, 2015] <https://arxiv.org/abs/1502.01852>`_.
Parameters
----------
inputs : list of Tensor
The input and trainable parameter(slope).
channel_shared : boolean
Whether to share the parameter(slope) across channels.
data_format : str
The data format, ``NCHW`` or ``NHWC``.
Returns
-------
Tensor
The output tensor, calculated as: |prelu_function|
"""
CheckInputs(inputs, 2)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='PRelu', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
return output
def Elu(inputs, alpha=1.0, **kwargs): def Elu(inputs, alpha=1.0, **kwargs):
"""Exponential Linear Unit function, introduces by `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_. """Exponential Linear Unit function, introduces by `[Clevert et.al, 2015] <https://arxiv.org/abs/1511.07289>`_.
...@@ -65,6 +94,8 @@ def Elu(inputs, alpha=1.0, **kwargs): ...@@ -65,6 +94,8 @@ def Elu(inputs, alpha=1.0, **kwargs):
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
alpha : float
The alpha.
Returns Returns
------- -------
...@@ -83,6 +114,31 @@ def Elu(inputs, alpha=1.0, **kwargs): ...@@ -83,6 +114,31 @@ def Elu(inputs, alpha=1.0, **kwargs):
return output return output
def SElu(inputs, **kwargs):
"""Scaled Exponential Linear Unit function, introduces by `[Klambauer et.al, 2017] <https://arxiv.org/abs/1706.02515>`_.
Parameters
----------
inputs : Tensor
The input tensor.
Returns
-------
Tensor
The output tensor, calculated as: |selu_function|
"""
CheckInputs(inputs, 1)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='SElu', **arguments)
if inputs.shape is not None:
output.shape = inputs.shape[:]
return output
def Sigmoid(inputs, **kwargs): def Sigmoid(inputs, **kwargs):
"""Sigmoid function. """Sigmoid function.
......
...@@ -232,7 +232,7 @@ def ROIPooling(inputs, pool_h, pool_w, spatial_scale, **kwargs): ...@@ -232,7 +232,7 @@ def ROIPooling(inputs, pool_h, pool_w, spatial_scale, **kwargs):
return Tensor.CreateOperator(nout=1, op_type='ROIPooling', **arguments) return Tensor.CreateOperator(nout=1, op_type='ROIPooling', **arguments)
def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **arguments): def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
"""Max ROIAlign, introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_. """Max ROIAlign, introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
The first dimension of input must be ``1``. The first dimension of input must be ``1``.
......
...@@ -50,7 +50,9 @@ Sigmoid = act.Sigmoid ...@@ -50,7 +50,9 @@ Sigmoid = act.Sigmoid
Tanh = act.Tanh Tanh = act.Tanh
Relu = act.Relu Relu = act.Relu
LRelu = act.LRelu LRelu = act.LRelu
PRelu = act.PRelu
Elu = act.Elu Elu = act.Elu
SElu = act.SElu
Softmax = act.Softmax Softmax = act.Softmax
Dropout = act.Dropout Dropout = act.Dropout
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
import numpy as np import numpy as np
import pprint import pprint
from dragon.config import logger
import dragon.core.workspace as ws import dragon.core.workspace as ws
from dragon.core.tensor import Tensor from dragon.core.tensor import Tensor
...@@ -85,6 +84,7 @@ class BaseUpdater(object): ...@@ -85,6 +84,7 @@ class BaseUpdater(object):
""" """
Print Updater Information. Print Updater Information.
""" """
from dragon.config import logger
logger.info('---------------------------------------------------------') logger.info('---------------------------------------------------------')
logger.info('Optimizer: {}, Using config:'.format(self._type.split('Update')[0])) logger.info('Optimizer: {}, Using config:'.format(self._type.split('Update')[0]))
pprint.pprint(self._hyper_params) pprint.pprint(self._hyper_params)
......
...@@ -17,7 +17,9 @@ from .vision import ConvolutionLayer, \ ...@@ -17,7 +17,9 @@ from .vision import ConvolutionLayer, \
BilinearResizeLayer BilinearResizeLayer
from .neuron import ReLULayer, \ from .neuron import ReLULayer, \
PReLULayer, \
ELULayer, \ ELULayer, \
SELULayer, \
DropoutLayer, \ DropoutLayer, \
TanHLayer, \ TanHLayer, \
PowerLayer PowerLayer
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
import dragon.ops as ops import dragon.ops as ops
from dragon.core.tensor import Tensor
from ..layer import Layer from ..layer import Layer
...@@ -29,6 +30,35 @@ class ReLULayer(Layer): ...@@ -29,6 +30,35 @@ class ReLULayer(Layer):
return ops.Relu(input, **self._param) return ops.Relu(input, **self._param)
class PReLULayer(Layer):
"""The implementation of ``PReLULayer``.
Parameters
----------
filler : FillerParameter
The filler of parameter(slope). Refer `PReLUParameter.filler`_.
channel_shared : boolean
Whether to share the parameter across channels. Refer `PReLUParameter.channel_shared`_.
"""
def __init__(self, LayerParameter):
super(PReLULayer, self).__init__(LayerParameter)
param = LayerParameter.prelu_param
self._param = {'channel_shared': param.channel_shared,
'data_format': 'NCHW'}
slope = Tensor(LayerParameter.name + '@param0')
slope_diff = Tensor(LayerParameter.name + '@param0_grad')
if param.HasField('filler'):
self.Fill(slope, param, 'filler')
else:
slope.Constant(value=0.25)
self._blobs.append({'data': slope, 'diff': slope_diff})
def Setup(self, bottom):
super(PReLULayer, self).Setup(bottom)
return ops.PRelu(bottom + [blob['data'] for blob in self._blobs], **self._param)
class ELULayer(Layer): class ELULayer(Layer):
"""The implementation of ``ELULayer``. """The implementation of ``ELULayer``.
...@@ -41,7 +71,7 @@ class ELULayer(Layer): ...@@ -41,7 +71,7 @@ class ELULayer(Layer):
def __init__(self, LayerParameter): def __init__(self, LayerParameter):
super(ELULayer, self).__init__(LayerParameter) super(ELULayer, self).__init__(LayerParameter)
param = LayerParameter.elu_param param = LayerParameter.elu_param
self._param = {'alpha': param.alpha} self._param = {'alpha': float(param.alpha)}
def Setup(self, bottom): def Setup(self, bottom):
super(ELULayer, self).Setup(bottom) super(ELULayer, self).Setup(bottom)
...@@ -49,6 +79,19 @@ class ELULayer(Layer): ...@@ -49,6 +79,19 @@ class ELULayer(Layer):
return ops.Elu(input, **self._param) return ops.Elu(input, **self._param)
class SELULayer(Layer):
"""
The implementation of ``SELULayer``.
"""
def __init__(self, LayerParameter):
super(SELULayer, self).__init__(LayerParameter)
def Setup(self, bottom):
super(SELULayer, self).Setup(bottom)
input = bottom[0] if isinstance(bottom, list) else bottom
return ops.SElu(input, **self._param)
class SigmoidLayer(Layer): class SigmoidLayer(Layer):
""" """
The implementation of ``SigmoidLayer``. The implementation of ``SigmoidLayer``.
......
...@@ -14,7 +14,6 @@ import dragon.tools.summary_writer as sw ...@@ -14,7 +14,6 @@ import dragon.tools.summary_writer as sw
import dragon.vm.theano as theano import dragon.vm.theano as theano
from dragon.vm.caffe.proto import caffe_pb2 as pb from dragon.vm.caffe.proto import caffe_pb2 as pb
from dragon.config import logger
from dragon.vm.caffe.misc import root_solver from dragon.vm.caffe.misc import root_solver
from dragon.vm.caffe.net import Net from dragon.vm.caffe.net import Net
from google.protobuf.text_format import Parse from google.protobuf.text_format import Parse
...@@ -172,6 +171,7 @@ class Solver(object): ...@@ -172,6 +171,7 @@ class Solver(object):
The implementation of `GetLearningRate(solver.cpp, L27)`_. The implementation of `GetLearningRate(solver.cpp, L27)`_.
""" """
from dragon.config import logger
policy = self._param.lr_policy policy = self._param.lr_policy
if policy == "step": if policy == "step":
...@@ -232,6 +232,7 @@ class Solver(object): ...@@ -232,6 +232,7 @@ class Solver(object):
The implementation of `Test(solver.cpp, L328)`_. The implementation of `Test(solver.cpp, L328)`_.
""" """
from dragon.config import logger
test_score = [] test_score = []
output_id = [] output_id = []
test_iter = self._param.test_iter[test_idx] test_iter = self._param.test_iter[test_idx]
...@@ -278,6 +279,7 @@ class Solver(object): ...@@ -278,6 +279,7 @@ class Solver(object):
The implementation of `Step(solver.cpp, L180)`_. The implementation of `Step(solver.cpp, L180)`_.
""" """
from dragon.config import logger
start_iter = self._iter; stop_iter = self._iter + iters start_iter = self._iter; stop_iter = self._iter + iters
loss_vec = []; smoothed_loss = 0 loss_vec = []; smoothed_loss = 0
tic = time.time() tic = time.time()
......
#include "operators/activation/prelu_op.h"
#include "core/workspace.h"
#include "utils/filler.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void PReluOp<Context>::RunWithType() {
if (channel_shared) {
TENSOR_FILL(input(1), vector<TIndex>(1, 1));
} else {
TENSOR_FILL(input(1), vector<TIndex>(1, input(0).dim(1)));
}
auto* Xdata = input(0).template data<T, Context>();
auto* Wdata = input(1).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::PRelu<T, Context>(output(0)->count(),
channels,
dim,
channel_shared,
data_format,
Xdata,
Wdata,
Ydata);
}
template <class Context>
void PReluOp<Context>::RunOnDevice() {
if (data_format == "NCHW") {
channels = input(0).dim(1);
dim = input(0).count(2);
} else {
channels = input(0).dim(-1);
dim = input(0).count() / channels;
}
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(PRelu);
#ifdef WITH_CUDA
DEPLOY_CUDA(PRelu);
#endif
OPERATOR_SCHEMA(PRelu).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void PReluGradientOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* dYdata = input(-1).template data<T, Context>();
if (output(1)->name() != "ignore") {
INIT_MULTIPLIER(multiplier, channels * dim);
bcast_dw = ws()->GetBuffer();
bcast_dw->Reshape(vector<TIndex>(1, channels * dim));
auto* dWdata = output(1)->template mutable_data<T, Context>();
auto* dWBdata = bcast_dw->template mutable_data<T, Context>();
kernel::PReluWGrad<T, Context>(input(0).dim(0),
input(0).count(1),
channels,
dim,
channel_shared,
data_format,
dYdata,
Xdata,
multiplier->template data<T, Context>(),
dWBdata,
dWdata);
ws()->ReleaseBuffer(bcast_dw);
}
if (output(0)->name() != "ignore") {
auto* Wdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::PReluGrad<T, Context>(output(0)->count(),
channels,
dim,
channel_shared,
data_format,
dYdata,
Xdata,
Wdata,
dXdata);
}
}
template <class Context>
void PReluGradientOp<Context>::RunOnDevice() {
if (data_format == "NCHW") {
channels = input(0).dim(1);
dim = input(0).count(2);
} else {
channels = input(0).dim(-1);
dim = input(0).count() / channels;
}
output(0)->ReshapeLike(input(0));
output(1)->ReshapeLike(input(1));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(PReluGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(PReluGradient);
#endif
OPERATOR_SCHEMA(PReluGradient).NumInputs(3).NumOutputs(2);
class GetPReluGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetPReluGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0), GI(1)});
}
};
REGISTER_GRADIENT(PRelu, GetPReluGradient);
} // namespace dragon
\ No newline at end of file
#include "operators/activation/selu_op.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void SEluOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::SElu<T, Context>(output(0)->count(), Xdata, Ydata);
}
template <class Context>
void SEluOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(SElu);
#ifdef WITH_CUDA
DEPLOY_CUDA(SElu);
#endif
OPERATOR_SCHEMA(SElu).NumInputs(1).NumOutputs(1).Inplace({ { 0, 0 } });
template <class Context> template <typename T>
void SEluGradientOp<Context>::RunWithType() {
auto* Ydata = input(0).template data<T, Context>();
auto* dYdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::SEluGrad<T, Context>(output(0)->count(), dYdata, Ydata, dXdata);
}
template <class Context>
void SEluGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(SEluGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(SEluGradient);
#endif
OPERATOR_SCHEMA(SEluGradient).NumInputs(2).NumOutputs(1).Inplace({ { 1, 0 }});
class GetSEluGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetSEluGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {O(0), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(SElu, GetSEluGradient);
} // namespace dragon
\ No newline at end of file
...@@ -33,11 +33,11 @@ void BiasAddOp<Context>::RunOnDevice() { ...@@ -33,11 +33,11 @@ void BiasAddOp<Context>::RunOnDevice() {
if (data_format == "NCHW") { if (data_format == "NCHW") {
if (input(0).template IsType<float>()) NCHWRunWithType<float>(); if (input(0).template IsType<float>()) NCHWRunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
else if (data_format == "NHWC") { else if (data_format == "NHWC") {
if (input(0).template IsType<float>()) NHWCRunWithType<float>(); if (input(0).template IsType<float>()) NHWCRunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
else { else {
LOG(FATAL) << "Unknown data format: " << data_format; LOG(FATAL) << "Unknown data format: " << data_format;
} }
......
...@@ -7,11 +7,11 @@ namespace dragon { ...@@ -7,11 +7,11 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void ROIAlignOp<Context>::RunWithType() { void ROIAlignOp<Context>::RunWithType() {
kernel::ROIAlign<T, Context>(spatial_scale, kernel::ROIAlign<T, Context>(spatial_scale,
pool_h, pool_w, pool_h, pool_w,
&input(0), &input(0),
&input(1), &input(1),
mask_h, mask_w, mask,
output(0)); output(0));
} }
...@@ -20,10 +20,8 @@ void ROIAlignOp<Context>::RunOnDevice() { ...@@ -20,10 +20,8 @@ void ROIAlignOp<Context>::RunOnDevice() {
vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w}); vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w});
output(0)->Reshape(dims); output(0)->Reshape(dims);
mask_h = ws()->CreateTensor("_t_" + anchor() + "_roi_align_mask_h"); mask = ws()->CreateTensor("_t_" + anchor() + "_roi_align_mask");
mask_h->Reshape(dims); mask->Reshape(dims);
mask_w = ws()->CreateTensor("_t_" + anchor() + "_roi_align_mask_w");
mask_w->Reshape(dims);
if (input(0).template IsType<float>()) return RunWithType<float>(); if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
...@@ -41,7 +39,7 @@ void ROIAlignGradientOp<Context>::RunWithType() { ...@@ -41,7 +39,7 @@ void ROIAlignGradientOp<Context>::RunWithType() {
pool_h, pool_w, pool_h, pool_w,
&input(-1), &input(-1),
&input(1), &input(1),
mask_h, mask_w, mask,
output(0)); output(0));
} }
...@@ -49,8 +47,7 @@ template <class Context> ...@@ -49,8 +47,7 @@ template <class Context>
void ROIAlignGradientOp<Context>::RunOnDevice() { void ROIAlignGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
mask_h = ws()->GetTensor("_t_" + anchor() + "_roi_align_mask_h"); mask = ws()->GetTensor("_t_" + anchor() + "_roi_align_mask");
mask_w = ws()->GetTensor("_t_" + anchor() + "_roi_align_mask_w");
if (input(0).template IsType<float>()) return RunWithType<float>(); if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
...@@ -59,8 +56,7 @@ void ROIAlignGradientOp<Context>::RunOnDevice() { ...@@ -59,8 +56,7 @@ void ROIAlignGradientOp<Context>::RunOnDevice() {
template <class Context> template <class Context>
void ROIAlignGradientOp<Context>::CleanResource() { void ROIAlignGradientOp<Context>::CleanResource() {
Operator<Context>::CleanResource(); Operator<Context>::CleanResource();
ws()->ReleaseBuffer(mask_h, "Common", true); ws()->ReleaseBuffer(mask, "Common", true);
ws()->ReleaseBuffer(mask_w, "Common", true);
} }
DEPLOY_CPU(ROIAlignGradient); DEPLOY_CPU(ROIAlignGradient);
......
...@@ -72,6 +72,124 @@ template<> void EluGrad<float, CPUContext>(const int count, ...@@ -72,6 +72,124 @@ template<> void EluGrad<float, CPUContext>(const int count,
} }
} }
/******************** activation.prelu ********************/
template<> void PRelu<float, CPUContext>(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* x,
const float* w,
float* y) {
if (channel_shared) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = std::max(x[i], float(0)) + w[0] * std::min(x[i], float(0));
}
} else {
if (data_format == "NCHW") {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
int c = (i / dim) % channels;
y[i] = std::max(x[i], float(0)) + w[c] * std::min(x[i], float(0));
}
} else if (data_format == "NHWC") {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
int c = i % channels;
y[i] = std::max(x[i], float(0)) + w[c] * std::min(x[i], float(0));
}
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
}
template<> void PReluGrad<float, CPUContext>(const int count,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* dy,
const float* x,
const float* w,
float* dx) {
if (channel_shared) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
dx[i] = dy[i] * ((x[i] > 0) + w[0] * (x[i] <= 0));
}
} else {
if (data_format == "NCHW") {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
int c = (i / dim) % channels;
dx[i] = dy[i] * ((x[i] > 0) + w[c] * (x[i] <= 0));
}
} else if (data_format == "NHWC") {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
int c = i % channels;
dx[i] = dy[i] * ((x[i] > 0) + w[c] * (x[i] <= 0));
}
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
}
template<> void PReluWGrad<float, CPUContext>(const int rows,
const int row_offset,
const int channels,
const int dim,
const bool channel_shared,
const string& data_format,
const float* dy,
const float* x,
const float* multiplier,
float* bcast_dw,
float* dw) {
const int cdim = channels * dim;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(cdim))
#endif
for (int i = 0; i < cdim; ++i) {
bcast_dw[i] = dy[i] * x[i] * (x[i] <= 0);
for (int n = 1; n < rows; n++) {
const int cur_idx = i + n * row_offset;
bcast_dw[i] += dy[cur_idx] * x[cur_idx] * (x[cur_idx] <= 0);
}
}
if (channel_shared) {
float w_sum = math::Dot<float, CPUContext>(channels * dim, bcast_dw, multiplier);
math::AddScalar<float, CPUContext>(1, w_sum, dw);
} else {
if (data_format == "NCHW") {
math::Gemv<float, CPUContext>(CblasNoTrans, channels, dim,
1.0,
bcast_dw, multiplier,
1.0,
dw);
} else if (data_format == "NHWC") {
math::Gemv<float, CPUContext>(CblasTrans, dim, channels,
1.0,
bcast_dw, multiplier,
1.0,
dw);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
}
/******************** activation.relu ********************/ /******************** activation.relu ********************/
template<> void Relu<float, CPUContext>(const int count, template<> void Relu<float, CPUContext>(const int count,
...@@ -99,6 +217,32 @@ template<> void ReluGrad<float, CPUContext>(const int count, ...@@ -99,6 +217,32 @@ template<> void ReluGrad<float, CPUContext>(const int count,
} }
} }
/******************** activation.selu ********************/
template<> void SElu<float, CPUContext>(const int count,
const float* x,
float* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = 1.0507 * std::max(x[i], float(0))
+ 1.7581 * (std::exp(std::min(x[i], float(0))) - float(1));
}
}
template<> void SEluGrad<float, CPUContext>(const int count,
const float* dy,
const float* y,
float* dx) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
dx[i] = y[i] > 0 ? 1.0507 * dy[i] : (1.7581 + y[i]) * dy[i];
}
}
/******************** activation.sigmoid ********************/ /******************** activation.sigmoid ********************/
template <typename T> template <typename T>
...@@ -210,22 +354,22 @@ template<> void TanhGrad<float, CPUContext>(const int count, ...@@ -210,22 +354,22 @@ template<> void TanhGrad<float, CPUContext>(const int count,
/******************** arithmetic.bias_add ********************/ /******************** arithmetic.bias_add ********************/
template<> void BiasAdd<float, CPUContext>(const int count, template<> void BiasAdd<float, CPUContext>(const int count,
const int outer_dim, const int outer_dim,
const int dim, const int dim,
const int inner_dim, const int inner_dim,
const string& format, const string& format,
const float* bias, const float* bias,
const float* bias_multiplier, const float* bias_multiplier,
float* y) { float* y) {
if (format == "NCHW") { if (format == "NCHW") {
const int y_offset = dim * inner_dim; const int y_offset = dim * inner_dim;
for (int n = 0; n < outer_dim; ++n) { for (int n = 0; n < outer_dim; ++n) {
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans, math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
dim, inner_dim, 1, dim, inner_dim, 1,
1.0, 1.0,
bias, bias_multiplier, bias, bias_multiplier,
1.0, 1.0,
y); y);
y += y_offset; y += y_offset;
} }
...@@ -1824,7 +1968,7 @@ template<> void ROIAlign<float, CPUContext>(const float spatial_scale, ...@@ -1824,7 +1968,7 @@ template<> void ROIAlign<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
Tensor* x, Tensor* x,
Tensor* roi, Tensor* roi,
Tensor* mask_h, Tensor* mask_w, Tensor* mask,
Tensor* y) { Tensor* y) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
...@@ -1833,7 +1977,7 @@ template<> void ROIAlignGrad<float, CPUContext>(const float spatial_scale, ...@@ -1833,7 +1977,7 @@ template<> void ROIAlignGrad<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
Tensor* dy, Tensor* dy,
Tensor* roi, Tensor* roi,
Tensor* mask_h, Tensor* mask_w, Tensor* mask,
Tensor* dx) { Tensor* dx) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!