Commit 6683676d by Ting PAN

GroupNormalization Support

1 parent 18b664b1
...@@ -16,15 +16,16 @@ class ReshapeOp final : public Operator<Context> { ...@@ -16,15 +16,16 @@ class ReshapeOp final : public Operator<Context> {
public: public:
ReshapeOp(const OperatorDef& op_def, Workspace* ws) ReshapeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
shape(OperatorBase::GetRepeatedArg<int>("shape")) { shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) {
new_shape.resize(shape.size()); GET_ARGUMENTS_WITH_DESC(int, shape);
} }
void RunOnDevice() override; void RunOnDevice() override;
protected: protected:
vector<int> shape; DECLARE_ARGUMENTS_WITH_DESC(int, shape);
vector<TIndex> new_shape; string shape_like_desc;
vector<TIndex> require_shape, new_shape;
}; };
template <class Context> template <class Context>
...@@ -38,6 +39,8 @@ class ReshapeGradientOp final : public Operator<Context> { ...@@ -38,6 +39,8 @@ class ReshapeGradientOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, ReshapeOp, shape);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
\ No newline at end of file
...@@ -105,7 +105,7 @@ class FusedBatchNormGradientOp : public Operator<Context> { ...@@ -105,7 +105,7 @@ class FusedBatchNormGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { } use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup(); void Setup();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class GroupNormOp : public Operator<Context> {
public:
GroupNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float momentum, eps;
Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format, mode;
int use_stats;
bool use_global_stats, is_recomputing;
};
template <class Context>
class GroupNormGradientOp final : public Operator<Context> {
public:
GroupNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats;
};
template <class Context>
class FusedGroupNormOp : public Operator<Context> {
public:
FusedGroupNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float momentum, eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats, is_recomputing;
};
template <class Context>
class FusedGroupNormGradientOp : public Operator<Context> {
public:
FusedGroupNormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup();
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
\ No newline at end of file
...@@ -16,6 +16,14 @@ ...@@ -16,6 +16,14 @@
\sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\ \sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \\ y_{i} = \gamma\hat{x}_{i} + \beta \\ \, \hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \\ y_{i} = \gamma\hat{x}_{i} + \beta \\ \,
.. |groupnorm_function| mathmacro:: \\ \, \\ \mu_{G} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{G}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{G})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{G}}{\sqrt{\sigma_{G}^{2} + \epsilon}} \\ \,
.. |groupnorm_scale_function| mathmacro:: \\ \, \\ \mu_{G} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{G}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{G})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{G}}{\sqrt{\sigma_{G}^{2} + \epsilon}} \\ y_{i} = \gamma\hat{x}_{i} + \beta \\ \,
.. |batchrenorm_function| mathmacro:: \\ \, \\ \mu_{B} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\ .. |batchrenorm_function| mathmacro:: \\ \, \\ \mu_{B} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\ \sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \cdot r + d \\ \, \hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \cdot r + d \\ \,
......
...@@ -113,6 +113,8 @@ List Brief ...@@ -113,6 +113,8 @@ List Brief
`BatchNorm`_ Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_. `BatchNorm`_ Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`BatchRenorm`_ Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_. `BatchRenorm`_ Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_.
`FusedBatchNorm`_ Batch Normalization, with scale procedure after normalization. `FusedBatchNorm`_ Batch Normalization, with scale procedure after normalization.
`GroupNorm`_ Group Normalization, introduced by `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`FusedGroupNorm`_ Group Normalization, with scale procedure after normalization.
`InstanceNorm`_ Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_. `InstanceNorm`_ Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_.
`L2Norm`_ L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_. `L2Norm`_ L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
================== ====================================================================== ================== ======================================================================
...@@ -253,6 +255,8 @@ List Brief ...@@ -253,6 +255,8 @@ List Brief
.. _BatchNorm: operators/norm.html#dragon.operators.norm.BatchNorm .. _BatchNorm: operators/norm.html#dragon.operators.norm.BatchNorm
.. _BatchRenorm: operators/norm.html#dragon.operators.norm.BatchRenorm .. _BatchRenorm: operators/norm.html#dragon.operators.norm.BatchRenorm
.. _FusedBatchNorm: operators/norm.html#dragon.operators.norm.FusedBatchNorm .. _FusedBatchNorm: operators/norm.html#dragon.operators.norm.FusedBatchNorm
.. _GroupNorm: operators/norm.html#dragon.operators.norm.GroupNorm
.. _FusedGroupNorm: operators/norm.html#dragon.operators.norm.FusedGroupNorm
.. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm .. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm
.. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm .. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm
......
...@@ -73,9 +73,11 @@ List Brief ...@@ -73,9 +73,11 @@ List Brief
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``. `ArgMaxLayer`_ The implementation of ``ArgMaxLayer``.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``. `BatchNormLayer`_ The implementation of ``BatchNormLayer``.
`BatchRenormLayer`_ The implementation of ``BatchRenormLayer``. `BatchRenormLayer`_ The implementation of ``BatchRenormLayer``.
`GroupNormLayer`_ The implementation of ``GroupNormLayer``.
`InstanceNormLayer`_ The implementation of ``InstanceNormLayer``. `InstanceNormLayer`_ The implementation of ``InstanceNormLayer``.
`ScaleLayer`_ The implementation of ``ScaleLayer``. `ScaleLayer`_ The implementation of ``ScaleLayer``.
`BNLayer`_ The implementation of ``BNLayer``. `BNLayer`_ The implementation of ``BNLayer``.
`GNLayer`_ The implementation of ``GNLayer``.
`NormalizeLayer`_ The implementation of ``NormalizeLayer``. `NormalizeLayer`_ The implementation of ``NormalizeLayer``.
`TileLayer`_ The extended implementation of ``TileLayer``. `TileLayer`_ The extended implementation of ``TileLayer``.
`ExpandDimsLayer`_ The implementation of ``ExpandDimsLayer``. `ExpandDimsLayer`_ The implementation of ``ExpandDimsLayer``.
...@@ -181,9 +183,11 @@ API Reference ...@@ -181,9 +183,11 @@ API Reference
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer .. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer .. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
.. _BatchRenormLayer: #dragon.vm.caffe.layers.common.BatchRenormLayer .. _BatchRenormLayer: #dragon.vm.caffe.layers.common.BatchRenormLayer
.. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer
.. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer .. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer
.. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer .. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer
.. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer .. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer
.. _GNLayer: #dragon.vm.caffe.layers.common.GNLayer
.. _NormalizeLayer: #dragon.vm.caffe.layers.common.NormalizeLayer .. _NormalizeLayer: #dragon.vm.caffe.layers.common.NormalizeLayer
.. _TileLayer: #dragon.vm.caffe.layers.common.TileLayer .. _TileLayer: #dragon.vm.caffe.layers.common.TileLayer
.. _ExpandDimsLayer: #dragon.vm.caffe.layers.common.ExpandDimsLayer .. _ExpandDimsLayer: #dragon.vm.caffe.layers.common.ExpandDimsLayer
......
...@@ -648,15 +648,21 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs): ...@@ -648,15 +648,21 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
return output return output
def Reshape(inputs, shape, **kwargs): def Reshape(inputs, shape, shape_like=None, **kwargs):
"""Reshape the dimensions of input. """Reshape the dimensions of input.
``shape`` could be a list of numbers or Tensors.
Set ``shape`` to ``None``, if you want to use ``shape_like``.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
shape : list or tuple shape : list, tuple or None
The new shape. The new shape.
shape_like: Tensor or None
The tensor for indicating the output shape.
Returns Returns
------- -------
...@@ -677,17 +683,29 @@ def Reshape(inputs, shape, **kwargs): ...@@ -677,17 +683,29 @@ def Reshape(inputs, shape, **kwargs):
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if not isinstance(shape, tuple) and not isinstance(shape, list): if shape is not None:
raise TypeError('The type of dims must be a tuple or list.') AddArgumentsWithDesc(arguments, shape, 'shape', 'int32', as_target=True)
elif shape_like is not None:
if not isinstance(shape_like, Tensor):
raise TypeError('The shape_like should be a Tensor.')
arguments['shape_like'] = shape_like.name
output = Tensor.CreateOperator(nout=1, op_type='Reshape', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Reshape', **arguments)
if inputs.shape is not None: if inputs.shape is not None:
output.shape = [1] * len(shape) possible_to_infer_shape = True
for i, s in enumerate(shape): if shape is not None:
if s == -1: output.shape[i] = 1 for dim in shape:
elif s == 0: output.shape[i] = inputs.shape[i] if isinstance(dim, Tensor):
else: output.shape[i] = s possible_to_infer_shape = False
if shape_like is not None:
possible_to_infer_shape = False
if possible_to_infer_shape:
output.shape = [1] * len(shape)
for i, s in enumerate(shape):
if s == -1: output.shape[i] = 1
elif s == 0: output.shape[i] = inputs.shape[i]
else: output.shape[i] = s
return output return output
......
...@@ -165,6 +165,103 @@ def FusedBatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwar ...@@ -165,6 +165,103 @@ def FusedBatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwar
return output return output
def GroupNorm(inputs, group=32, axis=-1, momentum=0.9, eps=1e-3,
use_stats=-1, mode='DEFAULT', **kwargs):
"""Group Normalization, introduced by `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
It follows the implementation of `Caffe`_, that scale procedure is moved to `ops.Scale(*args, **kwargs)`_.
The number of inputs vary from ``3`` to ``4`` (``DEFAULT`` or ``CAFFE`` mode).
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var] or [input, mean, var, factor].
group : int
The group size.
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
The eps.
use_stats : int
Whether to use global stats. Default is ``-1`` (Auto).
mode : str
The moving average mode. ``DEFAULT`` or ``CAFFE``.
Returns
-------
Tensor
The output tensor, calculated as:
|groupnorm_function|
The ``DEFAULT`` moving average of mean/var, calculated as:
|default_moving_average_function|
The ``CAFFE`` moving average of mean/var, calculated as:
|caffe_moving_average_function|
"""
CheckInputs(inputs, 3, 4)
arguments = ParseArguments(locals())
if len(inputs) > 3:
if mode != 'CAFFE':
raise ValueError('Only the CAFFE mode will take 4 inputs.')
output = Tensor.CreateOperator(nout=1, op_type='GroupNorm', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
return output
def FusedGroupNorm(inputs, group=32, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwargs):
"""Group Normalization, with scale procedure after normalization.
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var, scale, bias].
group : int
The group size.
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
The eps.
use_stats : int
Whether to use global stats. Default is ``-1`` (Auto).
Returns
-------
Tensor
The output tensor, calculated as:
|groupnorm_scale_function|
The moving average of mean/var, calculated as:
|default_moving_average_function|
"""
CheckInputs(inputs, 5)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='FusedGroupNorm', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
return output
def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs): def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs):
"""Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_ """Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
......
...@@ -92,7 +92,9 @@ GramMatrix = math.GramMatrix ...@@ -92,7 +92,9 @@ GramMatrix = math.GramMatrix
# normalization # normalization
BatchNorm = norm.BatchNorm BatchNorm = norm.BatchNorm
BatchRenorm = norm.BatchRenorm BatchRenorm = norm.BatchRenorm
GroupNorm = norm.GroupNorm
FusedBatchNorm = norm.FusedBatchNorm FusedBatchNorm = norm.FusedBatchNorm
FusedGroupNorm = norm.FusedGroupNorm
InstanceNorm = norm.InstanceNorm InstanceNorm = norm.InstanceNorm
L2Norm = norm.L2Norm L2Norm = norm.L2Norm
......
...@@ -39,6 +39,8 @@ from .common import InnerProductLayer, \ ...@@ -39,6 +39,8 @@ from .common import InnerProductLayer, \
BatchNormLayer, \ BatchNormLayer, \
BatchRenormLayer,\ BatchRenormLayer,\
BNLayer, \ BNLayer, \
GroupNormLayer, \
GNLayer, \
ConcatLayer, \ ConcatLayer, \
CropLayer, \ CropLayer, \
PythonLayer, \ PythonLayer, \
......
...@@ -412,6 +412,47 @@ class BatchRenormLayer(Layer): ...@@ -412,6 +412,47 @@ class BatchRenormLayer(Layer):
return ops.BatchRenorm(bottom + [blob['data'] for blob in self._blobs], **self._param) return ops.BatchRenorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class GroupNormLayer(Layer):
"""The implementation of ``GroupNormLayer``.
Parameters
----------
group : int
Refer ``GroupNormParameter.group``.
use_global_stats : boolean
Refer ``GroupNormParameter.use_global_stats``.
moving_average_fraction : float
Refer ``GroupNormParameter.moving_average_fraction``.
eps : float
Refer ``GroupNormParameter.eps``.
"""
def __init__(self, LayerParameter):
super(GroupNormLayer, self).__init__(LayerParameter)
param = LayerParameter.group_norm_param
self._param = {'group': int(param.group),
'use_stats': int(param.use_global_stats)
if param.HasField('use_global_stats') else -1,
'momentum': param.moving_average_fraction,
'eps': param.eps,
'axis': 1,
'mode': 'CAFFE'}
scope = LayerParameter.name
# mean, var, factor are set to 0 in order to do statistics
mean = Tensor(scope + '/param:0').Constant(value=0.0)
var = Tensor(scope + '/param:1').Constant(value=0.0)
factor = Tensor(scope + '/param:2').Constant(value=0.0)
# in dragon, set diff as None will ignore computing grad automatically
# but in bvlc-caffe1, you must set lr_mult = 0 manually
self._blobs.append({'data': mean, 'diff': None})
self._blobs.append({'data': var, 'diff': None})
self._blobs.append({'data': factor, 'diff': None})
def Setup(self, bottom):
super(GroupNormLayer, self).Setup(bottom)
return ops.GroupNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class InstanceNormLayer(Layer): class InstanceNormLayer(Layer):
""" """
The implementation of ``InstanceNormLayer``. The implementation of ``InstanceNormLayer``.
...@@ -518,6 +559,59 @@ class BNLayer(Layer): ...@@ -518,6 +559,59 @@ class BNLayer(Layer):
return ops.FusedBatchNorm(bottom + [blob['data'] for blob in self._blobs], **self._param) return ops.FusedBatchNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class GNLayer(Layer):
"""The implementation of ``GNLayer``.
Parameters
----------
group : int
Refer ``GroupNormParameter.group``.
use_global_stats : boolean
Refer ``GroupNormParameter.use_global_stats``.
moving_average_fraction : float
Refer ``GroupNormParameter.moving_average_fraction``.
eps : float
Refer ``GroupNormParameter.eps``.
filler : FillerParameter
The filler of scale parameter. Refer `ScaleParameter.filler`_.
bias_filler : FillerParameter
The filler of bias parameter. Refer `ScaleParameter.bias_filler`_.
"""
def __init__(self, LayerParameter):
super(GNLayer, self).__init__(LayerParameter)
gn_param = LayerParameter.group_norm_param
scale_param = LayerParameter.scale_param
self._param = {'group': int(gn_param.group),
'use_stats': int(gn_param.use_global_stats)
if gn_param.HasField('use_global_stats') else -1,
'momentum': gn_param.moving_average_fraction,
'eps': gn_param.eps,
'axis': 1}
scope = LayerParameter.name
mean = Tensor(scope + '/param:0').Constant(value=0.0)
var = Tensor(scope + '/param:1').Constant(value=0.0)
scale = Tensor(scope + '/param:2')
scale_diff = Tensor(scope + '/param:2_grad')
bias = Tensor(scope + '/param:3')
bias_diff = Tensor(scope + '/param:3_grad')
if scale_param.HasField('filler'):
self.Fill(scale, scale_param, 'filler')
else: scale.Constant(value=1.0)
self.Fill(bias, scale_param, 'bias_filler')
self.norm_blobs = [{'data': mean, 'diff': None},
{'data': var, 'diff': None}]
self.scale_blobs = [{'data': scale, 'diff': scale_diff},
{'data': bias, 'diff': bias_diff}]
self._blobs.extend(self.norm_blobs)
self._blobs.extend(self.scale_blobs)
def Setup(self, bottom):
super(GNLayer, self).Setup(bottom)
return ops.FusedGroupNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class NormalizeLayer(Layer): class NormalizeLayer(Layer):
"""The implementation of ``NormalizeLayer``. """The implementation of ``NormalizeLayer``.
......
...@@ -423,6 +423,7 @@ message LayerParameter { ...@@ -423,6 +423,7 @@ message LayerParameter {
optional DenseConcatParameter dense_concat_param = 163; optional DenseConcatParameter dense_concat_param = 163;
optional FocalLossParameter focal_loss_param = 164; optional FocalLossParameter focal_loss_param = 164;
optional GatherParameter gather_param = 165; optional GatherParameter gather_param = 165;
optional GroupNormParameter group_norm_param = 166;
} }
// Message that stores parameters used to apply transformation // Message that stores parameters used to apply transformation
...@@ -1512,3 +1513,17 @@ message GatherParameter { ...@@ -1512,3 +1513,17 @@ message GatherParameter {
optional int32 axis = 1 [default = 0]; optional int32 axis = 1 [default = 0];
} }
message GroupNormParameter {
// If false, accumulate global mean/variance values via a moving average. If
// true, use those accumulated values instead of computing mean/variance
// across the batch.
optional bool use_global_stats = 1;
// How much does the moving average decay each iteration?
optional float moving_average_fraction = 2 [default = 0.9];
// Small value to add to the variance estimate so that we don't divide by
// zero.
optional float eps = 3 [default = 1e-3];
optional uint32 group = 5 [default = 32]; // The group size
}
...@@ -36,7 +36,7 @@ find_packages('dragon') ...@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules() find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.1.12', version='0.2.1.13',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon', url='https://github.com/neopenx/Dragon',
author='Ting Pan', author='Ting Pan',
......
#include "operators/ndarray/reshape_op.h" #include "operators/ndarray/reshape_op.h"
#include "core/workspace.h"
namespace dragon { namespace dragon {
string dim_string(const vector<TIndex>& shape) {
std::stringstream ss;
ss << "(";
for (int i = 0; i < shape.size() - 1; i++) ss << shape[i] << ",";
ss << shape[shape.size() - 1] << ")";
return ss.str();
}
template <class Context> template <class Context>
void ReshapeOp<Context>::RunOnDevice() { void ReshapeOp<Context>::RunOnDevice() {
if (shape_desc.size() > 0 || shape_value.size() > 0) {
require_shape.resize(std::max(shape_desc.size(), shape_value.size()));
for (int i = 0; i < require_shape.size(); i++)
require_shape[i] = shape(i);
} else if (shape_like_desc.size() > 0) {
Tensor* shape_like_tensor = ws()->GetTensor(shape_like_desc);
require_shape.resize(shape_like_tensor->ndim());
for (int i = 0; i < require_shape.size(); i++)
require_shape[i] = shape_like_tensor->dim(i);
} else { LOG(FATAL) << "Missing the require shape."; }
vector<TIndex> Xdims = input(0).dims(); vector<TIndex> Xdims = input(0).dims();
new_shape.resize(require_shape.size());
int infer_dim = -1; int infer_dim = -1;
TIndex total_count = 1; TIndex total_count = 1;
for (int i = 0; i < shape.size(); i++) { for (int i = 0; i < require_shape.size(); i++) {
// handle unchanged dim if (require_shape[i] == 0) {
if (shape[i] == 0) { // handle unchanged dim
CHECK_LT(i, (int)Xdims.size()) CHECK_LT(i, (int)Xdims.size())
<< "\nDim(" << i << ") is out of the Xdims range of (0, " << "\nDim(" << i << ") is out of the Xdims range of (0, "
<< Xdims.size() << ")."; << Xdims.size() << ").";
new_shape[i] = Xdims[i]; new_shape[i] = Xdims[i];
} } else if (require_shape[i] > 0) {
// handle reseted dim // handle reseted dim
else if (shape[i] > 0) { new_shape[i] = require_shape[i];
new_shape[i] = shape[i]; } else {
} // handle inferred dim
// handle inferred dim
else {
CHECK_EQ(infer_dim, -1) CHECK_EQ(infer_dim, -1)
<< "\nDim(" << infer_dim << ") required infer before" << "\nDim(" << infer_dim << ") required infer before"
<< "\ncould not infer for dim(" << i << ") both."; << "\ncould not infer for dim(" << i << ") both.";
...@@ -35,7 +54,8 @@ void ReshapeOp<Context>::RunOnDevice() { ...@@ -35,7 +54,8 @@ void ReshapeOp<Context>::RunOnDevice() {
for (int i = 0; i < new_shape.size(); i++) { for (int i = 0; i < new_shape.size(); i++) {
if (new_shape[i] == -1) { if (new_shape[i] == -1) {
CHECK_EQ(input(0).count() % total_count, 0) CHECK_EQ(input(0).count() % total_count, 0)
<< "\nCan not change the total size."; << "\nCan not change the total size: "
<< input(0).dim_string() << " -> " << dim_string(new_shape);
new_shape[i] = input(0).count() / total_count; new_shape[i] = input(0).count() / total_count;
total_count *= new_shape[i]; total_count *= new_shape[i];
break; break;
...@@ -43,7 +63,8 @@ void ReshapeOp<Context>::RunOnDevice() { ...@@ -43,7 +63,8 @@ void ReshapeOp<Context>::RunOnDevice() {
} }
} }
CHECK_EQ(total_count, input(0).count()) CHECK_EQ(total_count, input(0).count())
<< "\nCan not change the total size."; << "\nCan not change the total size."
<< input(0).dim_string() << " -> " << dim_string(new_shape);
output(0)->Reshape(new_shape); output(0)->Reshape(new_shape);
output(0)->Share(input(0)); output(0)->Share(input(0));
} }
......
...@@ -19,9 +19,9 @@ void BatchNormOp<Context>::TrainingRunWithType() { ...@@ -19,9 +19,9 @@ void BatchNormOp<Context>::TrainingRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
...@@ -127,9 +127,9 @@ void BatchNormOp<Context>::InferenceRunWithType() { ...@@ -127,9 +127,9 @@ void BatchNormOp<Context>::InferenceRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata);
...@@ -169,12 +169,12 @@ void BatchNormOp<Context>::InferenceRunWithType() { ...@@ -169,12 +169,12 @@ void BatchNormOp<Context>::InferenceRunWithType() {
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1, math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, tVar_data, 1.0, NMul_data, tVar_data,
0.0, NC_data); 0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1, math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data, 1.0, NC_data, SMul_data,
0.0, Std_data); 0.0, Std_data);
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1, math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, tVar_data, 1.0, NSMul_data, tVar_data,
...@@ -248,9 +248,9 @@ void BatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -248,9 +248,9 @@ void BatchNormGradientOp<Context>::TrainingRunWithType() {
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -337,9 +337,9 @@ void BatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -337,9 +337,9 @@ void BatchNormGradientOp<Context>::InferenceRunWithType() {
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
if (data_format == "NCHW") { if (data_format == "NCHW") {
......
...@@ -23,9 +23,9 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() { ...@@ -23,9 +23,9 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
...@@ -153,9 +153,9 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() { ...@@ -153,9 +153,9 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata);
...@@ -296,9 +296,9 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -296,9 +296,9 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tMean_data = mean->template mutable_data<T, Context>(); auto* tMean_data = mean->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* XNorm_data = x_norm->template data<T, Context>(); auto* XNorm_data = x_norm->template data<T, Context>();
...@@ -436,9 +436,9 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -436,9 +436,9 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() {
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* Sdata = input(3).template data<T, Context>(); auto* Sdata = input(3).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// gradient w.r.t. scale // gradient w.r.t. scale
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!