Commit 6683676d by Ting PAN

GroupNormalization Support

1 parent 18b664b1
......@@ -16,15 +16,16 @@ class ReshapeOp final : public Operator<Context> {
public:
ReshapeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
shape(OperatorBase::GetRepeatedArg<int>("shape")) {
new_shape.resize(shape.size());
shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) {
GET_ARGUMENTS_WITH_DESC(int, shape);
}
void RunOnDevice() override;
protected:
vector<int> shape;
vector<TIndex> new_shape;
DECLARE_ARGUMENTS_WITH_DESC(int, shape);
string shape_like_desc;
vector<TIndex> require_shape, new_shape;
};
template <class Context>
......@@ -38,6 +39,8 @@ class ReshapeGradientOp final : public Operator<Context> {
void RunOnDevice() override;
};
DEFINE_ARGUMENTS_WITH_DESC(int, ReshapeOp, shape);
} // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
\ No newline at end of file
......@@ -105,7 +105,7 @@ class FusedBatchNormGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { }
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class GroupNormOp : public Operator<Context> {
public:
GroupNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float momentum, eps;
Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format, mode;
int use_stats;
bool use_global_stats, is_recomputing;
};
template <class Context>
class GroupNormGradientOp final : public Operator<Context> {
public:
GroupNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats;
};
template <class Context>
class FusedGroupNormOp : public Operator<Context> {
public:
FusedGroupNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float momentum, eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats, is_recomputing;
};
template <class Context>
class FusedGroupNormGradientOp : public Operator<Context> {
public:
FusedGroupNormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup();
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
\ No newline at end of file
......@@ -16,6 +16,14 @@
\sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \\ y_{i} = \gamma\hat{x}_{i} + \beta \\ \,
.. |groupnorm_function| mathmacro:: \\ \, \\ \mu_{G} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{G}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{G})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{G}}{\sqrt{\sigma_{G}^{2} + \epsilon}} \\ \,
.. |groupnorm_scale_function| mathmacro:: \\ \, \\ \mu_{G} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{G}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{G})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{G}}{\sqrt{\sigma_{G}^{2} + \epsilon}} \\ y_{i} = \gamma\hat{x}_{i} + \beta \\ \,
.. |batchrenorm_function| mathmacro:: \\ \, \\ \mu_{B} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \cdot r + d \\ \,
......
......@@ -113,6 +113,8 @@ List Brief
`BatchNorm`_ Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`BatchRenorm`_ Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_.
`FusedBatchNorm`_ Batch Normalization, with scale procedure after normalization.
`GroupNorm`_ Group Normalization, introduced by `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`FusedGroupNorm`_ Group Normalization, with scale procedure after normalization.
`InstanceNorm`_ Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_.
`L2Norm`_ L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
================== ======================================================================
......@@ -253,6 +255,8 @@ List Brief
.. _BatchNorm: operators/norm.html#dragon.operators.norm.BatchNorm
.. _BatchRenorm: operators/norm.html#dragon.operators.norm.BatchRenorm
.. _FusedBatchNorm: operators/norm.html#dragon.operators.norm.FusedBatchNorm
.. _GroupNorm: operators/norm.html#dragon.operators.norm.GroupNorm
.. _FusedGroupNorm: operators/norm.html#dragon.operators.norm.FusedGroupNorm
.. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm
.. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm
......
......@@ -73,9 +73,11 @@ List Brief
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``.
`BatchRenormLayer`_ The implementation of ``BatchRenormLayer``.
`GroupNormLayer`_ The implementation of ``GroupNormLayer``.
`InstanceNormLayer`_ The implementation of ``InstanceNormLayer``.
`ScaleLayer`_ The implementation of ``ScaleLayer``.
`BNLayer`_ The implementation of ``BNLayer``.
`GNLayer`_ The implementation of ``GNLayer``.
`NormalizeLayer`_ The implementation of ``NormalizeLayer``.
`TileLayer`_ The extended implementation of ``TileLayer``.
`ExpandDimsLayer`_ The implementation of ``ExpandDimsLayer``.
......@@ -181,9 +183,11 @@ API Reference
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
.. _BatchRenormLayer: #dragon.vm.caffe.layers.common.BatchRenormLayer
.. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer
.. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer
.. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer
.. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer
.. _GNLayer: #dragon.vm.caffe.layers.common.GNLayer
.. _NormalizeLayer: #dragon.vm.caffe.layers.common.NormalizeLayer
.. _TileLayer: #dragon.vm.caffe.layers.common.TileLayer
.. _ExpandDimsLayer: #dragon.vm.caffe.layers.common.ExpandDimsLayer
......
......@@ -648,15 +648,21 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
return output
def Reshape(inputs, shape, **kwargs):
def Reshape(inputs, shape, shape_like=None, **kwargs):
"""Reshape the dimensions of input.
``shape`` could be a list of numbers or Tensors.
Set ``shape`` to ``None``, if you want to use ``shape_like``.
Parameters
----------
inputs : Tensor
The input tensor.
shape : list or tuple
shape : list, tuple or None
The new shape.
shape_like: Tensor or None
The tensor for indicating the output shape.
Returns
-------
......@@ -677,17 +683,29 @@ def Reshape(inputs, shape, **kwargs):
CheckInputs(inputs, 1)
arguments = ParseArguments(locals())
if not isinstance(shape, tuple) and not isinstance(shape, list):
raise TypeError('The type of dims must be a tuple or list.')
if shape is not None:
AddArgumentsWithDesc(arguments, shape, 'shape', 'int32', as_target=True)
elif shape_like is not None:
if not isinstance(shape_like, Tensor):
raise TypeError('The shape_like should be a Tensor.')
arguments['shape_like'] = shape_like.name
output = Tensor.CreateOperator(nout=1, op_type='Reshape', **arguments)
if inputs.shape is not None:
output.shape = [1] * len(shape)
for i, s in enumerate(shape):
if s == -1: output.shape[i] = 1
elif s == 0: output.shape[i] = inputs.shape[i]
else: output.shape[i] = s
possible_to_infer_shape = True
if shape is not None:
for dim in shape:
if isinstance(dim, Tensor):
possible_to_infer_shape = False
if shape_like is not None:
possible_to_infer_shape = False
if possible_to_infer_shape:
output.shape = [1] * len(shape)
for i, s in enumerate(shape):
if s == -1: output.shape[i] = 1
elif s == 0: output.shape[i] = inputs.shape[i]
else: output.shape[i] = s
return output
......
......@@ -165,6 +165,103 @@ def FusedBatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwar
return output
def GroupNorm(inputs, group=32, axis=-1, momentum=0.9, eps=1e-3,
use_stats=-1, mode='DEFAULT', **kwargs):
"""Group Normalization, introduced by `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
It follows the implementation of `Caffe`_, that scale procedure is moved to `ops.Scale(*args, **kwargs)`_.
The number of inputs vary from ``3`` to ``4`` (``DEFAULT`` or ``CAFFE`` mode).
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var] or [input, mean, var, factor].
group : int
The group size.
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
The eps.
use_stats : int
Whether to use global stats. Default is ``-1`` (Auto).
mode : str
The moving average mode. ``DEFAULT`` or ``CAFFE``.
Returns
-------
Tensor
The output tensor, calculated as:
|groupnorm_function|
The ``DEFAULT`` moving average of mean/var, calculated as:
|default_moving_average_function|
The ``CAFFE`` moving average of mean/var, calculated as:
|caffe_moving_average_function|
"""
CheckInputs(inputs, 3, 4)
arguments = ParseArguments(locals())
if len(inputs) > 3:
if mode != 'CAFFE':
raise ValueError('Only the CAFFE mode will take 4 inputs.')
output = Tensor.CreateOperator(nout=1, op_type='GroupNorm', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
return output
def FusedGroupNorm(inputs, group=32, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwargs):
"""Group Normalization, with scale procedure after normalization.
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var, scale, bias].
group : int
The group size.
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
The eps.
use_stats : int
Whether to use global stats. Default is ``-1`` (Auto).
Returns
-------
Tensor
The output tensor, calculated as:
|groupnorm_scale_function|
The moving average of mean/var, calculated as:
|default_moving_average_function|
"""
CheckInputs(inputs, 5)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='FusedGroupNorm', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
return output
def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs):
"""Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
......
......@@ -92,7 +92,9 @@ GramMatrix = math.GramMatrix
# normalization
BatchNorm = norm.BatchNorm
BatchRenorm = norm.BatchRenorm
GroupNorm = norm.GroupNorm
FusedBatchNorm = norm.FusedBatchNorm
FusedGroupNorm = norm.FusedGroupNorm
InstanceNorm = norm.InstanceNorm
L2Norm = norm.L2Norm
......
......@@ -39,6 +39,8 @@ from .common import InnerProductLayer, \
BatchNormLayer, \
BatchRenormLayer,\
BNLayer, \
GroupNormLayer, \
GNLayer, \
ConcatLayer, \
CropLayer, \
PythonLayer, \
......
......@@ -412,6 +412,47 @@ class BatchRenormLayer(Layer):
return ops.BatchRenorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class GroupNormLayer(Layer):
"""The implementation of ``GroupNormLayer``.
Parameters
----------
group : int
Refer ``GroupNormParameter.group``.
use_global_stats : boolean
Refer ``GroupNormParameter.use_global_stats``.
moving_average_fraction : float
Refer ``GroupNormParameter.moving_average_fraction``.
eps : float
Refer ``GroupNormParameter.eps``.
"""
def __init__(self, LayerParameter):
super(GroupNormLayer, self).__init__(LayerParameter)
param = LayerParameter.group_norm_param
self._param = {'group': int(param.group),
'use_stats': int(param.use_global_stats)
if param.HasField('use_global_stats') else -1,
'momentum': param.moving_average_fraction,
'eps': param.eps,
'axis': 1,
'mode': 'CAFFE'}
scope = LayerParameter.name
# mean, var, factor are set to 0 in order to do statistics
mean = Tensor(scope + '/param:0').Constant(value=0.0)
var = Tensor(scope + '/param:1').Constant(value=0.0)
factor = Tensor(scope + '/param:2').Constant(value=0.0)
# in dragon, set diff as None will ignore computing grad automatically
# but in bvlc-caffe1, you must set lr_mult = 0 manually
self._blobs.append({'data': mean, 'diff': None})
self._blobs.append({'data': var, 'diff': None})
self._blobs.append({'data': factor, 'diff': None})
def Setup(self, bottom):
super(GroupNormLayer, self).Setup(bottom)
return ops.GroupNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class InstanceNormLayer(Layer):
"""
The implementation of ``InstanceNormLayer``.
......@@ -518,6 +559,59 @@ class BNLayer(Layer):
return ops.FusedBatchNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class GNLayer(Layer):
"""The implementation of ``GNLayer``.
Parameters
----------
group : int
Refer ``GroupNormParameter.group``.
use_global_stats : boolean
Refer ``GroupNormParameter.use_global_stats``.
moving_average_fraction : float
Refer ``GroupNormParameter.moving_average_fraction``.
eps : float
Refer ``GroupNormParameter.eps``.
filler : FillerParameter
The filler of scale parameter. Refer `ScaleParameter.filler`_.
bias_filler : FillerParameter
The filler of bias parameter. Refer `ScaleParameter.bias_filler`_.
"""
def __init__(self, LayerParameter):
super(GNLayer, self).__init__(LayerParameter)
gn_param = LayerParameter.group_norm_param
scale_param = LayerParameter.scale_param
self._param = {'group': int(gn_param.group),
'use_stats': int(gn_param.use_global_stats)
if gn_param.HasField('use_global_stats') else -1,
'momentum': gn_param.moving_average_fraction,
'eps': gn_param.eps,
'axis': 1}
scope = LayerParameter.name
mean = Tensor(scope + '/param:0').Constant(value=0.0)
var = Tensor(scope + '/param:1').Constant(value=0.0)
scale = Tensor(scope + '/param:2')
scale_diff = Tensor(scope + '/param:2_grad')
bias = Tensor(scope + '/param:3')
bias_diff = Tensor(scope + '/param:3_grad')
if scale_param.HasField('filler'):
self.Fill(scale, scale_param, 'filler')
else: scale.Constant(value=1.0)
self.Fill(bias, scale_param, 'bias_filler')
self.norm_blobs = [{'data': mean, 'diff': None},
{'data': var, 'diff': None}]
self.scale_blobs = [{'data': scale, 'diff': scale_diff},
{'data': bias, 'diff': bias_diff}]
self._blobs.extend(self.norm_blobs)
self._blobs.extend(self.scale_blobs)
def Setup(self, bottom):
super(GNLayer, self).Setup(bottom)
return ops.FusedGroupNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class NormalizeLayer(Layer):
"""The implementation of ``NormalizeLayer``.
......
......@@ -423,6 +423,7 @@ message LayerParameter {
optional DenseConcatParameter dense_concat_param = 163;
optional FocalLossParameter focal_loss_param = 164;
optional GatherParameter gather_param = 165;
optional GroupNormParameter group_norm_param = 166;
}
// Message that stores parameters used to apply transformation
......@@ -1512,3 +1513,17 @@ message GatherParameter {
optional int32 axis = 1 [default = 0];
}
message GroupNormParameter {
// If false, accumulate global mean/variance values via a moving average. If
// true, use those accumulated values instead of computing mean/variance
// across the batch.
optional bool use_global_stats = 1;
// How much does the moving average decay each iteration?
optional float moving_average_fraction = 2 [default = 0.9];
// Small value to add to the variance estimate so that we don't divide by
// zero.
optional float eps = 3 [default = 1e-3];
optional uint32 group = 5 [default = 32]; // The group size
}
......@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules()
setup(name = 'dragon',
version='0.2.1.12',
version='0.2.1.13',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon',
author='Ting Pan',
......
#include "operators/ndarray/reshape_op.h"
#include "core/workspace.h"
namespace dragon {
string dim_string(const vector<TIndex>& shape) {
std::stringstream ss;
ss << "(";
for (int i = 0; i < shape.size() - 1; i++) ss << shape[i] << ",";
ss << shape[shape.size() - 1] << ")";
return ss.str();
}
template <class Context>
void ReshapeOp<Context>::RunOnDevice() {
if (shape_desc.size() > 0 || shape_value.size() > 0) {
require_shape.resize(std::max(shape_desc.size(), shape_value.size()));
for (int i = 0; i < require_shape.size(); i++)
require_shape[i] = shape(i);
} else if (shape_like_desc.size() > 0) {
Tensor* shape_like_tensor = ws()->GetTensor(shape_like_desc);
require_shape.resize(shape_like_tensor->ndim());
for (int i = 0; i < require_shape.size(); i++)
require_shape[i] = shape_like_tensor->dim(i);
} else { LOG(FATAL) << "Missing the require shape."; }
vector<TIndex> Xdims = input(0).dims();
new_shape.resize(require_shape.size());
int infer_dim = -1;
TIndex total_count = 1;
for (int i = 0; i < shape.size(); i++) {
// handle unchanged dim
if (shape[i] == 0) {
for (int i = 0; i < require_shape.size(); i++) {
if (require_shape[i] == 0) {
// handle unchanged dim
CHECK_LT(i, (int)Xdims.size())
<< "\nDim(" << i << ") is out of the Xdims range of (0, "
<< Xdims.size() << ").";
new_shape[i] = Xdims[i];
}
// handle reseted dim
else if (shape[i] > 0) {
new_shape[i] = shape[i];
}
// handle inferred dim
else {
} else if (require_shape[i] > 0) {
// handle reseted dim
new_shape[i] = require_shape[i];
} else {
// handle inferred dim
CHECK_EQ(infer_dim, -1)
<< "\nDim(" << infer_dim << ") required infer before"
<< "\ncould not infer for dim(" << i << ") both.";
......@@ -35,7 +54,8 @@ void ReshapeOp<Context>::RunOnDevice() {
for (int i = 0; i < new_shape.size(); i++) {
if (new_shape[i] == -1) {
CHECK_EQ(input(0).count() % total_count, 0)
<< "\nCan not change the total size.";
<< "\nCan not change the total size: "
<< input(0).dim_string() << " -> " << dim_string(new_shape);
new_shape[i] = input(0).count() / total_count;
total_count *= new_shape[i];
break;
......@@ -43,7 +63,8 @@ void ReshapeOp<Context>::RunOnDevice() {
}
}
CHECK_EQ(total_count, input(0).count())
<< "\nCan not change the total size.";
<< "\nCan not change the total size."
<< input(0).dim_string() << " -> " << dim_string(new_shape);
output(0)->Reshape(new_shape);
output(0)->Share(input(0));
}
......
......@@ -19,9 +19,9 @@ void BatchNormOp<Context>::TrainingRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
......@@ -127,9 +127,9 @@ void BatchNormOp<Context>::InferenceRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata);
......@@ -169,12 +169,12 @@ void BatchNormOp<Context>::InferenceRunWithType() {
// divide by stddev
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, tVar_data,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
0.0, Std_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, tVar_data,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, tVar_data,
......@@ -248,9 +248,9 @@ void BatchNormGradientOp<Context>::TrainingRunWithType() {
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
if (data_format == "NCHW") {
......@@ -337,9 +337,9 @@ void BatchNormGradientOp<Context>::InferenceRunWithType() {
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
if (data_format == "NCHW") {
......
......@@ -23,9 +23,9 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
......@@ -153,9 +153,9 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata);
......@@ -296,9 +296,9 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tMean_data = mean->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* XNorm_data = x_norm->template data<T, Context>();
......@@ -436,9 +436,9 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() {
auto* dYdata = input(-1).template data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// gradient w.r.t. scale
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!